pere commited on
Commit
d809591
1 Parent(s): 289029c

FlaxAutoModelForSeq2SeqLM bug fixed

Browse files
events.out.tfevents.1625830870.t1v-n-55481057-w-0.76616.3.v2 ADDED
Binary file (89.5 kB). View file
 
events.out.tfevents.1625831337.t1v-n-55481057-w-0.78122.3.v2 ADDED
Binary file (580 kB). View file
 
run_translation_t5_flax.py CHANGED
@@ -50,6 +50,7 @@ from transformers import (
50
  FlaxAutoModelForSeq2SeqLM,
51
  HfArgumentParser,
52
  TrainingArguments,
 
53
  is_tensorboard_available,
54
  )
55
  from transformers.file_utils import is_offline_mode
@@ -374,11 +375,11 @@ def main():
374
  )
375
 
376
  if model_args.model_name_or_path:
377
- model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
378
  model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
379
  )
380
  else:
381
- model = FlaxAutoModelForSeq2SeqLM.from_config(
382
  config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
383
  )
384
 
 
50
  FlaxAutoModelForSeq2SeqLM,
51
  HfArgumentParser,
52
  TrainingArguments,
53
+ FlaxT5ForConditionalGeneration,
54
  is_tensorboard_available,
55
  )
56
  from transformers.file_utils import is_offline_mode
 
375
  )
376
 
377
  if model_args.model_name_or_path:
378
+ model = FlaxT5ForConditionalGeneration.from_pretrained(
379
  model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
380
  )
381
  else:
382
+ model = FlaxT5ForConditionalGeneration.from_config(
383
  config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
384
  )
385