nguyenvulebinh commited on
Commit
cbf9056
·
1 Parent(s): f1bbf33

fix trainer

Browse files
Files changed (1) hide show
  1. main.py +15 -10
main.py CHANGED
@@ -123,6 +123,7 @@ if __name__ == "__main__":
123
  # save_steps=5,
124
  # eval_steps=5,
125
  )
 
126
 
127
  # PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
128
  last_checkpoint_path = None
@@ -163,16 +164,20 @@ if __name__ == "__main__":
163
  )
164
  test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard)
165
  # Init trainer
166
- trainer = Trainer(
167
- model=w2v_ctc_model,
168
- data_collator=data_collator,
169
- args=training_args,
170
- compute_metrics=compute_metrics_fn(w2v_ctc_processor),
171
- train_dataset=train_dataset,
172
- eval_dataset=test_dataset,
173
- tokenizer=w2v_ctc_processor.feature_extractor,
174
- callbacks=[BreakEachEpoch()] # Manual break end of epoch because each epoch loop over a shard
175
- )
 
 
 
 
176
 
177
  logging.get_logger().info('Train shard idx: {} / {}'.format(train_dataset_shard_idx + 1, num_train_shards))
178
  logging.get_logger().info(
 
123
  # save_steps=5,
124
  # eval_steps=5,
125
  )
126
+ trainer = None
127
 
128
  # PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
129
  last_checkpoint_path = None
 
164
  )
165
  test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard)
166
  # Init trainer
167
+ if trainer is None:
168
+ trainer = Trainer(
169
+ model=w2v_ctc_model,
170
+ data_collator=data_collator,
171
+ args=training_args,
172
+ compute_metrics=compute_metrics_fn(w2v_ctc_processor),
173
+ train_dataset=train_dataset,
174
+ eval_dataset=test_dataset,
175
+ tokenizer=w2v_ctc_processor.feature_extractor,
176
+ callbacks=[BreakEachEpoch()] # Manual break end of epoch because each epoch loop over a shard
177
+ )
178
+ else:
179
+ trainer.train_dataset = train_dataset
180
+ trainer.eval_dataset = test_dataset
181
 
182
  logging.get_logger().info('Train shard idx: {} / {}'.format(train_dataset_shard_idx + 1, num_train_shards))
183
  logging.get_logger().info(