new script for partitioning. Based on base_wmt14enfr
Browse files
finetune_large_mt5_sentencefix_v4_16.gin
CHANGED
@@ -25,15 +25,35 @@ RANDOM_SEED = 0
|
|
25 |
INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/mt5_large/checkpoint_1000000"
|
26 |
|
27 |
train_script.train:
|
28 |
-
eval_period =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
# `num_decodes` is equivalent to a beam size in a beam search decoding.
|
31 |
models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4
|
32 |
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
39 |
|
|
|
25 |
INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/mt5_large/checkpoint_1000000"
|
26 |
|
27 |
train_script.train:
|
28 |
+
eval_period = 100
|
29 |
+
|
30 |
+
utils.RestoreCheckpointConfig:
|
31 |
+
path = %INITIAL_CHECKPOINT_PATH
|
32 |
+
mode = 'specific'
|
33 |
+
|
34 |
+
train_script.train:
|
35 |
+
train_dataset_cfg = @train/utils.DatasetConfig()
|
36 |
+
train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
|
37 |
+
infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig()
|
38 |
|
|
|
39 |
models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4
|
40 |
|
41 |
+
infer_eval/utils.DatasetConfig:
|
42 |
+
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
|
43 |
+
task_feature_lengths = %TASK_FEATURE_LENGTHS
|
44 |
+
split = 'validation'
|
45 |
+
batch_size = 64
|
46 |
+
shuffle = False
|
47 |
+
seed = 42
|
48 |
+
use_cached = %USE_CACHED_TASKS
|
49 |
+
pack = False
|
50 |
+
module = %MIXTURE_OR_TASK_MODULE
|
51 |
|
52 |
+
seqio.Evaluator:
|
53 |
+
logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger]
|
54 |
+
num_examples = None # Use all examples in the dataset.
|
55 |
+
use_memory_cache = True
|
56 |
+
|
57 |
+
|
58 |
+
partitioning.PjitPartitioner.num_partitions = 4
|
59 |
|