from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl, logging | |
class BreakEachEpoch(TrainerCallback): | |
""" | |
A :class:`~transformers.TrainerCallback` that handles the default flow of the training loop for logs, evaluation | |
and checkpoints. | |
""" | |
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): | |
control.should_training_stop = True | |
logging.get_logger().info("Break each epoch for reload new shard dataset") | |
return control | |