|
from __gin__ import dynamic_registration |
|
import tasks_v4 |
|
|
|
import __main__ as train_script |
|
from t5.data import mixtures |
|
from t5x import models |
|
from t5x import partitioning |
|
from t5x import utils |
|
|
|
include "t5x/examples/t5/mt5/large.gin" |
|
include "t5x/configs/runs/finetune.gin" |
|
|
|
MIXTURE_OR_TASK_NAME = "sentencefix" |
|
TASK_FEATURE_LENGTHS = {"inputs": 256, "targets": 256} |
|
TRAIN_STEPS = 1_200_000 # 1000000 pre-trained steps + 20000 fine-tuning steps. |
|
USE_CACHED_TASKS = False |
|
DROPOUT_RATE = 0.0 |
|
RANDOM_SEED = 0 |
|
|
|
# `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained |
|
# using Mesh Tensorflow (e.g. the public T5 / mT5 / ByT5 models), this should be |
|
# set to `` `` |
|
# `2048 * 114`. For `` `` |
|
#LOSS_NORMALIZING_FACTOR |
|
INITIAL_CHECKPOINT_PATH |
|
|
|
train_script.train: |
|
eval_period |
|
partitioner |
|
|
|
utils.RestoreCheckpointConfig: |
|
path |
|
mode |
|
|
|
#train_script.train: |
|
# train_dataset_cfg /utils.DatasetConfig() |
|
# train_eval_dataset_cfg = @train_eval/ |
|
# infer_eval_dataset_cfg /utils.DatasetConfig() |
|
|
|
models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4 |
|
|
|
infer_eval/ |
|
mixture_or_task_name |
|
task_feature_lengths |
|
split |
|
batch_size |
|
shuffle |
|
seed |
|
use_cached |
|
pack |
|
module |
|
|
|
partitioning.PjitPartitioner: |
|
num_partitions |
|
model_parallel_submesh |
|
logical_axis_rules |
|
|
|
#partitioning.PjitPartitioner.num_partitions |
|
|
|
|