|
from __gin__ import dynamic_registration |
|
import tasks_v4 |
|
|
|
import __main__ as train_script |
|
from t5.data import mixtures |
|
from t5x import models |
|
from t5x import partitioning |
|
from t5x import utils |
|
|
|
include "t5x/examples/t5/mt5/small.gin" |
|
include "t5x/configs/runs/finetune.gin" |
|
|
|
MIXTURE_OR_TASK_NAME = "sentencefix" |
|
TASK_FEATURE_LENGTHS = {"inputs": 256, "targets": 256} |
|
TRAIN_STEPS = 1_100_000 |
|
USE_CACHED_TASKS = False |
|
DROPOUT_RATE = 0.0 |
|
RANDOM_SEED = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/mt5_large/checkpoint_1000000" |
|
|
|
train_script.train: |
|
eval_period = 500 |
|
partitioner = @partitioning.ModelBasedPjitPartitioner() |
|
|
|
|
|
models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4 |
|
|
|
partitioning.ModelBasedPjitPartitioner.num_partitions = 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|