pere commited on
Commit
75a76d3
1 Parent(s): 9be6ad3
finetune_large_mt5_sentencefix_v4_16.gin CHANGED
@@ -25,35 +25,17 @@ RANDOM_SEED = 0
25
  INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/mt5_large/checkpoint_1000000"
26
 
27
  train_script.train:
28
- eval_period = 100
29
- partitioner = @partitioning.PjitPartitioner()
30
 
31
- utils.RestoreCheckpointConfig:
32
- path = %INITIAL_CHECKPOINT_PATH
33
- mode = 'specific'
34
 
35
- #train_script.train:
36
- # train_dataset_cfg = @train/utils.DatasetConfig()
37
- # train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
38
- # infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig()
39
 
40
- models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4
41
 
42
- infer_eval/utils.DatasetConfig:
43
- mixture_or_task_name = %MIXTURE_OR_TASK_NAME
44
- task_feature_lengths = %TASK_FEATURE_LENGTHS
45
- split = 'validation'
46
- batch_size = 64
47
- shuffle = False
48
- seed = 42
49
- use_cached = %USE_CACHED_TASKS
50
- pack = False
51
- module = %MIXTURE_OR_TASK_MODULE
52
-
53
- partitioning.PjitPartitioner:
54
- num_partitions = 4
55
- model_parallel_submesh = None
56
- logical_axis_rules = @partitioning.standard_logical_axis_rules()
57
-
58
- #partitioning.PjitPartitioner.num_partitions = 4
59
 
 
25
  INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/mt5_large/checkpoint_1000000"
26
 
27
  train_script.train:
28
+ eval_period = 500
29
+ partitioner = @partitioning.ModelBasedPjitPartitioner()
30
 
31
+ # `num_decodes` is equivalent to a beam size in a beam search decoding.
32
+ models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4
 
33
 
34
+ partitioning.ModelBasedPjitPartitioner.num_partitions = 2
 
 
 
35
 
 
36
 
37
+ #from t5.models import mesh_transformer
38
+ #import t5.models
39
+ #mesh_transformer.learning_rate_schedules.constant_learning_rate.learning_rate = 0.0005
40
+ #run.learning_rate_schedule = @learning_rate_schedules.constant_learning_rate
 
 
 
 
 
 
 
 
 
 
 
 
 
41