FlaxAutoModelForSeq2SeqLM bug fixed
Browse files
events.out.tfevents.1625830870.t1v-n-55481057-w-0.76616.3.v2
ADDED
Binary file (89.5 kB). View file
|
|
events.out.tfevents.1625831337.t1v-n-55481057-w-0.78122.3.v2
ADDED
Binary file (580 kB). View file
|
|
run_translation_t5_flax.py
CHANGED
@@ -50,6 +50,7 @@ from transformers import (
|
|
50 |
FlaxAutoModelForSeq2SeqLM,
|
51 |
HfArgumentParser,
|
52 |
TrainingArguments,
|
|
|
53 |
is_tensorboard_available,
|
54 |
)
|
55 |
from transformers.file_utils import is_offline_mode
|
@@ -374,11 +375,11 @@ def main():
|
|
374 |
)
|
375 |
|
376 |
if model_args.model_name_or_path:
|
377 |
-
model =
|
378 |
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
379 |
)
|
380 |
else:
|
381 |
-
model =
|
382 |
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
383 |
)
|
384 |
|
|
|
50 |
FlaxAutoModelForSeq2SeqLM,
|
51 |
HfArgumentParser,
|
52 |
TrainingArguments,
|
53 |
+
FlaxT5ForConditionalGeneration,
|
54 |
is_tensorboard_available,
|
55 |
)
|
56 |
from transformers.file_utils import is_offline_mode
|
|
|
375 |
)
|
376 |
|
377 |
if model_args.model_name_or_path:
|
378 |
+
model = FlaxT5ForConditionalGeneration.from_pretrained(
|
379 |
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
380 |
)
|
381 |
else:
|
382 |
+
model = FlaxT5ForConditionalGeneration.from_config(
|
383 |
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
384 |
)
|
385 |
|