File size: 553 Bytes
1e275bf
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl, logging


class BreakEachEpoch(TrainerCallback):
    """
    A :class:`~transformers.TrainerCallback` that handles the default flow of the training loop for logs, evaluation
    and checkpoints.
    """
    def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        control.should_training_stop = True
        logging.get_logger().info("Break each epoch for reload new shard dataset")
        return control