Commit
·
cbf9056
1
Parent(s):
f1bbf33
fix trainer
Browse files
main.py
CHANGED
@@ -123,6 +123,7 @@ if __name__ == "__main__":
|
|
123 |
# save_steps=5,
|
124 |
# eval_steps=5,
|
125 |
)
|
|
|
126 |
|
127 |
# PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
|
128 |
last_checkpoint_path = None
|
@@ -163,16 +164,20 @@ if __name__ == "__main__":
|
|
163 |
)
|
164 |
test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard)
|
165 |
# Init trainer
|
166 |
-
trainer
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
|
|
176 |
|
177 |
logging.get_logger().info('Train shard idx: {} / {}'.format(train_dataset_shard_idx + 1, num_train_shards))
|
178 |
logging.get_logger().info(
|
|
|
123 |
# save_steps=5,
|
124 |
# eval_steps=5,
|
125 |
)
|
126 |
+
trainer = None
|
127 |
|
128 |
# PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
|
129 |
last_checkpoint_path = None
|
|
|
164 |
)
|
165 |
test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard)
|
166 |
# Init trainer
|
167 |
+
if trainer is None:
|
168 |
+
trainer = Trainer(
|
169 |
+
model=w2v_ctc_model,
|
170 |
+
data_collator=data_collator,
|
171 |
+
args=training_args,
|
172 |
+
compute_metrics=compute_metrics_fn(w2v_ctc_processor),
|
173 |
+
train_dataset=train_dataset,
|
174 |
+
eval_dataset=test_dataset,
|
175 |
+
tokenizer=w2v_ctc_processor.feature_extractor,
|
176 |
+
callbacks=[BreakEachEpoch()] # Manual break end of epoch because each epoch loop over a shard
|
177 |
+
)
|
178 |
+
else:
|
179 |
+
trainer.train_dataset = train_dataset
|
180 |
+
trainer.eval_dataset = test_dataset
|
181 |
|
182 |
logging.get_logger().info('Train shard idx: {} / {}'.format(train_dataset_shard_idx + 1, num_train_shards))
|
183 |
logging.get_logger().info(
|