pere commited on
Commit
b9e7413
1 Parent(s): 3960395

new script for partitioning. Based on base_wmt14enfr

Browse files
finetune_large_mt5_sentencefix_v4_16.gin CHANGED
@@ -25,15 +25,35 @@ RANDOM_SEED = 0
25
  INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/mt5_large/checkpoint_1000000"
26
 
27
  train_script.train:
28
- eval_period = 500
 
 
 
 
 
 
 
 
 
29
 
30
- # `num_decodes` is equivalent to a beam size in a beam search decoding.
31
  models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4
32
 
33
- partitioning.PjitPartitioner.num_partitions = 4
 
 
 
 
 
 
 
 
 
34
 
35
- #from t5.models import mesh_transformer
36
- #import t5.models
37
- #mesh_transformer.learning_rate_schedules.constant_learning_rate.learning_rate = 0.0005
38
- #run.learning_rate_schedule = @learning_rate_schedules.constant_learning_rate
 
 
 
39
 
 
25
  INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/mt5_large/checkpoint_1000000"
26
 
27
  train_script.train:
28
+ eval_period = 100
29
+
30
+ utils.RestoreCheckpointConfig:
31
+ path = %INITIAL_CHECKPOINT_PATH
32
+ mode = 'specific'
33
+
34
+ train_script.train:
35
+ train_dataset_cfg = @train/utils.DatasetConfig()
36
+ train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
37
+ infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig()
38
 
 
39
  models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4
40
 
41
+ infer_eval/utils.DatasetConfig:
42
+ mixture_or_task_name = %MIXTURE_OR_TASK_NAME
43
+ task_feature_lengths = %TASK_FEATURE_LENGTHS
44
+ split = 'validation'
45
+ batch_size = 64
46
+ shuffle = False
47
+ seed = 42
48
+ use_cached = %USE_CACHED_TASKS
49
+ pack = False
50
+ module = %MIXTURE_OR_TASK_MODULE
51
 
52
+ seqio.Evaluator:
53
+ logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger]
54
+ num_examples = None # Use all examples in the dataset.
55
+ use_memory_cache = True
56
+
57
+
58
+ partitioning.PjitPartitioner.num_partitions = 4
59