diff --git "a/lmflow/pipeline/utils/raft_trainer.py" "b/lmflow/pipeline/utils/raft_trainer.py"
new file mode 100644--- /dev/null
+++ "b/lmflow/pipeline/utils/raft_trainer.py"
@@ -0,0 +1,3782 @@
+import contextlib
+import functools
+import glob
+import inspect
+import math
+import os
+import random
+import re
+import shutil
+import sys
+import time
+import warnings
+from collections.abc import Mapping
+from distutils.util import strtobool
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
+
+from tqdm.auto import tqdm
+
+
+# Integrations must be imported before ML frameworks:
+# isort: off
+from transformers.integrations import (
+ default_hp_search_backend,
+ get_reporting_integration_callbacks,
+ hp_params,
+ is_fairscale_available,
+ is_optuna_available,
+ is_ray_tune_available,
+ is_sigopt_available,
+ is_wandb_available,
+ run_hp_search_optuna,
+ run_hp_search_ray,
+ run_hp_search_sigopt,
+ run_hp_search_wandb,
+)
+
+# isort: on
+
+import numpy as np
+import torch
+import torch.distributed as dist
+from huggingface_hub import Repository, create_repo
+from packaging import version
+from torch import nn
+from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
+from torch.utils.data.distributed import DistributedSampler
+
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
+from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
+from transformers.deepspeed import deepspeed_init, is_deepspeed_zero3_enabled
+from transformers.dependency_versions_check import dep_version_check
+from transformers.modelcard import TrainingSummary
+from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
+from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
+from transformers.optimization import Adafactor, get_scheduler
+from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_10, is_torch_less_than_1_11
+from transformers.tokenization_utils_base import PreTrainedTokenizerBase
+from transformers.trainer_callback import (
+ CallbackHandler,
+ DefaultFlowCallback,
+ PrinterCallback,
+ ProgressCallback,
+ TrainerCallback,
+ TrainerControl,
+ TrainerState,
+)
+from transformers.trainer_pt_utils import (
+ DistributedLengthGroupedSampler,
+ DistributedSamplerWithLoop,
+ DistributedTensorGatherer,
+ IterableDatasetShard,
+ LabelSmoother,
+ LengthGroupedSampler,
+ SequentialDistributedSampler,
+ ShardSampler,
+ distributed_broadcast_scalars,
+ distributed_concat,
+ find_batch_size,
+ get_module_class_from_name,
+ get_parameter_names,
+ nested_concat,
+ nested_detach,
+ nested_numpify,
+ nested_truncate,
+ nested_xla_mesh_reduce,
+ reissue_pt_warnings,
+)
+from transformers.trainer_utils import (
+ PREFIX_CHECKPOINT_DIR,
+ BestRun,
+ EvalLoopOutput,
+ EvalPrediction,
+ FSDPOption,
+ HPSearchBackend,
+ HubStrategy,
+ IntervalStrategy,
+ PredictionOutput,
+ RemoveColumnsCollator,
+ ShardedDDPOption,
+ TrainerMemoryTracker,
+ TrainOutput,
+ default_compute_objective,
+ default_hp_space,
+ denumpify_detensorize,
+ enable_full_determinism,
+ find_executable_batch_size,
+ get_last_checkpoint,
+ has_length,
+ number_of_arguments,
+ seed_worker,
+ set_seed,
+ speed_metrics,
+)
+from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments
+from transformers.utils import (
+ CONFIG_NAME,
+ WEIGHTS_INDEX_NAME,
+ WEIGHTS_NAME,
+ can_return_loss,
+ find_labels,
+ get_full_repo_name,
+ is_accelerate_available,
+ is_apex_available,
+ is_datasets_available,
+ is_in_notebook,
+ is_ipex_available,
+ is_sagemaker_dp_enabled,
+ is_sagemaker_mp_enabled,
+ is_torch_compile_available,
+ is_torch_neuroncore_available,
+ is_torch_tpu_available,
+ logging,
+)
+from transformers.utils.generic import ContextManagers
+
+
+_is_native_cpu_amp_available = is_torch_greater_or_equal_than_1_10
+
+DEFAULT_CALLBACKS = [DefaultFlowCallback]
+DEFAULT_PROGRESS_CALLBACK = ProgressCallback
+
+if is_in_notebook():
+ from transformers.utils.notebook import NotebookProgressCallback
+
+ DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
+
+if is_apex_available():
+ from apex import amp
+
+if is_datasets_available():
+ import datasets
+
+if is_torch_tpu_available(check_device=False):
+ import torch_xla.core.xla_model as xm
+ import torch_xla.debug.metrics as met
+ import torch_xla.distributed.parallel_loader as pl
+
+if is_fairscale_available():
+ dep_version_check("fairscale")
+ import fairscale
+ from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
+ from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
+ from fairscale.nn.wrap import auto_wrap
+ from fairscale.optim import OSS
+ from fairscale.optim.grad_scaler import ShardedGradScaler
+
+
+if is_sagemaker_mp_enabled():
+ import smdistributed.modelparallel.torch as smp
+ from smdistributed.modelparallel import __version__ as SMP_VERSION
+
+ IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
+
+ from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
+else:
+ IS_SAGEMAKER_MP_POST_1_10 = False
+
+
+skip_first_batches = None
+
+
+
+logger = logging.get_logger(__name__)
+
+
+# Name of the files used for checkpointing
+TRAINING_ARGS_NAME = "training_args.bin"
+TRAINER_STATE_NAME = "trainer_state.json"
+OPTIMIZER_NAME = "optimizer.pt"
+SCHEDULER_NAME = "scheduler.pt"
+SCALER_NAME = "scaler.pt"
+
+
+class RaftTrainer:
+ """
+ Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
+ Args:
+ model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*):
+ The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.
+
+ [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use
+ your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers
+ models.
+
+ args ([`TrainingArguments`], *optional*):
+ The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the
+ `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
+ data_collator (`DataCollator`, *optional*):
+ The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will
+ default to [`default_data_collator`] if no `tokenizer` is provided, an instance of
+ [`DataCollatorWithPadding`] otherwise.
+ train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*):
+ The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the
+ `model.forward()` method are automatically removed.
+ Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a
+ distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a
+ `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
+ manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
+ sets the seed of the RNGs used.
+ eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*):
+ The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
+ `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
+ dataset prepending the dictionary key to the metric name.
+ tokenizer ([`PreTrainedTokenizerBase`], *optional*):
+ The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the
+ maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
+ interrupted training or reuse the fine-tuned model.
+ model_init (`Callable[[], PreTrainedModel]`, *optional*):
+ A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start
+ from a new instance of the model as given by this function.
+ The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
+ be able to choose different architectures according to hyper parameters (such as layer count, sizes of
+ inner layers, dropout probabilities etc).
+ compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
+ The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
+ a dictionary string to metric values.
+ callbacks (List of [`TrainerCallback`], *optional*):
+ A list of callbacks to customize the training loop. Will add those to the list of default callbacks
+ detailed in [here](callback).
+ If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
+ optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*): A tuple
+ containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model
+ and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
+ A function that preprocess the logits right before caching them at each evaluation step. Must take two
+ tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
+ by this function will be reflected in the predictions received by `compute_metrics`.
+ Note that the labels (second parameter) will be `None` if the dataset does not have them.
+ Important attributes:
+ - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]
+ subclass.
+ - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
+ original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,
+ the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner
+ model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.
+ - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
+ data parallelism, this means some of the model layers are split on different GPUs).
+ - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
+ to `False` if model parallel or deepspeed is used, or if the default
+ `TrainingArguments.place_model_on_device` is overridden to return `False` .
+ - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while
+ in `train`)
+ """
+
+ from transformers.trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
+
+ def __init__(
+ self,
+ model: Union[PreTrainedModel, nn.Module] = None,
+ args: TrainingArguments = None,
+ data_collator: Optional[DataCollator] = None,
+ train_dataset: Optional[Dataset] = None,
+ eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
+ compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
+ callbacks: Optional[List[TrainerCallback]] = None,
+ optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
+ ):
+ ############
+ self.save_counter = 0
+ ##############
+ if args is None:
+ output_dir = "tmp_trainer"
+ logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
+ args = TrainingArguments(output_dir=output_dir)
+ self.args = args
+ # Seed must be set before instantiating the model when using model
+ enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
+ self.hp_name = None
+ self.deepspeed = None
+ self.is_in_train = False
+
+ # memory metrics - must set up as early as possible
+ self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
+ self._memory_tracker.start()
+
+ # set the correct log level depending on the node
+ log_level = args.get_process_log_level()
+ logging.set_verbosity(log_level)
+
+ # force device and distributed setup init explicitly
+ args._setup_devices
+
+ if model is None:
+ if model_init is not None:
+ self.model_init = model_init
+ model = self.call_model_init()
+ else:
+ raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
+ else:
+ if model_init is not None:
+ warnings.warn(
+ "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will"
+ " overwrite your model when calling the `train` method. This will become a fatal error in the next"
+ " release.",
+ FutureWarning,
+ )
+ self.model_init = model_init
+
+ if model.__class__.__name__ in MODEL_MAPPING_NAMES:
+ raise ValueError(
+ f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only "
+ "computes hidden states and does not accept any labels. You should choose a model with a head "
+ "suitable for your task like any of the `AutoModelForXxx` listed at "
+ "https://huggingface.co/docs/transformers/model_doc/auto."
+ )
+
+ if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel:
+ self.is_model_parallel = True
+ else:
+ self.is_model_parallel = False
+
+ # At this stage the model is already loaded
+ if getattr(model, "is_loaded_in_8bit", False):
+ if getattr(model, "_is_int8_training_enabled", False):
+ logger.info(
+ "The model is loaded in 8-bit precision. To train this model you need to add additional modules"
+ " inside the model such as adapters using `peft` library and freeze the model weights. Please"
+ " check "
+ " the examples in https://github.com/huggingface/peft for more details."
+ )
+ else:
+ raise ValueError(
+ "The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit"
+ " model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
+ )
+
+ # Setup Sharded DDP training
+ self.sharded_ddp = None
+ if len(args.sharded_ddp) > 0:
+ if args.deepspeed:
+ raise ValueError(
+ "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
+ )
+ if len(args.fsdp) > 0:
+ raise ValueError(
+ "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags."
+ )
+
+ if args.local_rank == -1:
+ raise ValueError("Using sharded DDP only works in distributed training.")
+ elif not is_fairscale_available():
+ raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
+ elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
+ raise ImportError(
+ "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
+ f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
+ )
+ elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
+ self.sharded_ddp = ShardedDDPOption.SIMPLE
+ elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
+ self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
+ elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
+ self.sharded_ddp = ShardedDDPOption.ZERO_DP_3
+
+ self.fsdp = None
+ if len(args.fsdp) > 0:
+ if args.deepspeed:
+ raise ValueError(
+ "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
+ )
+ if not args.fsdp_config["xla"] and args.local_rank == -1:
+ raise ValueError("Using fsdp only works in distributed training.")
+
+ # dep_version_check("torch>=1.12.0")
+ # Would have to update setup.py with torch>=1.12.0
+ # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0
+ # below is the current alternative.
+ if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"):
+ raise ValueError("FSDP requires PyTorch >= 1.12.0")
+
+ from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy
+
+ if FSDPOption.FULL_SHARD in args.fsdp:
+ self.fsdp = ShardingStrategy.FULL_SHARD
+ elif FSDPOption.SHARD_GRAD_OP in args.fsdp:
+ self.fsdp = ShardingStrategy.SHARD_GRAD_OP
+ elif FSDPOption.NO_SHARD in args.fsdp:
+ self.fsdp = ShardingStrategy.NO_SHARD
+
+ self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE
+ if "backward_prefetch" in self.args.fsdp_config and "backward_pos" not in self.backward_prefetch:
+ self.backward_prefetch = BackwardPrefetch.BACKWARD_POST
+
+ self.forword_prefetch = False
+ if self.args.fsdp_config.get("forword_prefect", False):
+ self.forword_prefetch = True
+
+ self.limit_all_gathers = False
+ if self.args.fsdp_config.get("limit_all_gathers", False):
+ self.limit_all_gathers = True
+
+ # one place to sort out whether to place the model on device or not
+ # postpone switching model to cuda when:
+ # 1. MP - since we are trying to fit a much bigger than 1 gpu model
+ # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
+ # and we only use deepspeed for training at the moment
+ # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
+ # 4. Sharded DDP - same as MP
+ # 5. FSDP - same as MP
+ self.place_model_on_device = args.place_model_on_device
+ if (
+ self.is_model_parallel
+ or args.deepspeed
+ or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
+ or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
+ or (self.fsdp is not None)
+ ):
+ self.place_model_on_device = False
+
+ default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
+ self.data_collator = data_collator if data_collator is not None else default_collator
+ self.train_dataset = train_dataset
+ self.eval_dataset = eval_dataset
+ self.tokenizer = tokenizer
+
+ if self.place_model_on_device and not getattr(model, "is_loaded_in_8bit", False):
+ self._move_model_to_device(model, args.device)
+
+ # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
+ if self.is_model_parallel:
+ self.args._n_gpu = 1
+
+ # later use `self.model is self.model_wrapped` to check if it's wrapped or not
+ self.model_wrapped = model
+ self.model = model
+
+ self.compute_metrics = compute_metrics
+ self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
+ self.optimizer, self.lr_scheduler = optimizers
+ if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
+ raise RuntimeError(
+ "Passing a `model_init` is incompatible with providing the `optimizers` argument. "
+ "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
+ )
+ if is_torch_tpu_available() and self.optimizer is not None:
+ for param in self.model.parameters():
+ model_device = param.device
+ break
+ for param_group in self.optimizer.param_groups:
+ if len(param_group["params"]) > 0:
+ optimizer_device = param_group["params"][0].device
+ break
+ if model_device != optimizer_device:
+ raise ValueError(
+ "The model and the optimizer parameters are not on the same device, which probably means you"
+ " created an optimizer around your model **before** putting on the device and passing it to the"
+ " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
+ " `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
+ )
+ if ((self.sharded_ddp is not None) or args.deepspeed or (self.fsdp is not None)) and (
+ self.optimizer is not None or self.lr_scheduler is not None
+ ):
+ raise RuntimeError(
+ "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled."
+ "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
+ )
+ default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
+ callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
+ self.callback_handler = CallbackHandler(
+ callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
+ )
+ self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
+
+ # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
+ self._loggers_initialized = False
+
+ # Create clone of distant repo and output directory if needed
+ if self.args.push_to_hub:
+ self.init_git_repo(at_init=True)
+ # In case of pull, we need to make sure every process has the latest.
+ if is_torch_tpu_available():
+ xm.rendezvous("init git repo")
+ elif args.local_rank != -1:
+ dist.barrier()
+
+ if self.args.should_save:
+ os.makedirs(self.args.output_dir, exist_ok=True)
+
+ if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
+ raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
+
+ if args.max_steps > 0:
+ logger.info("max_steps is given, it will override any value given in num_train_epochs")
+
+ if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0:
+ raise ValueError("train_dataset does not implement __len__, max_steps has to be specified")
+
+ if (
+ train_dataset is not None
+ and isinstance(train_dataset, torch.utils.data.IterableDataset)
+ and args.group_by_length
+ ):
+ raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset")
+
+ self._signature_columns = None
+
+ # Mixed precision setup
+ self.use_apex = False
+ self.use_cuda_amp = False
+ self.use_cpu_amp = False
+
+ # Mixed precision setup for SageMaker Model Parallel
+ if is_sagemaker_mp_enabled():
+ # BF16 + model parallelism in SageMaker: currently not supported, raise an error
+ if args.bf16:
+ raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")
+
+ if IS_SAGEMAKER_MP_POST_1_10:
+ # When there's mismatch between SMP config and trainer argument, use SMP config as truth
+ if args.fp16 != smp.state.cfg.fp16:
+ logger.warning(
+ f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16},"
+ f"but FP16 provided in trainer argument is {args.fp16},"
+ f"setting to {smp.state.cfg.fp16}"
+ )
+ args.fp16 = smp.state.cfg.fp16
+ else:
+ # smp < 1.10 does not support fp16 in trainer.
+ if hasattr(smp.state.cfg, "fp16"):
+ logger.warning(
+ f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
+ "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
+ )
+
+ if args.fp16 or args.bf16:
+ if args.half_precision_backend == "auto":
+ if args.device == torch.device("cpu"):
+ if args.fp16:
+ raise ValueError("Tried to use `fp16` but it is not supported on cpu")
+ elif _is_native_cpu_amp_available:
+ args.half_precision_backend = "cpu_amp"
+ else:
+ raise ValueError("Tried to use cpu amp but native cpu amp is not available")
+ else:
+ args.half_precision_backend = "cuda_amp"
+
+ logger.info(f"Using {args.half_precision_backend} half precision backend")
+
+ self.do_grad_scaling = False
+ if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled() or is_torch_tpu_available()):
+ # deepspeed and SageMaker Model Parallel manage their own half precision
+ if args.half_precision_backend == "cuda_amp":
+ self.use_cuda_amp = True
+ self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
+ # bf16 does not need grad scaling
+ self.do_grad_scaling = self.amp_dtype == torch.float16
+ if self.do_grad_scaling:
+ if self.sharded_ddp is not None:
+ self.scaler = ShardedGradScaler()
+ elif self.fsdp is not None:
+ from torch.distributed.fsdp.sharded_grad_scaler import (
+ ShardedGradScaler as FSDPShardedGradScaler,
+ )
+
+ self.scaler = FSDPShardedGradScaler()
+ elif is_torch_tpu_available():
+ from torch_xla.amp import GradScaler
+
+ self.scaler = GradScaler()
+ else:
+ self.scaler = torch.cuda.amp.GradScaler()
+ elif args.half_precision_backend == "cpu_amp":
+ self.use_cpu_amp = True
+ self.amp_dtype = torch.bfloat16
+ else:
+ if not is_apex_available():
+ raise ImportError(
+ "Using FP16 with APEX but APEX is not installed, please refer to"
+ " https://www.github.com/nvidia/apex."
+ )
+ self.use_apex = True
+
+ # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
+ if (
+ is_sagemaker_mp_enabled()
+ and self.use_cuda_amp
+ and args.max_grad_norm is not None
+ and args.max_grad_norm > 0
+ ):
+ raise ValueError(
+ "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
+ "along 'max_grad_norm': 0 in your hyperparameters."
+ )
+
+ # Label smoothing
+ if self.args.label_smoothing_factor != 0:
+ self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
+ else:
+ self.label_smoother = None
+
+ self.state = TrainerState(
+ is_local_process_zero=self.is_local_process_zero(),
+ is_world_process_zero=self.is_world_process_zero(),
+ )
+
+ self.control = TrainerControl()
+ # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
+ # returned to 0 every time flos need to be logged
+ self.current_flos = 0
+ self.hp_search_backend = None
+ self.use_tune_checkpoints = False
+ default_label_names = find_labels(self.model.__class__)
+ self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
+ self.can_return_loss = can_return_loss(self.model.__class__)
+ self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
+
+ # Internal variables to keep track of the original batch size
+ self._train_batch_size = args.train_batch_size
+
+ # very last
+ self._memory_tracker.stop_and_update_metrics()
+
+ # torch.compile
+ if args.torch_compile and not is_torch_compile_available():
+ raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.")
+
+ def add_callback(self, callback):
+ """
+ Add a callback to the current list of [`~transformer.TrainerCallback`].
+ Args:
+ callback (`type` or [`~transformer.TrainerCallback`]):
+ A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
+ first case, will instantiate a member of that class.
+ """
+ self.callback_handler.add_callback(callback)
+
+ def pop_callback(self, callback):
+ """
+ Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it.
+ If the callback is not found, returns `None` (and no error is raised).
+ Args:
+ callback (`type` or [`~transformer.TrainerCallback`]):
+ A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
+ first case, will pop the first member of that class found in the list of callbacks.
+ Returns:
+ [`~transformer.TrainerCallback`]: The callback removed, if found.
+ """
+ return self.callback_handler.pop_callback(callback)
+
+ def remove_callback(self, callback):
+ """
+ Remove a callback from the current list of [`~transformer.TrainerCallback`].
+ Args:
+ callback (`type` or [`~transformer.TrainerCallback`]):
+ A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
+ first case, will remove the first member of that class found in the list of callbacks.
+ """
+ self.callback_handler.remove_callback(callback)
+
+ def _move_model_to_device(self, model, device):
+ model = model.to(device)
+ # Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
+ if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
+ model.tie_weights()
+
+ def _set_signature_columns_if_needed(self):
+ if self._signature_columns is None:
+ # Inspect model forward signature to keep only the arguments it accepts.
+ signature = inspect.signature(self.model.forward)
+ self._signature_columns = list(signature.parameters.keys())
+ # Labels may be named label or label_ids, the default data collator handles that.
+ self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
+
+ def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
+ if not self.args.remove_unused_columns:
+ return dataset
+ self._set_signature_columns_if_needed()
+ signature_columns = self._signature_columns
+
+ ignored_columns = list(set(dataset.column_names) - set(signature_columns))
+ if len(ignored_columns) > 0:
+ dset_description = "" if description is None else f"in the {description} set"
+ logger.info(
+ f"The following columns {dset_description} don't have a corresponding argument in "
+ f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
+ f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
+ " you can safely ignore this message."
+ )
+
+ columns = [k for k in signature_columns if k in dataset.column_names]
+
+ if version.parse(datasets.__version__) < version.parse("1.4.0"):
+ dataset.set_format(
+ type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
+ )
+ return dataset
+ else:
+ return dataset.remove_columns(ignored_columns)
+
+ def _get_collator_with_removed_columns(
+ self, data_collator: Callable, description: Optional[str] = None
+ ) -> Callable:
+ """Wrap the data collator in a callable removing unused columns."""
+ if not self.args.remove_unused_columns:
+ return data_collator
+ self._set_signature_columns_if_needed()
+ signature_columns = self._signature_columns
+
+ remove_columns_collator = RemoveColumnsCollator(
+ data_collator=data_collator,
+ signature_columns=signature_columns,
+ logger=logger,
+ description=description,
+ model_name=self.model.__class__.__name__,
+ )
+ return remove_columns_collator
+
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
+ if self.train_dataset is None or not has_length(self.train_dataset):
+ return None
+
+ generator = None
+ if self.args.world_size <= 1:
+ generator = torch.Generator()
+ # for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with
+ # `args.seed`) if data_seed isn't provided.
+ # Further on in this method, we default to `args.seed` instead.
+ if self.args.data_seed is None:
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
+ else:
+ seed = self.args.data_seed
+ generator.manual_seed(seed)
+
+ seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed
+
+ # Build the sampler.
+ if self.args.group_by_length:
+ if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
+ lengths = (
+ self.train_dataset[self.args.length_column_name]
+ if self.args.length_column_name in self.train_dataset.column_names
+ else None
+ )
+ else:
+ lengths = None
+ model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
+ if self.args.world_size <= 1:
+ return LengthGroupedSampler(
+ self.args.train_batch_size * self.args.gradient_accumulation_steps,
+ dataset=self.train_dataset,
+ lengths=lengths,
+ model_input_name=model_input_name,
+ generator=generator,
+ )
+ else:
+ return DistributedLengthGroupedSampler(
+ self.args.train_batch_size * self.args.gradient_accumulation_steps,
+ dataset=self.train_dataset,
+ num_replicas=self.args.world_size,
+ rank=self.args.process_index,
+ lengths=lengths,
+ model_input_name=model_input_name,
+ seed=seed,
+ )
+
+ else:
+ if self.args.world_size <= 1:
+ return RandomSampler(self.train_dataset, generator=generator)
+ elif (
+ self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
+ and not self.args.dataloader_drop_last
+ ):
+ # Use a loop for TPUs when drop_last is False to have all batches have the same size.
+ return DistributedSamplerWithLoop(
+ self.train_dataset,
+ batch_size=self.args.per_device_train_batch_size,
+ num_replicas=self.args.world_size,
+ rank=self.args.process_index,
+ seed=seed,
+ )
+ else:
+ return DistributedSampler(
+ self.train_dataset,
+ num_replicas=self.args.world_size,
+ rank=self.args.process_index,
+ seed=seed,
+ )
+
+ def get_train_dataloader(self) -> DataLoader:
+ """
+ Returns the training [`~torch.utils.data.DataLoader`].
+ Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
+ training if necessary) otherwise.
+ Subclass and override this method if you want to inject some custom behavior.
+ """
+ if self.train_dataset is None:
+ raise ValueError("Trainer: training requires a train_dataset.")
+
+ train_dataset = self.train_dataset
+ data_collator = self.data_collator
+ if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
+ train_dataset = self._remove_unused_columns(train_dataset, description="training")
+ else:
+ data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
+
+ if isinstance(train_dataset, torch.utils.data.IterableDataset):
+ if self.args.world_size > 1:
+ train_dataset = IterableDatasetShard(
+ train_dataset,
+ batch_size=self._train_batch_size,
+ drop_last=self.args.dataloader_drop_last,
+ num_processes=self.args.world_size,
+ process_index=self.args.process_index,
+ )
+
+ return DataLoader(
+ train_dataset,
+ batch_size=self._train_batch_size,
+ collate_fn=data_collator,
+ num_workers=self.args.dataloader_num_workers,
+ pin_memory=self.args.dataloader_pin_memory,
+ )
+
+ train_sampler = self._get_train_sampler()
+
+ return DataLoader(
+ train_dataset,
+ batch_size=self._train_batch_size,
+ sampler=train_sampler,
+ collate_fn=data_collator,
+ drop_last=self.args.dataloader_drop_last,
+ num_workers=self.args.dataloader_num_workers,
+ pin_memory=self.args.dataloader_pin_memory,
+ worker_init_fn=seed_worker,
+ )
+
+ def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
+ # Deprecated code
+ if self.args.use_legacy_prediction_loop:
+ if is_torch_tpu_available():
+ return SequentialDistributedSampler(
+ eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
+ )
+ elif is_sagemaker_mp_enabled():
+ return SequentialDistributedSampler(
+ eval_dataset,
+ num_replicas=smp.dp_size(),
+ rank=smp.dp_rank(),
+ batch_size=self.args.per_device_eval_batch_size,
+ )
+ elif self.args.local_rank != -1:
+ return SequentialDistributedSampler(eval_dataset)
+ else:
+ return SequentialSampler(eval_dataset)
+
+ if self.args.world_size <= 1:
+ return SequentialSampler(eval_dataset)
+ else:
+ return ShardSampler(
+ eval_dataset,
+ batch_size=self.args.per_device_eval_batch_size,
+ num_processes=self.args.world_size,
+ process_index=self.args.process_index,
+ )
+
+ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
+ """
+ Returns the evaluation [`~torch.utils.data.DataLoader`].
+ Subclass and override this method if you want to inject some custom behavior.
+ Args:
+ eval_dataset (`torch.utils.data.Dataset`, *optional*):
+ If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
+ by the `model.forward()` method are automatically removed. It must implement `__len__`.
+ """
+ if eval_dataset is None and self.eval_dataset is None:
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
+ eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
+ data_collator = self.data_collator
+
+ if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
+ eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
+ else:
+ data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")
+
+ if isinstance(eval_dataset, torch.utils.data.IterableDataset):
+ if self.args.world_size > 1:
+ eval_dataset = IterableDatasetShard(
+ eval_dataset,
+ batch_size=self.args.per_device_eval_batch_size,
+ drop_last=self.args.dataloader_drop_last,
+ num_processes=self.args.world_size,
+ process_index=self.args.process_index,
+ )
+ return DataLoader(
+ eval_dataset,
+ batch_size=self.args.eval_batch_size,
+ collate_fn=data_collator,
+ num_workers=self.args.dataloader_num_workers,
+ pin_memory=self.args.dataloader_pin_memory,
+ )
+
+ eval_sampler = self._get_eval_sampler(eval_dataset)
+
+ return DataLoader(
+ eval_dataset,
+ sampler=eval_sampler,
+ batch_size=self.args.eval_batch_size,
+ collate_fn=data_collator,
+ drop_last=self.args.dataloader_drop_last,
+ num_workers=self.args.dataloader_num_workers,
+ pin_memory=self.args.dataloader_pin_memory,
+ )
+
+ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
+ """
+ Returns the test [`~torch.utils.data.DataLoader`].
+ Subclass and override this method if you want to inject some custom behavior.
+ Args:
+ test_dataset (`torch.utils.data.Dataset`, *optional*):
+ The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
+ `model.forward()` method are automatically removed. It must implement `__len__`.
+ """
+ data_collator = self.data_collator
+
+ if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
+ test_dataset = self._remove_unused_columns(test_dataset, description="test")
+ else:
+ data_collator = self._get_collator_with_removed_columns(data_collator, description="test")
+
+ if isinstance(test_dataset, torch.utils.data.IterableDataset):
+ if self.args.world_size > 1:
+ test_dataset = IterableDatasetShard(
+ test_dataset,
+ batch_size=self.args.eval_batch_size,
+ drop_last=self.args.dataloader_drop_last,
+ num_processes=self.args.world_size,
+ process_index=self.args.process_index,
+ )
+ return DataLoader(
+ test_dataset,
+ batch_size=self.args.eval_batch_size,
+ collate_fn=data_collator,
+ num_workers=self.args.dataloader_num_workers,
+ pin_memory=self.args.dataloader_pin_memory,
+ )
+
+ test_sampler = self._get_eval_sampler(test_dataset)
+
+ # We use the same batch_size as for eval.
+ return DataLoader(
+ test_dataset,
+ sampler=test_sampler,
+ batch_size=self.args.eval_batch_size,
+ collate_fn=data_collator,
+ drop_last=self.args.dataloader_drop_last,
+ num_workers=self.args.dataloader_num_workers,
+ pin_memory=self.args.dataloader_pin_memory,
+ )
+
+ def create_optimizer_and_scheduler(self, num_training_steps: int):
+ """
+ Setup the optimizer and the learning rate scheduler.
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
+ Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
+ `create_scheduler`) in a subclass.
+ """
+ self.create_optimizer()
+ if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:
+ # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer
+ optimizer = self.optimizer.optimizer
+ else:
+ optimizer = self.optimizer
+ self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
+
+ def create_optimizer(self):
+ """
+ Setup the optimizer.
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
+ """
+ opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
+
+ if self.optimizer is None:
+ decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
+ optimizer_grouped_parameters = [
+ {
+ "params": [
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
+ ],
+ "weight_decay": self.args.weight_decay,
+ },
+ {
+ "params": [
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
+ ],
+ "weight_decay": 0.0,
+ },
+ ]
+
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
+
+ if self.sharded_ddp == ShardedDDPOption.SIMPLE:
+ self.optimizer = OSS(
+ params=optimizer_grouped_parameters,
+ optim=optimizer_cls,
+ **optimizer_kwargs,
+ )
+ else:
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
+ if optimizer_cls.__name__ == "Adam8bit":
+ import bitsandbytes
+
+ manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
+
+ skipped = 0
+ for module in opt_model.modules():
+ if isinstance(module, nn.Embedding):
+ skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
+ print(f"skipped {module}: {skipped/2**20}M params")
+ manager.register_module_override(module, "weight", {"optim_bits": 32})
+ logger.debug(f"bitsandbytes: will optimize {module} in fp32")
+ print(f"skipped: {skipped/2**20}M params")
+
+ if is_sagemaker_mp_enabled():
+ self.optimizer = smp.DistributedOptimizer(self.optimizer)
+
+ return self.optimizer
+
+ @staticmethod
+ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
+ """
+ Returns the optimizer class and optimizer parameters based on the training arguments.
+ Args:
+ args (`transformers.training_args.TrainingArguments`):
+ The training arguments for the training session.
+ """
+
+ # parse args.optim_args
+ optim_args = {}
+ if args.optim_args:
+ for mapping in args.optim_args.replace(" ", "").split(","):
+ key, value = mapping.split("=")
+ optim_args[key] = value
+
+ optimizer_kwargs = {"lr": args.learning_rate}
+
+ adam_kwargs = {
+ "betas": (args.adam_beta1, args.adam_beta2),
+ "eps": args.adam_epsilon,
+ }
+ if args.optim == OptimizerNames.ADAFACTOR:
+ optimizer_cls = Adafactor
+ optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
+ elif args.optim == OptimizerNames.ADAMW_HF:
+ from transformers.optimization import AdamW
+
+ optimizer_cls = AdamW
+ optimizer_kwargs.update(adam_kwargs)
+ elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]:
+ from torch.optim import AdamW
+
+ optimizer_cls = AdamW
+ optimizer_kwargs.update(adam_kwargs)
+ if args.optim == OptimizerNames.ADAMW_TORCH_FUSED:
+ optimizer_kwargs.update({"fused": True})
+ elif args.optim == OptimizerNames.ADAMW_TORCH_XLA:
+ try:
+ from torch_xla.amp.syncfree import AdamW
+
+ optimizer_cls = AdamW
+ optimizer_kwargs.update(adam_kwargs)
+ except ImportError:
+ raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.")
+ elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
+ try:
+ from apex.optimizers import FusedAdam
+
+ optimizer_cls = FusedAdam
+ optimizer_kwargs.update(adam_kwargs)
+ except ImportError:
+ raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
+ elif args.optim == OptimizerNames.ADAMW_BNB:
+ try:
+ from bitsandbytes.optim import Adam8bit
+
+ optimizer_cls = Adam8bit
+ optimizer_kwargs.update(adam_kwargs)
+ except ImportError:
+ raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
+ elif args.optim == OptimizerNames.ADAMW_ANYPRECISION:
+ try:
+ from torchdistx.optimizers import AnyPrecisionAdamW
+
+ optimizer_cls = AnyPrecisionAdamW
+ optimizer_kwargs.update(adam_kwargs)
+
+ # TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx.
+ optimizer_kwargs.update(
+ {
+ "use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")),
+ "momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")),
+ "variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")),
+ "compensation_buffer_dtype": getattr(
+ torch, optim_args.get("compensation_buffer_dtype", "bfloat16")
+ ),
+ }
+ )
+ except ImportError:
+ raise ValueError("Please install https://github.com/pytorch/torchdistx")
+ elif args.optim == OptimizerNames.SGD:
+ optimizer_cls = torch.optim.SGD
+ elif args.optim == OptimizerNames.ADAGRAD:
+ optimizer_cls = torch.optim.Adagrad
+ else:
+ raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
+ return optimizer_cls, optimizer_kwargs
+
+ def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
+ """
+ Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
+ passed as an argument.
+ Args:
+ num_training_steps (int): The number of training steps to do.
+ """
+ ############
+ num_training_steps *= 3
+ ############
+ if self.lr_scheduler is None:
+ self.lr_scheduler = get_scheduler(
+ self.args.lr_scheduler_type,
+ optimizer=self.optimizer if optimizer is None else optimizer,
+ num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
+ num_training_steps=num_training_steps,
+ )
+ return self.lr_scheduler
+
+ def num_examples(self, dataloader: DataLoader) -> int:
+ """
+ Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When
+ dataloader.dataset does not exist or has no length, estimates as best it can
+ """
+ try:
+ dataset = dataloader.dataset
+ # Special case for IterableDatasetShard, we need to dig deeper
+ if isinstance(dataset, IterableDatasetShard):
+ return len(dataloader.dataset.dataset)
+ return len(dataloader.dataset)
+ except (NameError, AttributeError, TypeError): # no dataset or length, estimate by length of dataloader
+ return len(dataloader) * self.args.per_device_train_batch_size
+
+ def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
+ """HP search setup code"""
+ self._trial = trial
+
+ if self.hp_search_backend is None or trial is None:
+ return
+ if self.hp_search_backend == HPSearchBackend.OPTUNA:
+ params = self.hp_space(trial)
+ elif self.hp_search_backend == HPSearchBackend.RAY:
+ params = trial
+ params.pop("wandb", None)
+ elif self.hp_search_backend == HPSearchBackend.SIGOPT:
+ params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()}
+ elif self.hp_search_backend == HPSearchBackend.WANDB:
+ params = trial
+
+ for key, value in params.items():
+ if not hasattr(self.args, key):
+ logger.warning(
+ f"Trying to set {key} in the hyperparameter search but there is no corresponding field in"
+ " `TrainingArguments`."
+ )
+ continue
+ old_attr = getattr(self.args, key, None)
+ # Casting value to the proper type
+ if old_attr is not None:
+ value = type(old_attr)(value)
+ setattr(self.args, key, value)
+ if self.hp_search_backend == HPSearchBackend.OPTUNA:
+ logger.info(f"Trial: {trial.params}")
+ if self.hp_search_backend == HPSearchBackend.SIGOPT:
+ logger.info(f"SigOpt Assignments: {trial.assignments}")
+ if self.hp_search_backend == HPSearchBackend.WANDB:
+ logger.info(f"W&B Sweep parameters: {trial}")
+ if self.args.deepspeed:
+ # Rebuild the deepspeed config to reflect the updated training parameters
+ from transformers.deepspeed import HfTrainerDeepSpeedConfig
+
+ self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
+ self.args.hf_deepspeed_config.trainer_config_process(self.args)
+
+ def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
+ if self.hp_search_backend is None or trial is None:
+ return
+ self.objective = self.compute_objective(metrics.copy())
+ if self.hp_search_backend == HPSearchBackend.OPTUNA:
+ import optuna
+
+ trial.report(self.objective, step)
+ if trial.should_prune():
+ self.callback_handler.on_train_end(self.args, self.state, self.control)
+ raise optuna.TrialPruned()
+ elif self.hp_search_backend == HPSearchBackend.RAY:
+ from ray import tune
+
+ if self.control.should_save:
+ self._tune_save_checkpoint()
+ tune.report(objective=self.objective, **metrics)
+
+ def _tune_save_checkpoint(self):
+ from ray import tune
+
+ if not self.use_tune_checkpoints:
+ return
+ with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
+ output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
+ self.save_model(output_dir, _internal_call=True)
+ if self.args.should_save:
+ self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
+ torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
+ torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
+
+ def call_model_init(self, trial=None):
+ model_init_argcount = number_of_arguments(self.model_init)
+ if model_init_argcount == 0:
+ model = self.model_init()
+ elif model_init_argcount == 1:
+ model = self.model_init(trial)
+ else:
+ raise RuntimeError("model_init should have 0 or 1 argument.")
+
+ if model is None:
+ raise RuntimeError("model_init should not return None.")
+
+ return model
+
+ def torch_jit_model_eval(self, model, dataloader, training=False):
+ if not training:
+ if dataloader is None:
+ logger.warning("failed to use PyTorch jit mode due to current dataloader is none.")
+ return model
+ example_batch = next(iter(dataloader))
+ example_batch = self._prepare_inputs(example_batch)
+ try:
+ jit_model = model.eval()
+ with ContextManagers([self.autocast_smart_context_manager(cache_enabled=False), torch.no_grad()]):
+ if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.14.0"):
+ if isinstance(example_batch, dict):
+ jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)
+ else:
+ jit_model = torch.jit.trace(
+ jit_model,
+ example_kwarg_inputs={key: example_batch[key] for key in example_batch},
+ strict=False,
+ )
+ else:
+ jit_inputs = []
+ for key in example_batch:
+ example_tensor = torch.ones_like(example_batch[key])
+ jit_inputs.append(example_tensor)
+ jit_inputs = tuple(jit_inputs)
+ jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False)
+ jit_model = torch.jit.freeze(jit_model)
+ with torch.no_grad():
+ jit_model(**example_batch)
+ jit_model(**example_batch)
+ model = jit_model
+ self.use_cpu_amp = False
+ self.use_cuda_amp = False
+ except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e:
+ logger.warning(f"failed to use PyTorch jit mode due to: {e}.")
+
+ return model
+
+ def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
+ if not is_ipex_available():
+ raise ImportError(
+ "Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer"
+ " to https://github.com/intel/intel-extension-for-pytorch."
+ )
+
+ import intel_extension_for_pytorch as ipex
+
+ if not training:
+ model.eval()
+ dtype = torch.bfloat16 if not self.is_in_train and self.args.bf16_full_eval else dtype
+ # conv_bn_folding is disabled as it fails in symbolic tracing, resulting in ipex warnings
+ model = ipex.optimize(model, dtype=dtype, level="O1", conv_bn_folding=False, inplace=not self.is_in_train)
+ else:
+ if not model.training:
+ model.train()
+ model, self.optimizer = ipex.optimize(
+ model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1"
+ )
+
+ return model
+
+ def _wrap_model(self, model, training=True, dataloader=None):
+ if self.args.torch_compile:
+ model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode)
+
+ if self.args.use_ipex:
+ dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
+ model = self.ipex_optimize_model(model, training, dtype=dtype)
+
+ if is_sagemaker_mp_enabled():
+ # Wrapping the base model twice in a DistributedModel will raise an error.
+ if isinstance(self.model_wrapped, smp.model.DistributedModel):
+ return self.model_wrapped
+ return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
+
+ # already initialized its own DDP and AMP
+ if self.deepspeed:
+ return self.deepspeed
+
+ # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
+ if unwrap_model(model) is not model:
+ return model
+
+ # Mixed precision training with apex (torch < 1.6)
+ if self.use_apex and training:
+ model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
+
+ # Multi-gpu training (should be after apex fp16 initialization)
+ if self.args.n_gpu > 1:
+ model = nn.DataParallel(model)
+
+ if self.args.jit_mode_eval:
+ start_time = time.time()
+ model = self.torch_jit_model_eval(model, dataloader, training)
+ self.jit_compilation_time = round(time.time() - start_time, 4)
+
+ # Note: in torch.distributed mode, there's no point in wrapping the model
+ # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
+ if not training:
+ return model
+
+ # Distributed training (should be after apex fp16 initialization)
+ if self.sharded_ddp is not None:
+ # Sharded DDP!
+ if self.sharded_ddp == ShardedDDPOption.SIMPLE:
+ model = ShardedDDP(model, self.optimizer)
+ else:
+ mixed_precision = self.args.fp16 or self.args.bf16
+ cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
+ zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
+ # XXX: Breaking the self.model convention but I see no way around it for now.
+ if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
+ model = auto_wrap(model)
+ self.model = model = FullyShardedDDP(
+ model,
+ mixed_precision=mixed_precision,
+ reshard_after_forward=zero_3,
+ cpu_offload=cpu_offload,
+ ).to(self.args.device)
+ # Distributed training using PyTorch FSDP
+ elif self.fsdp is not None:
+ if not self.args.fsdp_config["xla"]:
+ # PyTorch FSDP!
+ from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
+ from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
+ from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
+
+ if FSDPOption.OFFLOAD in self.args.fsdp:
+ cpu_offload = CPUOffload(offload_params=True)
+ else:
+ cpu_offload = CPUOffload(offload_params=False)
+
+ auto_wrap_policy = None
+
+ if FSDPOption.AUTO_WRAP in self.args.fsdp:
+ if self.args.fsdp_config["fsdp_min_num_params"] > 0:
+ auto_wrap_policy = functools.partial(
+ size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
+ )
+ elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
+ transformer_cls_to_wrap = set()
+ for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
+ transformer_cls = get_module_class_from_name(model, layer_class)
+ if transformer_cls is None:
+ raise Exception("Could not find the transformer layer class to wrap in the model.")
+ else:
+ transformer_cls_to_wrap.add(transformer_cls)
+ auto_wrap_policy = functools.partial(
+ transformer_auto_wrap_policy,
+ # Transformer layer class to wrap
+ transformer_layer_cls=transformer_cls_to_wrap,
+ )
+ mixed_precision_policy = None
+ dtype = None
+ if self.args.fp16:
+ dtype = torch.float16
+ elif self.args.bf16:
+ dtype = torch.bfloat16
+ if dtype is not None:
+ mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
+ if type(model) != FSDP:
+ # XXX: Breaking the self.model convention but I see no way around it for now.
+ self.model = model = FSDP(
+ model,
+ sharding_strategy=self.fsdp,
+ cpu_offload=cpu_offload,
+ auto_wrap_policy=auto_wrap_policy,
+ mixed_precision=mixed_precision_policy,
+ device_id=self.args.device,
+ backward_prefetch=self.backward_prefetch,
+ forward_prefetch=self.forword_prefetch,
+ limit_all_gathers=self.limit_all_gathers,
+ )
+ else:
+ try:
+ from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
+ from torch_xla.distributed.fsdp import checkpoint_module
+ from torch_xla.distributed.fsdp.wrap import (
+ size_based_auto_wrap_policy,
+ transformer_auto_wrap_policy,
+ )
+ except ImportError:
+ raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
+ auto_wrap_policy = None
+ auto_wrapper_callable = None
+ if self.args.fsdp_config["fsdp_min_num_params"] > 0:
+ auto_wrap_policy = functools.partial(
+ size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
+ )
+ elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
+ transformer_cls_to_wrap = set()
+ for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
+ transformer_cls = get_module_class_from_name(model, layer_class)
+ if transformer_cls is None:
+ raise Exception("Could not find the transformer layer class to wrap in the model.")
+ else:
+ transformer_cls_to_wrap.add(transformer_cls)
+ auto_wrap_policy = functools.partial(
+ transformer_auto_wrap_policy,
+ # Transformer layer class to wrap
+ transformer_layer_cls=transformer_cls_to_wrap,
+ )
+ fsdp_kwargs = self.args.xla_fsdp_config
+ if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
+ # Apply gradient checkpointing to auto-wrapped sub-modules if specified
+ def auto_wrapper_callable(m, *args, **kwargs):
+ return FSDP(checkpoint_module(m), *args, **kwargs)
+
+ # Wrap the base model with an outer FSDP wrapper
+ self.model = model = FSDP(
+ model,
+ auto_wrap_policy=auto_wrap_policy,
+ auto_wrapper_callable=auto_wrapper_callable,
+ **fsdp_kwargs,
+ )
+
+ # Patch `xm.optimizer_step` should not reduce gradients in this case,
+ # as FSDP does not need gradient reduction over sharded parameters.
+ def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
+ loss = optimizer.step(**optimizer_args)
+ if barrier:
+ xm.mark_step()
+ return loss
+
+ xm.optimizer_step = patched_optimizer_step
+ elif is_sagemaker_dp_enabled():
+ model = nn.parallel.DistributedDataParallel(
+ model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
+ )
+ elif self.args.local_rank != -1:
+ kwargs = {}
+ if self.args.ddp_find_unused_parameters is not None:
+ kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
+ elif isinstance(model, PreTrainedModel):
+ # find_unused_parameters breaks checkpointing as per
+ # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
+ kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing
+ else:
+ kwargs["find_unused_parameters"] = True
+
+ if self.args.ddp_bucket_cap_mb is not None:
+ kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
+ if is_torch_neuroncore_available():
+ return model
+ model = nn.parallel.DistributedDataParallel(
+ model,
+ device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
+ output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
+ **kwargs,
+ )
+
+ return model
+
+ def train(
+ self,
+ resume_from_checkpoint: Optional[Union[str, bool]] = None,
+ trial: Union["optuna.Trial", Dict[str, Any]] = None,
+ ignore_keys_for_eval: Optional[List[str]] = None,
+ is_first_time = False,
+ **kwargs,
+ ):
+ """
+ Main training entry point.
+ Args:
+ resume_from_checkpoint (`str` or `bool`, *optional*):
+ If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
+ `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
+ of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here.
+ trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
+ The trial run or the hyperparameter dictionary for hyperparameter search.
+ ignore_keys_for_eval (`List[str]`, *optional*)
+ A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+ gathering predictions for evaluation during the training.
+ kwargs:
+ Additional keyword arguments used to hide deprecated arguments
+ """
+ if resume_from_checkpoint is False:
+ resume_from_checkpoint = None
+
+ # memory metrics - must set up as early as possible
+ self._memory_tracker.start()
+
+ args = self.args
+
+ #self.is_in_train = True
+
+ # do_train is not a reliable argument, as it might not be set and .train() still called, so
+ # the following is a workaround:
+ if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train:
+ self._move_model_to_device(self.model, args.device)
+
+ if "model_path" in kwargs:
+ resume_from_checkpoint = kwargs.pop("model_path")
+ warnings.warn(
+ "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
+ "instead.",
+ FutureWarning,
+ )
+ if len(kwargs) > 0:
+ raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
+ # This might change the seed so needs to run first.
+ self._hp_search_setup(trial)
+ self._train_batch_size = self.args.train_batch_size
+
+ # Model re-init
+ model_reloaded = False
+ if self.model_init is not None:
+ # Seed must be set before instantiating the model when using model_init.
+ enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
+ self.model = self.call_model_init(trial)
+ model_reloaded = True
+ # Reinitializes optimizer and scheduler
+ self.optimizer, self.lr_scheduler = None, None
+
+ # Load potential model checkpoint
+ if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
+ resume_from_checkpoint = get_last_checkpoint(args.output_dir)
+ if resume_from_checkpoint is None:
+ raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
+
+ if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and args.deepspeed is None:
+ self._load_from_checkpoint(resume_from_checkpoint)
+
+ # If model was re-initialized, put it on the right device and update self.model_wrapped
+ if model_reloaded:
+ if self.place_model_on_device:
+ self._move_model_to_device(self.model, args.device)
+ self.model_wrapped = self.model
+ if is_first_time:
+ inner_training_loop1 = find_executable_batch_size(
+ self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
+ )
+ return inner_training_loop1(
+ args=args,
+ resume_from_checkpoint=resume_from_checkpoint,
+ trial=trial,
+ ignore_keys_for_eval=ignore_keys_for_eval,
+ )
+ else:
+ inner_training_loop2 = find_executable_batch_size(
+ self._one_train, self._train_batch_size, args.auto_find_batch_size
+ )
+ return inner_training_loop2(
+ args=args,
+ resume_from_checkpoint=resume_from_checkpoint,
+ trial=trial,
+ ignore_keys_for_eval=ignore_keys_for_eval,
+ )
+
+
+ def _one_train(
+ self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
+ ):
+ #print(self.lr_scheduler)
+ #print(dir(self.lr_scheduler))
+
+
+ self.state = TrainerState()
+ self.state.is_hyper_param_search = trial is not None
+ # Get dataloader
+ self._train_batch_size = batch_size
+ # Data loader and number of training steps
+ train_dataloader = self.get_train_dataloader()
+ #print("AAAAAAA", len(train_dataloader))
+
+ total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
+
+ len_dataloader = None
+ if has_length(train_dataloader):
+ len_dataloader = len(train_dataloader)
+ num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
+ num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
+ num_examples = self.num_examples(train_dataloader)
+ if args.max_steps > 0:
+ max_steps = args.max_steps
+ num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
+ args.max_steps % num_update_steps_per_epoch > 0
+ )
+ # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
+ # the best we can do.
+ num_train_samples = args.max_steps * total_train_batch_size
+ else:
+ max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
+ num_train_epochs = math.ceil(args.num_train_epochs)
+ num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
+ elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
+ max_steps = args.max_steps
+ # Setting a very large number of epochs so we go as many times as necessary over the iterator.
+ num_train_epochs = sys.maxsize
+ num_update_steps_per_epoch = max_steps
+ num_examples = total_train_batch_size * args.max_steps
+ num_train_samples = args.max_steps * total_train_batch_size
+ else:
+ raise ValueError(
+ "args.max_steps must be set to a positive value if dataloader does not have a length, was"
+ f" {args.max_steps}"
+ )
+ ###########
+ #num_train_epochs = 5
+
+ # Train!
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {num_examples}")
+ logger.info(f" Num Epochs = {num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {max_steps}")
+ logger.info(
+ f" Number of trainable parameters = {sum(p.numel() for p in self.tmp_model.parameters() if p.requires_grad)}"
+ )
+
+ self.state.epoch = 0
+ start_time = time.time()
+ epochs_trained = 0
+ steps_trained_in_current_epoch = 0
+ steps_trained_progress_bar = None
+
+ # Update the references
+ self.callback_handler.model = self.model
+ self.callback_handler.optimizer = self.optimizer
+ self.callback_handler.lr_scheduler = self.lr_scheduler
+ self.callback_handler.train_dataloader = train_dataloader
+ if self.hp_name is not None and self._trial is not None:
+ # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
+ # parameter to Train when using DDP.
+ self.state.trial_name = self.hp_name(self._trial)
+ if trial is not None:
+ assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
+ self.state.trial_params = hp_params(assignments)
+ else:
+ self.state.trial_params = None
+ # This should be the same if the state has been saved but in case the training arguments changed, it's safer
+ # to set this after the load.
+ self.state.max_steps = max_steps
+ self.state.num_train_epochs = num_train_epochs
+ self.state.is_local_process_zero = self.is_local_process_zero()
+ self.state.is_world_process_zero = self.is_world_process_zero()
+
+ # tr_loss is a tensor to avoid synchronization of TPUs through .item()
+ tr_loss = torch.tensor(0.0).to(args.device)
+ # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
+ self._total_loss_scalar = 0.0
+ self._globalstep_last_logged = self.state.global_step
+ #model.zero_grad()
+ self.tmp_model.zero_grad()
+
+ self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
+
+ # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
+ if not args.ignore_data_skip:
+ #print("I skip!") called
+ for epoch in range(epochs_trained):
+ is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance(
+ train_dataloader.sampler, RandomSampler
+ )
+ if is_torch_less_than_1_11 or not is_random_sampler:
+ # We just need to begin an iteration to create the randomization of the sampler.
+ # That was before PyTorch 1.11 however...
+ for _ in train_dataloader:
+ break
+ else:
+ # Otherwise we need to call the whooooole sampler cause there is some random operation added
+ # AT THE VERY END!
+ _ = list(train_dataloader.sampler)
+
+ ###############
+ #num_train_epochs = 10
+ self.is_in_train = True
+ #print("The number of epoches: ", num_train_epochs)
+ #############
+ total_batched_samples = 0
+ for epoch in range(epochs_trained, num_train_epochs):
+ if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
+ train_dataloader.sampler.set_epoch(epoch)
+ elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard):
+ train_dataloader.dataset.set_epoch(epoch)
+
+ if is_torch_tpu_available():
+ parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device)
+ epoch_iterator = parallel_loader
+ else:
+ epoch_iterator = train_dataloader
+
+ # Reset the past mems state at the beginning of each epoch if necessary.
+ if args.past_index >= 0:
+ self._past = None
+
+ steps_in_epoch = (
+ len(epoch_iterator)
+ if len_dataloader is not None
+ else args.max_steps * args.gradient_accumulation_steps
+ )
+ self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
+
+ if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
+ self._load_rng_state(resume_from_checkpoint)
+
+ rng_to_sync = False
+ steps_skipped = 0
+ if skip_first_batches is not None and steps_trained_in_current_epoch > 0:
+ epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
+ steps_skipped = steps_trained_in_current_epoch
+ steps_trained_in_current_epoch = 0
+ rng_to_sync = True
+
+ #print("The number of one epoch: ", len(epoch_iterator))
+ step = -1
+ for step, inputs in enumerate(epoch_iterator):
+ total_batched_samples += 1
+ if rng_to_sync:
+ self._load_rng_state(resume_from_checkpoint)
+ rng_to_sync = False
+
+ # Skip past any already trained steps if resuming training
+ if steps_trained_in_current_epoch > 0:
+ steps_trained_in_current_epoch -= 1
+ if steps_trained_progress_bar is not None:
+ steps_trained_progress_bar.update(1)
+ if steps_trained_in_current_epoch == 0:
+ self._load_rng_state(resume_from_checkpoint)
+ continue
+ elif steps_trained_progress_bar is not None:
+ steps_trained_progress_bar.close()
+ steps_trained_progress_bar = None
+
+ if step % args.gradient_accumulation_steps == 0:
+ self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
+
+ if (
+ (total_batched_samples % args.gradient_accumulation_steps != 0)
+ and args.local_rank != -1
+ and args._no_sync_in_gradient_accumulation
+ ):
+ # Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
+ with self.tmp_model.no_sync():
+ tr_loss_step = self.training_step(self.tmp_model, inputs)
+ #with model.no_sync():
+ #tr_loss_step = self.training_step(model, inputs)
+ else:
+ tr_loss_step = self.training_step(self.tmp_model, inputs)
+
+ if (
+ args.logging_nan_inf_filter
+ and not is_torch_tpu_available()
+ and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
+ ):
+ # if loss is nan or inf simply add the average of previous logged losses
+ tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
+ else:
+ tr_loss += tr_loss_step
+
+ self.current_flos += float(self.floating_point_ops(inputs))
+
+ # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps
+ if self.deepspeed:
+ self.deepspeed.step()
+
+ if total_batched_samples % args.gradient_accumulation_steps == 0 or (
+ # last step in epoch but step is always smaller than gradient_accumulation_steps
+ steps_in_epoch <= args.gradient_accumulation_steps
+ and (step + 1) == steps_in_epoch
+ ):
+ # Gradient clipping
+ if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed:
+ # deepspeed does its own clipping
+
+ if self.do_grad_scaling:
+ # Reduce gradients first for XLA
+ if is_torch_tpu_available():
+ gradients = xm._fetch_gradients(self.optimizer)
+ xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
+ # AMP: gradients need unscaling
+ self.scaler.unscale_(self.optimizer)
+
+ if is_sagemaker_mp_enabled() and args.fp16:
+ self.optimizer.clip_master_grads(args.max_grad_norm)
+ elif hasattr(self.optimizer, "clip_grad_norm"):
+ # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
+ self.optimizer.clip_grad_norm(args.max_grad_norm)
+ elif hasattr(model, "clip_grad_norm_"):
+ # Some models (like FullyShardedDDP) have a specific way to do gradient clipping
+ model.clip_grad_norm_(args.max_grad_norm)
+ else:
+ # Revert to normal clipping otherwise, handling Apex or full precision
+ nn.utils.clip_grad_norm_(
+ amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
+ args.max_grad_norm,
+ )
+
+ # Optimizer step
+ optimizer_was_run = True
+ if self.deepspeed:
+ pass # called outside the loop
+ elif is_torch_tpu_available():
+ if self.do_grad_scaling:
+ self.scaler.step(self.optimizer)
+ self.scaler.update()
+ else:
+ xm.optimizer_step(self.optimizer)
+ elif self.do_grad_scaling:
+ scale_before = self.scaler.get_scale()
+ self.scaler.step(self.optimizer)
+ self.scaler.update()
+ scale_after = self.scaler.get_scale()
+ optimizer_was_run = scale_before <= scale_after
+ else:
+ self.optimizer.step()
+
+ if optimizer_was_run and not self.deepspeed:
+ self.lr_scheduler.step()
+
+ self.tmp_model.zero_grad()
+ self.state.global_step += 1
+ self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
+ self.control = self.callback_handler.on_step_end(args, self.state, self.control)
+
+ self._maybe_log_save_evaluate(tr_loss, self.tmp_model, trial, epoch, ignore_keys_for_eval)
+ else:
+ self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
+
+ if self.control.should_epoch_stop or self.control.should_training_stop:
+ break
+ if step < 0:
+ logger.warning(
+ "There seems to be not a single sample in your epoch_iterator, stopping training at step"
+ f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
+ f" num_steps ({max_steps}) higher than the number of available samples."
+ )
+ self.control.should_training_stop = True
+
+ self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
+ self._maybe_log_save_evaluate(tr_loss, self.tmp_model, trial, epoch, ignore_keys_for_eval)
+
+ if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
+ if is_torch_tpu_available():
+ # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
+ xm.master_print(met.metrics_report())
+ else:
+ logger.warning(
+ "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
+ "configured. Check your training configuration if this is unexpected."
+ )
+ if self.control.should_training_stop:
+ break
+
+ if args.past_index and hasattr(self, "_past"):
+ # Clean the state at the end of training
+ delattr(self, "_past")
+
+ logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
+ if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
+ # Wait for everyone to get here so we are sur the model has been saved by process 0.
+ if is_torch_tpu_available():
+ xm.rendezvous("load_best_model_at_end")
+ elif args.local_rank != -1:
+ dist.barrier()
+ elif is_sagemaker_mp_enabled():
+ smp.barrier()
+
+ self._load_best_model()
+
+ # add remaining tr_loss
+ self._total_loss_scalar += tr_loss.item()
+ train_loss = self._total_loss_scalar / self.state.global_step
+
+ metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
+ self.store_flos()
+ metrics["total_flos"] = self.state.total_flos
+ metrics["train_loss"] = train_loss
+
+ self.is_in_train = False
+
+ self._memory_tracker.stop_and_update_metrics(metrics)
+
+ self.log(metrics)
+
+ run_dir = self._get_output_dir(trial)
+ checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)
+
+ # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
+ if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
+ for checkpoint in checkpoints_sorted:
+ if checkpoint != self.state.best_model_checkpoint:
+ logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
+ shutil.rmtree(checkpoint)
+
+ self.control = self.callback_handler.on_train_end(args, self.state, self.control)
+
+ return TrainOutput(self.state.global_step, train_loss, metrics)
+
+ def _inner_training_loop(
+ self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
+ ):
+ '''
+ 0 This function serves to train one time
+ 1 Update the self.train_dataset before calling this function
+ '''
+ # 1 Get dataloader
+ self._train_batch_size = batch_size
+ # Data loader and number of training steps
+ train_dataloader = self.get_train_dataloader()
+ total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
+
+ len_dataloader = None
+ if has_length(train_dataloader):
+ len_dataloader = len(train_dataloader)
+ num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
+ num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
+ num_examples = self.num_examples(train_dataloader)
+ if args.max_steps > 0:
+ max_steps = args.max_steps
+ num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
+ args.max_steps % num_update_steps_per_epoch > 0
+ )
+ # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
+ # the best we can do.
+ num_train_samples = args.max_steps * total_train_batch_size
+ else:
+ max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
+ num_train_epochs = math.ceil(args.num_train_epochs)
+ num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
+ elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
+ max_steps = args.max_steps
+ # Setting a very large number of epochs so we go as many times as necessary over the iterator.
+ num_train_epochs = sys.maxsize
+ num_update_steps_per_epoch = max_steps
+ num_examples = total_train_batch_size * args.max_steps
+ num_train_samples = args.max_steps * total_train_batch_size
+ else:
+ raise ValueError(
+ "args.max_steps must be set to a positive value if dataloader does not have a length, was"
+ f" {args.max_steps}"
+ )
+
+ if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
+ if self.args.n_gpu > 1:
+ # nn.DataParallel(model) replicates the model, creating new variables and module
+ # references registered here no longer work on other gpus, breaking the module
+ raise ValueError(
+ "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
+ " (torch.distributed.launch)."
+ )
+ else:
+ debug_overflow = DebugUnderflowOverflow(self.model) # noqa
+
+ delay_optimizer_creation = (
+ self.sharded_ddp is not None
+ and self.sharded_ddp != ShardedDDPOption.SIMPLE
+ or is_sagemaker_mp_enabled()
+ or self.fsdp is not None
+ )
+ if args.deepspeed:
+ deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
+ self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint
+ )
+ self.model = deepspeed_engine.module
+ self.model_wrapped = deepspeed_engine
+ self.deepspeed = deepspeed_engine
+ self.optimizer = optimizer
+ self.lr_scheduler = lr_scheduler
+ #print("I just create a optimizer here!") # called
+ elif not delay_optimizer_creation:
+ self.create_optimizer_and_scheduler(num_training_steps=max_steps)
+
+ self.state = TrainerState()
+ self.state.is_hyper_param_search = trial is not None
+
+ # Activate gradient checkpointing if needed
+ if args.gradient_checkpointing:
+ self.model.gradient_checkpointing_enable()
+
+ #model = self._wrap_model(self.model_wrapped)
+ self.tmp_model = self._wrap_model(self.model_wrapped)
+
+
+ #if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
+ # self._load_from_checkpoint(resume_from_checkpoint, model)
+
+ # for the rest of this function `model` is the outside model, whether it was wrapped or not
+ if self.tmp_model is not self.model:
+ self.model_wrapped = self.tmp_model
+
+ if delay_optimizer_creation:
+ print("I create here!") # not called
+ self.create_optimizer_and_scheduler(num_training_steps=max_steps)
+
+ return True
+ # Check if saved optimizer or scheduler states exist
+ #self._load_optimizer_and_scheduler(resume_from_checkpoint)
+
+ # important: at this point:
+ # self.model is the Transformers Model
+ # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
+
+
+
+ def _get_output_dir(self, trial):
+ if self.hp_search_backend is not None and trial is not None:
+ if self.hp_search_backend == HPSearchBackend.OPTUNA:
+ run_id = trial.number
+ elif self.hp_search_backend == HPSearchBackend.RAY:
+ from ray import tune
+
+ run_id = tune.get_trial_id()
+ elif self.hp_search_backend == HPSearchBackend.SIGOPT:
+ run_id = trial.id
+ elif self.hp_search_backend == HPSearchBackend.WANDB:
+ import wandb
+
+ run_id = wandb.run.id
+ run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
+ run_dir = os.path.join(self.args.output_dir, run_name)
+ else:
+ run_dir = self.args.output_dir
+ return run_dir
+
+ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
+ if model is None:
+ model = self.model
+
+ if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile(
+ os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
+ ):
+ raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
+
+ logger.info(f"Loading model from {resume_from_checkpoint}.")
+
+ if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)):
+ config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
+ checkpoint_version = config.transformers_version
+ if checkpoint_version is not None and checkpoint_version != __version__:
+ logger.warning(
+ f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
+ f"Transformers but your current version is {__version__}. This is not recommended and could "
+ "yield to errors or unwanted behaviors."
+ )
+
+ if os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
+ # If the model is on the GPU, it still works!
+ if is_sagemaker_mp_enabled():
+ if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
+ # If the 'user_content.pt' file exists, load with the new smp api.
+ # Checkpoint must have been saved with the new smp api.
+ smp.resume_from_checkpoint(
+ path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False
+ )
+ else:
+ # If the 'user_content.pt' file does NOT exist, load with the old smp api.
+ # Checkpoint must have been saved with the old smp api.
+ if hasattr(self.args, "fp16") and self.args.fp16 is True:
+ logger.warning(
+ "Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported."
+ )
+ state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
+ # Required for smp to not auto-translate state_dict from hf to smp (is already smp).
+ state_dict["_smp_is_partial"] = False
+ load_result = model.load_state_dict(state_dict, strict=True)
+ # release memory
+ del state_dict
+ else:
+ # We load the model state dict on the CPU to avoid an OOM error.
+ state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
+ # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
+ # which takes *args instead of **kwargs
+ load_result = model.load_state_dict(state_dict, False)
+ # release memory
+ del state_dict
+ self._issue_warnings_after_load(load_result)
+ else:
+ # We load the sharded checkpoint
+ load_result = load_sharded_checkpoint(model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled())
+ if not is_sagemaker_mp_enabled():
+ self._issue_warnings_after_load(load_result)
+
+ def _load_best_model(self):
+ logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
+ best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
+ model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
+ if os.path.exists(best_model_path):
+ if self.deepspeed:
+ if self.model_wrapped is not None:
+ # this removes the pre-hooks from the previous engine
+ self.model_wrapped.destroy()
+ self.model_wrapped = None
+
+ # temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping
+ deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
+ self,
+ num_training_steps=self.args.max_steps,
+ resume_from_checkpoint=self.state.best_model_checkpoint,
+ )
+ self.model = deepspeed_engine.module
+ self.model_wrapped = deepspeed_engine
+ self.deepspeed = deepspeed_engine
+ self.optimizer = optimizer
+ self.lr_scheduler = lr_scheduler
+ else:
+ if is_sagemaker_mp_enabled():
+ if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
+ # If the 'user_content.pt' file exists, load with the new smp api.
+ # Checkpoint must have been saved with the new smp api.
+ smp.resume_from_checkpoint(
+ path=self.state.best_model_checkpoint,
+ tag=WEIGHTS_NAME,
+ partial=False,
+ load_optimizer=False,
+ )
+ else:
+ # If the 'user_content.pt' file does NOT exist, load with the old smp api.
+ # Checkpoint must have been saved with the old smp api.
+ state_dict = torch.load(best_model_path, map_location="cpu")
+ state_dict["_smp_is_partial"] = False
+ load_result = model.load_state_dict(state_dict, strict=True)
+ else:
+ # We load the model state dict on the CPU to avoid an OOM error.
+ state_dict = torch.load(best_model_path, map_location="cpu")
+ # If the model is on the GPU, it still works!
+ # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
+ # which takes *args instead of **kwargs
+ load_result = model.load_state_dict(state_dict, False)
+ if not is_sagemaker_mp_enabled():
+ self._issue_warnings_after_load(load_result)
+ elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
+ load_result = load_sharded_checkpoint(
+ model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled()
+ )
+ if not is_sagemaker_mp_enabled():
+ self._issue_warnings_after_load(load_result)
+ else:
+ logger.warning(
+ f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
+ "on multiple nodes, you should activate `--save_on_each_node`."
+ )
+
+ def _issue_warnings_after_load(self, load_result):
+ if len(load_result.missing_keys) != 0:
+ if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(
+ self.model._keys_to_ignore_on_save
+ ):
+ self.model.tie_weights()
+ else:
+ logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.")
+ if len(load_result.unexpected_keys) != 0:
+ logger.warning(
+ f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
+ )
+
+ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
+ if self.control.should_log:
+ if is_torch_tpu_available():
+ xm.mark_step()
+
+ logs: Dict[str, float] = {}
+
+ # all_gather + mean() to get average loss over all processes
+ tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
+
+ # reset tr_loss to zero
+ tr_loss -= tr_loss
+
+ logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
+ logs["learning_rate"] = self._get_learning_rate()
+
+ self._total_loss_scalar += tr_loss_scalar
+ self._globalstep_last_logged = self.state.global_step
+ self.store_flos()
+
+ self.log(logs)
+
+ metrics = None
+ if self.control.should_evaluate:
+ if isinstance(self.eval_dataset, dict):
+ for eval_dataset_name, eval_dataset in self.eval_dataset.items():
+ metrics = self.evaluate(
+ eval_dataset=eval_dataset,
+ ignore_keys=ignore_keys_for_eval,
+ metric_key_prefix=f"eval_{eval_dataset_name}",
+ )
+ else:
+ metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
+ self._report_to_hp_search(trial, self.state.global_step, metrics)
+
+ if self.control.should_save:
+ self._save_checkpoint(model, trial, metrics=metrics)
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
+
+ def _load_rng_state(self, checkpoint):
+ # Load RNG states from `checkpoint`
+ if checkpoint is None:
+ return
+
+ if self.args.world_size > 1:
+ process_index = self.args.process_index
+ rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
+ if not os.path.isfile(rng_file):
+ logger.info(
+ f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
+ "wasn't launched in a distributed fashion, reproducibility is not guaranteed."
+ )
+ return
+ else:
+ rng_file = os.path.join(checkpoint, "rng_state.pth")
+ if not os.path.isfile(rng_file):
+ logger.info(
+ "Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
+ "fashion, reproducibility is not guaranteed."
+ )
+ return
+
+ checkpoint_rng_state = torch.load(rng_file)
+ random.setstate(checkpoint_rng_state["python"])
+ np.random.set_state(checkpoint_rng_state["numpy"])
+ torch.random.set_rng_state(checkpoint_rng_state["cpu"])
+ if torch.cuda.is_available():
+ if self.args.local_rank != -1:
+ torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
+ else:
+ try:
+ torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
+ except Exception as e:
+ logger.info(
+ f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"
+ "\nThis won't yield the same results as if the training had not been interrupted."
+ )
+ if is_torch_tpu_available():
+ xm.set_rng_state(checkpoint_rng_state["xla"])
+
+ def _save_checkpoint(self, model, trial, metrics=None):
+ # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
+ # want to save except FullyShardedDDP.
+ # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
+
+ # Save model checkpoint
+ #checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.save_counter}"
+ ##########
+ self.save_counter += 1
+ ##########
+ if self.hp_search_backend is None and trial is None:
+ self.store_flos()
+
+ run_dir = self._get_output_dir(trial=trial)
+ output_dir = os.path.join(run_dir, checkpoint_folder)
+ self.save_model(output_dir, _internal_call=True)
+ if self.deepspeed:
+ # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
+ # config `stage3_gather_16bit_weights_on_model_save` is True
+ self.deepspeed.save_checkpoint(output_dir)
+
+ # Save optimizer and scheduler
+ if self.sharded_ddp == ShardedDDPOption.SIMPLE:
+ self.optimizer.consolidate_state_dict()
+
+ if is_torch_tpu_available():
+ xm.rendezvous("saving_optimizer_states")
+ xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
+ reissue_pt_warnings(caught_warnings)
+ elif is_sagemaker_mp_enabled():
+ opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
+ smp.barrier()
+ if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
+ smp.save(
+ opt_state_dict,
+ os.path.join(output_dir, OPTIMIZER_NAME),
+ partial=True,
+ v3=smp.state.cfg.shard_optimizer_state,
+ )
+ if self.args.should_save:
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
+ reissue_pt_warnings(caught_warnings)
+ if self.do_grad_scaling:
+ torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
+ elif self.args.should_save and not self.deepspeed:
+ # deepspeed.save_checkpoint above saves model/optim/sched
+ torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
+ reissue_pt_warnings(caught_warnings)
+ if self.do_grad_scaling:
+ torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
+
+ # Determine the new best metric / best model checkpoint
+ if metrics is not None and self.args.metric_for_best_model is not None:
+ metric_to_check = self.args.metric_for_best_model
+ if not metric_to_check.startswith("eval_"):
+ metric_to_check = f"eval_{metric_to_check}"
+ metric_value = metrics[metric_to_check]
+
+ operator = np.greater if self.args.greater_is_better else np.less
+ if (
+ self.state.best_metric is None
+ or self.state.best_model_checkpoint is None
+ or operator(metric_value, self.state.best_metric)
+ ):
+ self.state.best_metric = metric_value
+ self.state.best_model_checkpoint = output_dir
+
+ # Save the Trainer state
+ if self.args.should_save:
+ self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
+
+ # Save RNG state in non-distributed training
+ rng_states = {
+ "python": random.getstate(),
+ "numpy": np.random.get_state(),
+ "cpu": torch.random.get_rng_state(),
+ }
+ if torch.cuda.is_available():
+ if self.args.local_rank == -1:
+ # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
+ rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
+ else:
+ rng_states["cuda"] = torch.cuda.random.get_rng_state()
+
+ if is_torch_tpu_available():
+ rng_states["xla"] = xm.get_rng_state()
+
+ # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
+ # not yet exist.
+ os.makedirs(output_dir, exist_ok=True)
+
+ if self.args.world_size <= 1:
+ torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
+ else:
+ torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
+
+ if self.args.push_to_hub:
+ self._push_from_checkpoint(output_dir)
+
+ # Maybe delete some older checkpoints.
+ if self.args.should_save:
+ self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
+
+ def _load_optimizer_and_scheduler(self, checkpoint):
+ """If optimizer and scheduler states exist, load them."""
+ if checkpoint is None:
+ return
+
+ if self.deepspeed:
+ # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
+ return
+
+ checkpoint_file_exists = (
+ glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
+ if is_sagemaker_mp_enabled()
+ else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
+ )
+ if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
+ # Load in optimizer and scheduler states
+ if is_torch_tpu_available():
+ # On TPU we have to take some extra precautions to properly load the states on the right device.
+ optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
+ reissue_pt_warnings(caught_warnings)
+
+ xm.send_cpu_data_to_device(optimizer_state, self.args.device)
+ xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)
+
+ self.optimizer.load_state_dict(optimizer_state)
+ self.lr_scheduler.load_state_dict(lr_scheduler_state)
+ else:
+ map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
+ if is_sagemaker_mp_enabled():
+ if os.path.isfile(os.path.join(checkpoint, "user_content.pt")):
+ # Optimizer checkpoint was saved with smp >= 1.10
+ def opt_load_hook(mod, opt):
+ opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
+
+ else:
+ # Optimizer checkpoint was saved with smp < 1.10
+ def opt_load_hook(mod, opt):
+ if IS_SAGEMAKER_MP_POST_1_10:
+ opt.load_state_dict(
+ smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True)
+ )
+ else:
+ opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
+
+ self.model_wrapped.register_post_step_hook(opt_load_hook)
+ else:
+ self.optimizer.load_state_dict(
+ torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
+ )
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
+ reissue_pt_warnings(caught_warnings)
+ if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
+ self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
+
+ def hyperparameter_search(
+ self,
+ hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
+ compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
+ n_trials: int = 20,
+ direction: str = "minimize",
+ backend: Optional[Union["str", HPSearchBackend]] = None,
+ hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
+ **kwargs,
+ ) -> BestRun:
+ """
+ Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined
+ by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided,
+ the sum of all metrics otherwise.
+
+ To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to
+ reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to
+ subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom
+ optimizer/scheduler.
+
+ Args:
+ hp_space (`Callable[["optuna.Trial"], Dict[str, float]]`, *optional*):
+ A function that defines the hyperparameter search space. Will default to
+ [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or
+ [`~trainer_utils.default_hp_space_sigopt`] depending on your backend.
+ compute_objective (`Callable[[Dict[str, float]], float]`, *optional*):
+ A function computing the objective to minimize or maximize from the metrics returned by the `evaluate`
+ method. Will default to [`~trainer_utils.default_compute_objective`].
+ n_trials (`int`, *optional*, defaults to 100):
+ The number of trial runs to test.
+ direction (`str`, *optional*, defaults to `"minimize"`):
+ Whether to optimize greater or lower objects. Can be `"minimize"` or `"maximize"`, you should pick
+ `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or several metrics.
+ backend (`str` or [`~training_utils.HPSearchBackend`], *optional*):
+ The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending
+ on which one is installed. If all are installed, will default to optuna.
+ hp_name (`Callable[["optuna.Trial"], str]]`, *optional*):
+ A function that defines the trial/run name. Will default to None.
+ kwargs (`Dict[str, Any]`, *optional*):
+ Additional keyword arguments passed along to `optuna.create_study` or `ray.tune.run`. For more
+ information see:
+ - the documentation of
+ [optuna.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html)
+ - the documentation of [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run)
+ - the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create)
+ Returns:
+ [`trainer_utils.BestRun`]: All the information about the best run. Experiment summary can be found in
+ `run_summary` attribute for Ray backend.
+ """
+ if backend is None:
+ backend = default_hp_search_backend()
+ if backend is None:
+ raise RuntimeError(
+ "At least one of optuna or ray should be installed. "
+ "To install optuna run `pip install optuna`. "
+ "To install ray run `pip install ray[tune]`. "
+ "To install sigopt run `pip install sigopt`."
+ )
+ backend = HPSearchBackend(backend)
+ if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
+ raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
+ if backend == HPSearchBackend.RAY and not is_ray_tune_available():
+ raise RuntimeError(
+ "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
+ )
+ if backend == HPSearchBackend.SIGOPT and not is_sigopt_available():
+ raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.")
+ if backend == HPSearchBackend.WANDB and not is_wandb_available():
+ raise RuntimeError("You picked the wandb backend, but it is not installed. Use `pip install wandb`.")
+ self.hp_search_backend = backend
+ if self.model_init is None:
+ raise RuntimeError(
+ "To use hyperparameter search, you need to pass your model through a model_init function."
+ )
+
+ self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
+ self.hp_name = hp_name
+ self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
+
+ backend_dict = {
+ HPSearchBackend.OPTUNA: run_hp_search_optuna,
+ HPSearchBackend.RAY: run_hp_search_ray,
+ HPSearchBackend.SIGOPT: run_hp_search_sigopt,
+ HPSearchBackend.WANDB: run_hp_search_wandb,
+ }
+ best_run = backend_dict[backend](self, n_trials, direction, **kwargs)
+
+ self.hp_search_backend = None
+ return best_run
+
+ def log(self, logs: Dict[str, float]) -> None:
+ """
+ Log `logs` on the various objects watching training.
+ Subclass and override this method to inject custom behavior.
+ Args:
+ logs (`Dict[str, float]`):
+ The values to log.
+ """
+ if self.state.epoch is not None:
+ logs["epoch"] = round(self.state.epoch, 2)
+
+ output = {**logs, **{"step": self.state.global_step}}
+ self.state.log_history.append(output)
+ self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
+
+ def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
+ """
+ Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
+ """
+ if isinstance(data, Mapping):
+ return type(data)({k: self._prepare_input(v) for k, v in data.items()})
+ elif isinstance(data, (tuple, list)):
+ return type(data)(self._prepare_input(v) for v in data)
+ elif isinstance(data, torch.Tensor):
+ kwargs = {"device": self.args.device}
+ if self.deepspeed and (torch.is_floating_point(data) or torch.is_complex(data)):
+ # NLP models inputs are int/uint and those get adjusted to the right dtype of the
+ # embedding. Other models such as wav2vec2's inputs are already float and thus
+ # may need special handling to match the dtypes of the model
+ kwargs.update({"dtype": self.args.hf_deepspeed_config.dtype()})
+ return data.to(**kwargs)
+ return data
+
+ def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
+ """
+ Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
+ handling potential state.
+ """
+ inputs = self._prepare_input(inputs)
+ if len(inputs) == 0:
+ raise ValueError(
+ "The batch received was empty, your model won't be able to train on it. Double-check that your "
+ f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
+ )
+ if self.args.past_index >= 0 and self._past is not None:
+ inputs["mems"] = self._past
+
+ return inputs
+
+ def compute_loss_context_manager(self):
+ """
+ A helper wrapper to group together context managers.
+ """
+ return self.autocast_smart_context_manager()
+
+ def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
+ """
+ A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
+ arguments, depending on the situation.
+ """
+ if self.use_cuda_amp or self.use_cpu_amp:
+ if is_torch_greater_or_equal_than_1_10:
+ ctx_manager = (
+ torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
+ if self.use_cpu_amp
+ else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
+ )
+ else:
+ ctx_manager = torch.cuda.amp.autocast()
+ else:
+ ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()
+
+ return ctx_manager
+
+ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
+ """
+ Perform a training step on a batch of inputs.
+ Subclass and override to inject custom behavior.
+ Args:
+ model (`nn.Module`):
+ The model to train.
+ inputs (`Dict[str, Union[torch.Tensor, Any]]`):
+ The inputs and targets of the model.
+ The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
+ argument `labels`. Check your model's documentation for all accepted arguments.
+ Return:
+ `torch.Tensor`: The tensor with training loss on this batch.
+ """
+ model.train()
+ inputs = self._prepare_inputs(inputs)
+
+ if is_sagemaker_mp_enabled():
+ loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
+ return loss_mb.reduce_mean().detach().to(self.args.device)
+
+ with self.compute_loss_context_manager():
+ loss = self.compute_loss(model, inputs)
+
+ if self.args.n_gpu > 1:
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
+
+ if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
+ # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
+ loss = loss / self.args.gradient_accumulation_steps
+
+ if self.do_grad_scaling:
+ self.scaler.scale(loss).backward()
+ elif self.use_apex:
+ with amp.scale_loss(loss, self.optimizer) as scaled_loss:
+ scaled_loss.backward()
+ elif self.deepspeed:
+ # loss gets scaled under gradient_accumulation_steps in deepspeed
+ loss = self.deepspeed.backward(loss)
+ else:
+ loss.backward()
+
+ return loss.detach()
+
+ def compute_loss(self, model, inputs, return_outputs=False):
+ """
+ How the loss is computed by Trainer. By default, all models return the loss in the first element.
+ Subclass and override for custom behavior.
+ """
+ if self.label_smoother is not None and "labels" in inputs:
+ labels = inputs.pop("labels")
+ else:
+ labels = None
+ outputs = model(**inputs)
+ # Save past state if it exists
+ # TODO: this needs to be fixed and made cleaner later.
+ if self.args.past_index >= 0:
+ self._past = outputs[self.args.past_index]
+
+ if labels is not None:
+ if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
+ loss = self.label_smoother(outputs, labels, shift_labels=True)
+ else:
+ loss = self.label_smoother(outputs, labels)
+ else:
+ if isinstance(outputs, dict) and "loss" not in outputs:
+ raise ValueError(
+ "The model did not return a loss from the inputs, only the following keys: "
+ f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
+ )
+ # We don't use .loss here since the model may return tuples instead of ModelOutput.
+ loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
+
+ return (loss, outputs) if return_outputs else loss
+
+ def is_local_process_zero(self) -> bool:
+ """
+ Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
+ machines) main process.
+ """
+ return self.args.local_process_index == 0
+
+ def is_world_process_zero(self) -> bool:
+ """
+ Whether or not this process is the global main process (when training in a distributed fashion on several
+ machines, this is only going to be `True` for one process).
+ """
+ # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global
+ # process index.
+ if is_sagemaker_mp_enabled():
+ return smp.rank() == 0
+ else:
+ return self.args.process_index == 0
+
+ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
+ """
+ Will save the model, so you can reload it using `from_pretrained()`.
+ Will only save from the main process.
+ """
+
+ if output_dir is None:
+ output_dir = self.args.output_dir
+
+ if is_torch_tpu_available():
+ self._save_tpu(output_dir)
+ elif is_sagemaker_mp_enabled():
+ # Calling the state_dict needs to be done on the wrapped model and on all processes.
+ os.makedirs(output_dir, exist_ok=True)
+ state_dict = self.model_wrapped.state_dict()
+ if self.args.should_save:
+ self._save(output_dir, state_dict=state_dict)
+ if IS_SAGEMAKER_MP_POST_1_10:
+ # 'user_content.pt' indicates model state_dict saved with smp >= 1.10
+ Path(os.path.join(output_dir, "user_content.pt")).touch()
+ elif (
+ ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
+ or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
+ or self.fsdp is not None
+ ):
+ state_dict = self.model.state_dict()
+
+ if self.args.should_save:
+ self._save(output_dir, state_dict=state_dict)
+ elif self.deepspeed:
+ # this takes care of everything as long as we aren't under zero3
+ if self.args.should_save:
+ self._save(output_dir)
+
+ if is_deepspeed_zero3_enabled():
+ # It's too complicated to try to override different places where the weights dump gets
+ # saved, so since under zero3 the file is bogus, simply delete it. The user should
+ # either user deepspeed checkpoint to resume or to recover full weights use
+ # zero_to_fp32.py stored in the checkpoint.
+ if self.args.should_save:
+ file = os.path.join(output_dir, WEIGHTS_NAME)
+ if os.path.isfile(file):
+ # logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
+ os.remove(file)
+
+ # now save the real model if stage3_gather_16bit_weights_on_model_save=True
+ # if false it will not be saved.
+ # This must be called on all ranks
+ if not self.deepspeed.save_16bit_model(output_dir, WEIGHTS_NAME):
+ logger.warning(
+ "deepspeed.save_16bit_model didn't save the model, since"
+ " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
+ " zero_to_fp32.py to recover weights"
+ )
+ self.deepspeed.save_checkpoint(output_dir)
+
+ elif self.args.should_save:
+ self._save(output_dir)
+
+ # Push to the Hub when `save_model` is called by the user.
+ if self.args.push_to_hub and not _internal_call:
+ self.push_to_hub(commit_message="Model save")
+
+ def _save_tpu(self, output_dir: Optional[str] = None):
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
+ logger.info(f"Saving model checkpoint to {output_dir}")
+
+ if xm.is_master_ordinal():
+ os.makedirs(output_dir, exist_ok=True)
+ torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
+
+ # Save a trained model and configuration using `save_pretrained()`.
+ # They can then be reloaded using `from_pretrained()`
+ xm.rendezvous("saving_checkpoint")
+ if not isinstance(self.model, PreTrainedModel):
+ if isinstance(unwrap_model(self.model), PreTrainedModel):
+ unwrap_model(self.model).save_pretrained(
+ output_dir,
+ is_main_process=self.args.should_save,
+ state_dict=self.model.state_dict(),
+ save_function=xm.save,
+ )
+ else:
+ logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
+ state_dict = self.model.state_dict()
+ xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
+ else:
+ self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
+ if self.tokenizer is not None and self.args.should_save:
+ self.tokenizer.save_pretrained(output_dir)
+
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
+ # If we are executing this function, we are the process zero, so we don't check for that.
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
+ os.makedirs(output_dir, exist_ok=True)
+ logger.info(f"Saving model checkpoint to {output_dir}")
+ # Save a trained model and configuration using `save_pretrained()`.
+ # They can then be reloaded using `from_pretrained()`
+ if not isinstance(self.model, PreTrainedModel):
+ if isinstance(unwrap_model(self.model), PreTrainedModel):
+ if state_dict is None:
+ state_dict = self.model.state_dict()
+ unwrap_model(self.model).save_pretrained(output_dir, state_dict=state_dict)
+ else:
+ logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
+ if state_dict is None:
+ state_dict = self.model.state_dict()
+ torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
+ else:
+ self.model.save_pretrained(output_dir, state_dict=state_dict)
+ if self.tokenizer is not None:
+ self.tokenizer.save_pretrained(output_dir)
+
+ # Good practice: save your training arguments together with the trained model
+ torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
+
+ def store_flos(self):
+ # Storing the number of floating-point operations that went into the model
+ if self.args.local_rank != -1:
+ self.state.total_flos += (
+ distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
+ )
+ self.current_flos = 0
+ else:
+ self.state.total_flos += self.current_flos
+ self.current_flos = 0
+
+ def _sorted_checkpoints(
+ self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
+ ) -> List[str]:
+ ordering_and_checkpoint_path = []
+
+ glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
+
+ for path in glob_checkpoints:
+ if use_mtime:
+ ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
+ else:
+ regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
+ if regex_match is not None and regex_match.groups() is not None:
+ ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
+
+ checkpoints_sorted = sorted(ordering_and_checkpoint_path)
+ checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
+ # Make sure we don't delete the best model.
+ if self.state.best_model_checkpoint is not None:
+ best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
+ for i in range(best_model_index, len(checkpoints_sorted) - 2):
+ checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]
+ return checkpoints_sorted
+
+ def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
+ if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
+ return
+
+ # Check if we should delete older checkpoint(s)
+ checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
+ if len(checkpoints_sorted) <= self.args.save_total_limit:
+ return
+
+ # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
+ # we don't do to allow resuming.
+ save_total_limit = self.args.save_total_limit
+ if (
+ self.state.best_model_checkpoint is not None
+ and self.args.save_total_limit == 1
+ and checkpoints_sorted[-1] != self.state.best_model_checkpoint
+ ):
+ save_total_limit = 2
+
+ number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
+ checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
+ for checkpoint in checkpoints_to_be_deleted:
+ logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
+ shutil.rmtree(checkpoint, ignore_errors=True)
+
+ def evaluate(
+ self,
+ eval_dataset: Optional[Dataset] = None,
+ ignore_keys: Optional[List[str]] = None,
+ metric_key_prefix: str = "eval",
+ ) -> Dict[str, float]:
+ """
+ Run evaluation and returns metrics.
+ The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
+ (pass it to the init `compute_metrics` argument).
+ You can also subclass and override this method to inject custom behavior.
+ Args:
+ eval_dataset (`Dataset`, *optional*):
+ Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns
+ not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
+ method.
+ ignore_keys (`Lst[str]`, *optional*):
+ A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+ gathering predictions.
+ metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
+ An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
+ "eval_bleu" if the prefix is "eval" (default)
+ Returns:
+ A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
+ dictionary also contains the epoch number which comes from the training state.
+ """
+ # memory metrics - must set up as early as possible
+ self._memory_tracker.start()
+
+ eval_dataloader = self.get_eval_dataloader(eval_dataset)
+ start_time = time.time()
+
+ eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
+ output = eval_loop(
+ eval_dataloader,
+ description="Evaluation",
+ # No point gathering the predictions if there are no metrics, otherwise we defer to
+ # self.args.prediction_loss_only
+ prediction_loss_only=True if self.compute_metrics is None else None,
+ ignore_keys=ignore_keys,
+ metric_key_prefix=metric_key_prefix,
+ )
+
+ total_batch_size = self.args.eval_batch_size * self.args.world_size
+ if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
+ start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
+ output.metrics.update(
+ speed_metrics(
+ metric_key_prefix,
+ start_time,
+ num_samples=output.num_samples,
+ num_steps=math.ceil(output.num_samples / total_batch_size),
+ )
+ )
+
+ self.log(output.metrics)
+
+ if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
+ # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
+ xm.master_print(met.metrics_report())
+
+ self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
+
+ self._memory_tracker.stop_and_update_metrics(output.metrics)
+
+ return output.metrics
+
+ def predict(
+ self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test"
+ ) -> PredictionOutput:
+ """
+ Run prediction and returns predictions and potential metrics.
+ Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
+ will also return metrics, like in `evaluate()`.
+ Args:
+ test_dataset (`Dataset`):
+ Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the
+ `model.forward()` method are automatically removed. Has to implement the method `__len__`
+ ignore_keys (`Lst[str]`, *optional*):
+ A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+ gathering predictions.
+ metric_key_prefix (`str`, *optional*, defaults to `"test"`):
+ An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
+ "test_bleu" if the prefix is "test" (default)
+
+ If your predictions or labels have different sequence length (for instance because you're doing dynamic padding
+ in a token classification task) the predictions will be padded (on the right) to allow for concatenation into
+ one array. The padding index is -100.
+
+ Returns: *NamedTuple* A namedtuple with the following keys:
+ - predictions (`np.ndarray`): The predictions on `test_dataset`.
+ - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
+ - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
+ labels).
+ """
+ # memory metrics - must set up as early as possible
+ self._memory_tracker.start()
+
+ test_dataloader = self.get_test_dataloader(test_dataset)
+ start_time = time.time()
+
+ eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
+ output = eval_loop(
+ test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
+ )
+ total_batch_size = self.args.eval_batch_size * self.args.world_size
+ if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
+ start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
+ output.metrics.update(
+ speed_metrics(
+ metric_key_prefix,
+ start_time,
+ num_samples=output.num_samples,
+ num_steps=math.ceil(output.num_samples / total_batch_size),
+ )
+ )
+
+ self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics)
+ self._memory_tracker.stop_and_update_metrics(output.metrics)
+
+ return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)
+
+ def evaluation_loop(
+ self,
+ dataloader: DataLoader,
+ description: str,
+ prediction_loss_only: Optional[bool] = None,
+ ignore_keys: Optional[List[str]] = None,
+ metric_key_prefix: str = "eval",
+ ) -> EvalLoopOutput:
+ """
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
+ Works both with or without labels.
+ """
+ args = self.args
+
+ prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
+
+ # if eval is called w/o train init deepspeed here
+ if args.deepspeed and not self.deepspeed:
+ # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
+ # from the checkpoint eventually
+ deepspeed_engine, _, _ = deepspeed_init(
+ self, num_training_steps=0, resume_from_checkpoint=None, inference=True
+ )
+ self.model = deepspeed_engine.module
+ self.model_wrapped = deepspeed_engine
+ self.deepspeed = deepspeed_engine
+
+ model = self._wrap_model(self.model, training=False, dataloader=dataloader)
+
+ # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
+ # while ``train`` is running, cast it to the right dtype first and then put on device
+ if not self.is_in_train:
+ if args.fp16_full_eval:
+ model = model.to(dtype=torch.float16, device=args.device)
+ elif args.bf16_full_eval:
+ model = model.to(dtype=torch.bfloat16, device=args.device)
+
+ batch_size = self.args.eval_batch_size
+
+ logger.info(f"***** Running {description} *****")
+ if has_length(dataloader):
+ logger.info(f" Num examples = {self.num_examples(dataloader)}")
+ else:
+ logger.info(" Num examples: Unknown")
+ logger.info(f" Batch size = {batch_size}")
+
+ model.eval()
+
+ self.callback_handler.eval_dataloader = dataloader
+ # Do this before wrapping.
+ eval_dataset = getattr(dataloader, "dataset", None)
+
+ if is_torch_tpu_available():
+ dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
+
+ if args.past_index >= 0:
+ self._past = None
+
+ # Initialize containers
+ # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
+ losses_host = None
+ preds_host = None
+ labels_host = None
+ inputs_host = None
+
+ # losses/preds/labels on CPU (final containers)
+ all_losses = None
+ all_preds = None
+ all_labels = None
+ all_inputs = None
+ # Will be useful when we have an iterable dataset so don't know its length.
+
+ observed_num_examples = 0
+ # Main evaluation loop
+ for step, inputs in enumerate(dataloader):
+ # Update the observed num examples
+ observed_batch_size = find_batch_size(inputs)
+ if observed_batch_size is not None:
+ observed_num_examples += observed_batch_size
+ # For batch samplers, batch_size is not known by the dataloader in advance.
+ if batch_size is None:
+ batch_size = observed_batch_size
+
+ # Prediction step
+ loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
+ inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
+
+ if is_torch_tpu_available():
+ xm.mark_step()
+
+ # Update containers on host
+ if loss is not None:
+ losses = self._nested_gather(loss.repeat(batch_size))
+ losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
+ if labels is not None:
+ labels = self._pad_across_processes(labels)
+ labels = self._nested_gather(labels)
+ labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
+ if inputs_decode is not None:
+ inputs_decode = self._pad_across_processes(inputs_decode)
+ inputs_decode = self._nested_gather(inputs_decode)
+ inputs_host = (
+ inputs_decode
+ if inputs_host is None
+ else nested_concat(inputs_host, inputs_decode, padding_index=-100)
+ )
+ if logits is not None:
+ logits = self._pad_across_processes(logits)
+ logits = self._nested_gather(logits)
+ if self.preprocess_logits_for_metrics is not None:
+ logits = self.preprocess_logits_for_metrics(logits, labels)
+ preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
+ self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
+
+ # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
+ if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
+ if losses_host is not None:
+ losses = nested_numpify(losses_host)
+ all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
+ if preds_host is not None:
+ logits = nested_numpify(preds_host)
+ all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
+ if inputs_host is not None:
+ inputs_decode = nested_numpify(inputs_host)
+ all_inputs = (
+ inputs_decode
+ if all_inputs is None
+ else nested_concat(all_inputs, inputs_decode, padding_index=-100)
+ )
+ if labels_host is not None:
+ labels = nested_numpify(labels_host)
+ all_labels = (
+ labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
+ )
+
+ # Set back to None to begin a new accumulation
+ losses_host, preds_host, inputs_host, labels_host = None, None, None, None
+
+ if args.past_index and hasattr(self, "_past"):
+ # Clean the state at the end of the evaluation loop
+ delattr(self, "_past")
+
+ # Gather all remaining tensors and put them back on the CPU
+ if losses_host is not None:
+ losses = nested_numpify(losses_host)
+ all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
+ if preds_host is not None:
+ logits = nested_numpify(preds_host)
+ all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
+ if inputs_host is not None:
+ inputs_decode = nested_numpify(inputs_host)
+ all_inputs = (
+ inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
+ )
+ if labels_host is not None:
+ labels = nested_numpify(labels_host)
+ all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
+
+ # Number of samples
+ if has_length(eval_dataset):
+ num_samples = len(eval_dataset)
+ # The instance check is weird and does not actually check for the type, but whether the dataset has the right
+ # methods. Therefore we need to make sure it also has the attribute.
+ elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
+ num_samples = eval_dataset.num_examples
+ else:
+ if has_length(dataloader):
+ num_samples = self.num_examples(dataloader)
+ else: # both len(dataloader.dataset) and len(dataloader) fail
+ num_samples = observed_num_examples
+ if num_samples == 0 and observed_num_examples > 0:
+ num_samples = observed_num_examples
+
+ # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
+ # samplers has been rounded to a multiple of batch_size, so we truncate.
+ if all_losses is not None:
+ all_losses = all_losses[:num_samples]
+ if all_preds is not None:
+ all_preds = nested_truncate(all_preds, num_samples)
+ if all_labels is not None:
+ all_labels = nested_truncate(all_labels, num_samples)
+ if all_inputs is not None:
+ all_inputs = nested_truncate(all_inputs, num_samples)
+
+ # Metrics!
+ if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
+ if args.include_inputs_for_metrics:
+ metrics = self.compute_metrics(
+ EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
+ )
+ else:
+ metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
+ else:
+ metrics = {}
+
+ # To be JSON-serializable, we need to remove numpy types or zero-d tensors
+ metrics = denumpify_detensorize(metrics)
+
+ if all_losses is not None:
+ metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
+ if hasattr(self, "jit_compilation_time"):
+ metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
+
+ # Prefix all keys with metric_key_prefix + '_'
+ for key in list(metrics.keys()):
+ if not key.startswith(f"{metric_key_prefix}_"):
+ metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
+
+ return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
+
+ def _nested_gather(self, tensors, name=None):
+ """
+ Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
+ concatenating them to `gathered`
+ """
+ if tensors is None:
+ return
+ if is_torch_tpu_available():
+ if name is None:
+ name = "nested_gather"
+ tensors = nested_xla_mesh_reduce(tensors, name)
+ elif is_sagemaker_mp_enabled():
+ tensors = smp_gather(tensors)
+ elif self.args.local_rank != -1:
+ tensors = distributed_concat(tensors)
+ return tensors
+
+ # Copied from Accelerate.
+ def _pad_across_processes(self, tensor, pad_index=-100):
+ """
+ Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so
+ they can safely be gathered.
+ """
+ if isinstance(tensor, (list, tuple)):
+ return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor)
+ elif isinstance(tensor, dict):
+ return type(tensor)({k: self._pad_across_processes(v, pad_index=pad_index) for k, v in tensor.items()})
+ elif not isinstance(tensor, torch.Tensor):
+ raise TypeError(
+ f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
+ )
+
+ if len(tensor.shape) < 2:
+ return tensor
+ # Gather all sizes
+ size = torch.tensor(tensor.shape, device=tensor.device)[None]
+ sizes = self._nested_gather(size).cpu()
+
+ max_size = max(s[1] for s in sizes)
+ # When extracting XLA graphs for compilation, max_size is 0,
+ # so use inequality to avoid errors.
+ if tensor.shape[1] >= max_size:
+ return tensor
+
+ # Then pad to the maximum size
+ old_size = tensor.shape
+ new_size = list(old_size)
+ new_size[1] = max_size
+ new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index
+ new_tensor[:, : old_size[1]] = tensor
+ return new_tensor
+
+ def prediction_step(
+ self,
+ model: nn.Module,
+ inputs: Dict[str, Union[torch.Tensor, Any]],
+ prediction_loss_only: bool,
+ ignore_keys: Optional[List[str]] = None,
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
+ """
+ Perform an evaluation step on `model` using `inputs`.
+ Subclass and override to inject custom behavior.
+ Args:
+ model (`nn.Module`):
+ The model to evaluate.
+ inputs (`Dict[str, Union[torch.Tensor, Any]]`):
+ The inputs and targets of the model.
+ The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
+ argument `labels`. Check your model's documentation for all accepted arguments.
+ prediction_loss_only (`bool`):
+ Whether or not to return the loss only.
+ ignore_keys (`Lst[str]`, *optional*):
+ A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+ gathering predictions.
+ Return:
+ Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
+ logits and labels (each being optional).
+ """
+ has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
+ # For CLIP-like models capable of returning loss values.
+ # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
+ # is `True` in `model.forward`.
+ return_loss = inputs.get("return_loss", None)
+ if return_loss is None:
+ return_loss = self.can_return_loss
+ loss_without_labels = True if len(self.label_names) == 0 and return_loss else False
+
+ inputs = self._prepare_inputs(inputs)
+ if ignore_keys is None:
+ if hasattr(self.model, "config"):
+ ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
+ else:
+ ignore_keys = []
+
+ # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
+ if has_labels or loss_without_labels:
+ labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
+ if len(labels) == 1:
+ labels = labels[0]
+ else:
+ labels = None
+
+ with torch.no_grad():
+ if is_sagemaker_mp_enabled():
+ raw_outputs = smp_forward_only(model, inputs)
+ if has_labels or loss_without_labels:
+ if isinstance(raw_outputs, dict):
+ loss_mb = raw_outputs["loss"]
+ logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
+ else:
+ loss_mb = raw_outputs[0]
+ logits_mb = raw_outputs[1:]
+
+ loss = loss_mb.reduce_mean().detach().cpu()
+ logits = smp_nested_concat(logits_mb)
+ else:
+ loss = None
+ if isinstance(raw_outputs, dict):
+ logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
+ else:
+ logits_mb = raw_outputs
+ logits = smp_nested_concat(logits_mb)
+ else:
+ if has_labels or loss_without_labels:
+ with self.compute_loss_context_manager():
+ loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
+ loss = loss.mean().detach()
+
+ if isinstance(outputs, dict):
+ logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
+ else:
+ logits = outputs[1:]
+ else:
+ loss = None
+ with self.compute_loss_context_manager():
+ outputs = model(**inputs)
+ if isinstance(outputs, dict):
+ logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
+ else:
+ logits = outputs
+ # TODO: this needs to be fixed and made cleaner later.
+ if self.args.past_index >= 0:
+ self._past = outputs[self.args.past_index - 1]
+
+ if prediction_loss_only:
+ return (loss, None, None)
+
+ logits = nested_detach(logits)
+ if len(logits) == 1:
+ logits = logits[0]
+
+ return (loss, logits, labels)
+
+ def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
+ """
+ For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point
+ operations for every backward + forward pass. If using another model, either implement such a method in the
+ model or subclass and override this method.
+ Args:
+ inputs (`Dict[str, Union[torch.Tensor, Any]]`):
+ The inputs and targets of the model.
+ Returns:
+ `int`: The number of floating-point operations.
+ """
+ if hasattr(self.model, "floating_point_ops"):
+ return self.model.floating_point_ops(inputs)
+ else:
+ return 0
+
+ def init_git_repo(self, at_init: bool = False):
+ """
+ Initializes a git repo in `self.args.hub_model_id`.
+ Args:
+ at_init (`bool`, *optional*, defaults to `False`):
+ Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is
+ `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped
+ out.
+ """
+ if not self.is_world_process_zero():
+ return
+ if self.args.hub_model_id is None:
+ repo_name = Path(self.args.output_dir).absolute().name
+ else:
+ repo_name = self.args.hub_model_id
+ if "/" not in repo_name:
+ repo_name = get_full_repo_name(repo_name, token=self.args.hub_token)
+
+ # Make sure the repo exists.
+ create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True)
+ try:
+ self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token)
+ except EnvironmentError:
+ if self.args.overwrite_output_dir and at_init:
+ # Try again after wiping output_dir
+ shutil.rmtree(self.args.output_dir)
+ self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token)
+ else:
+ raise
+
+ self.repo.git_pull()
+
+ # By default, ignore the checkpoint folders
+ if (
+ not os.path.exists(os.path.join(self.args.output_dir, ".gitignore"))
+ and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS
+ ):
+ with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
+ writer.writelines(["checkpoint-*/"])
+
+ # Add "*.sagemaker" to .gitignore if using SageMaker
+ if os.environ.get("SM_TRAINING_ENV"):
+ self._add_sm_patterns_to_gitignore()
+
+ self.push_in_progress = None
+
+ def create_model_card(
+ self,
+ language: Optional[str] = None,
+ license: Optional[str] = None,
+ tags: Union[str, List[str], None] = None,
+ model_name: Optional[str] = None,
+ finetuned_from: Optional[str] = None,
+ tasks: Union[str, List[str], None] = None,
+ dataset_tags: Union[str, List[str], None] = None,
+ dataset: Union[str, List[str], None] = None,
+ dataset_args: Union[str, List[str], None] = None,
+ ):
+ """
+ Creates a draft of a model card using the information available to the `Trainer`.
+ Args:
+ language (`str`, *optional*):
+ The language of the model (if applicable)
+ license (`str`, *optional*):
+ The license of the model. Will default to the license of the pretrained model used, if the original
+ model given to the `Trainer` comes from a repo on the Hub.
+ tags (`str` or `List[str]`, *optional*):
+ Some tags to be included in the metadata of the model card.
+ model_name (`str`, *optional*):
+ The name of the model.
+ finetuned_from (`str`, *optional*):
+ The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo
+ of the original model given to the `Trainer` (if it comes from the Hub).
+ tasks (`str` or `List[str]`, *optional*):
+ One or several task identifiers, to be included in the metadata of the model card.
+ dataset_tags (`str` or `List[str]`, *optional*):
+ One or several dataset tags, to be included in the metadata of the model card.
+ dataset (`str` or `List[str]`, *optional*):
+ One or several dataset identifiers, to be included in the metadata of the model card.
+ dataset_args (`str` or `List[str]`, *optional*):
+ One or several dataset arguments, to be included in the metadata of the model card.
+ """
+ if not self.is_world_process_zero():
+ return
+
+ training_summary = TrainingSummary.from_trainer(
+ self,
+ language=language,
+ license=license,
+ tags=tags,
+ model_name=model_name,
+ finetuned_from=finetuned_from,
+ tasks=tasks,
+ dataset_tags=dataset_tags,
+ dataset=dataset,
+ dataset_args=dataset_args,
+ )
+ model_card = training_summary.to_model_card()
+ with open(os.path.join(self.args.output_dir, "README.md"), "w") as f:
+ f.write(model_card)
+
+ def _push_from_checkpoint(self, checkpoint_folder):
+ # Only push from one node.
+ if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
+ return
+ # If we haven't finished the last push, we don't do this one.
+ if self.push_in_progress is not None and not self.push_in_progress.is_done:
+ return
+
+ output_dir = self.args.output_dir
+ # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
+ modeling_files = [CONFIG_NAME, WEIGHTS_NAME]
+ for modeling_file in modeling_files:
+ if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
+ shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
+ # Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure.
+ if self.tokenizer is not None:
+ self.tokenizer.save_pretrained(output_dir)
+ # Same for the training arguments
+ torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
+
+ try:
+ if self.args.hub_strategy == HubStrategy.CHECKPOINT:
+ # Temporarily move the checkpoint just saved for the push
+ tmp_checkpoint = os.path.join(output_dir, "last-checkpoint")
+ # We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a
+ # subfolder.
+ if os.path.isdir(tmp_checkpoint):
+ shutil.rmtree(tmp_checkpoint)
+ shutil.move(checkpoint_folder, tmp_checkpoint)
+
+ if self.args.save_strategy == IntervalStrategy.STEPS:
+ commit_message = f"Training in progress, step {self.state.global_step}"
+ else:
+ commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
+ _, self.push_in_progress = self.repo.push_to_hub(
+ commit_message=commit_message, blocking=False, auto_lfs_prune=True
+ )
+ finally:
+ if self.args.hub_strategy == HubStrategy.CHECKPOINT:
+ # Move back the checkpoint to its place
+ shutil.move(tmp_checkpoint, checkpoint_folder)
+
+ def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
+ """
+ Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
+ Parameters:
+ commit_message (`str`, *optional*, defaults to `"End of training"`):
+ Message to commit while pushing.
+ blocking (`bool`, *optional*, defaults to `True`):
+ Whether the function should return only when the `git push` has finished.
+ kwargs:
+ Additional keyword arguments passed along to [`~Trainer.create_model_card`].
+ Returns:
+ The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of
+ the commit and an object to track the progress of the commit if `blocking=True`
+ """
+ # If a user calls manually `push_to_hub` with `self.args.push_to_hub = False`, we try to create the repo but
+ # it might fail.
+ if not hasattr(self, "repo"):
+ self.init_git_repo()
+
+ model_name = kwargs.pop("model_name", None)
+ if model_name is None and self.args.should_save:
+ if self.args.hub_model_id is None:
+ model_name = Path(self.args.output_dir).name
+ else:
+ model_name = self.args.hub_model_id.split("/")[-1]
+
+ # Needs to be executed on all processes for TPU training, but will only save on the processed determined by
+ # self.args.should_save.
+ self.save_model(_internal_call=True)
+
+ # Only push from one node.
+ if not self.is_world_process_zero():
+ return
+
+ # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
+ if blocking and self.push_in_progress is not None and not self.push_in_progress.is_done:
+ self.push_in_progress._process.kill()
+ self.push_in_progress = None
+
+ git_head_commit_url = self.repo.push_to_hub(
+ commit_message=commit_message, blocking=blocking, auto_lfs_prune=True
+ )
+ # push separately the model card to be independant from the rest of the model
+ if self.args.should_save:
+ self.create_model_card(model_name=model_name, **kwargs)
+ try:
+ self.repo.push_to_hub(
+ commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True
+ )
+ except EnvironmentError as exc:
+ logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")
+
+ return git_head_commit_url
+
+ #
+ # Deprecated code
+ #
+
+ def prediction_loop(
+ self,
+ dataloader: DataLoader,
+ description: str,
+ prediction_loss_only: Optional[bool] = None,
+ ignore_keys: Optional[List[str]] = None,
+ metric_key_prefix: str = "eval",
+ ) -> EvalLoopOutput:
+ """
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
+ Works both with or without labels.
+ """
+ args = self.args
+
+ if not has_length(dataloader):
+ raise ValueError("dataloader must implement a working __len__")
+
+ prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
+
+ # if eval is called w/o train init deepspeed here
+ if args.deepspeed and not self.deepspeed:
+ # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
+ # from the checkpoint eventually
+ deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None)
+ self.model = deepspeed_engine.module
+ self.model_wrapped = deepspeed_engine
+ self.deepspeed = deepspeed_engine
+ # XXX: we don't need optim/sched for inference, but this needs to be sorted out, since
+ # for example the Z3-optimizer is a must for zero3 to work even for inference - what we
+ # don't need is the deepspeed basic optimizer which is self.optimizer.optimizer
+ deepspeed_engine.optimizer.optimizer = None
+ deepspeed_engine.lr_scheduler = None
+
+ model = self._wrap_model(self.model, training=False, dataloader=dataloader)
+
+ # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
+ # while ``train`` is running, cast it to the right dtype first and then put on device
+ if not self.is_in_train:
+ if args.fp16_full_eval:
+ model = model.to(dtype=torch.float16, device=args.device)
+ elif args.bf16_full_eval:
+ model = model.to(dtype=torch.bfloat16, device=args.device)
+
+ batch_size = dataloader.batch_size
+ num_examples = self.num_examples(dataloader)
+ logger.info(f"***** Running {description} *****")
+ logger.info(f" Num examples = {num_examples}")
+ logger.info(f" Batch size = {batch_size}")
+ losses_host: torch.Tensor = None
+ preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
+ labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
+ inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None
+
+ world_size = max(1, args.world_size)
+
+ eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
+ if not prediction_loss_only:
+ # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
+ # a batch size to the sampler)
+ make_multiple_of = None
+ if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler):
+ make_multiple_of = dataloader.sampler.batch_size
+ preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
+ labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
+ inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
+
+ model.eval()
+
+ if is_torch_tpu_available():
+ dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
+
+ if args.past_index >= 0:
+ self._past = None
+
+ self.callback_handler.eval_dataloader = dataloader
+
+ for step, inputs in enumerate(dataloader):
+ loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
+ inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
+
+ if loss is not None:
+ losses = loss.repeat(batch_size)
+ losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
+ if logits is not None:
+ preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
+ if labels is not None:
+ labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
+ if inputs_decode is not None:
+ inputs_host = (
+ inputs_decode
+ if inputs_host is None
+ else nested_concat(inputs_host, inputs_decode, padding_index=-100)
+ )
+ self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
+
+ # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
+ if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
+ eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
+ if not prediction_loss_only:
+ preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
+ labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
+ inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
+
+ # Set back to None to begin a new accumulation
+ losses_host, preds_host, labels_host, inputs_host = None, None, None, None
+
+ if args.past_index and hasattr(self, "_past"):
+ # Clean the state at the end of the evaluation loop
+ delattr(self, "_past")
+
+ # Gather all remaining tensors and put them back on the CPU
+ eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
+ if not prediction_loss_only:
+ preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
+ labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
+ inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
+
+ eval_loss = eval_losses_gatherer.finalize()
+ preds = preds_gatherer.finalize() if not prediction_loss_only else None
+ label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
+ inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None
+
+ if self.compute_metrics is not None and preds is not None and label_ids is not None:
+ if args.include_inputs_for_metrics:
+ metrics = self.compute_metrics(
+ EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids)
+ )
+ else:
+ metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
+ else:
+ metrics = {}
+
+ # To be JSON-serializable, we need to remove numpy types or zero-d tensors
+ metrics = denumpify_detensorize(metrics)
+
+ if eval_loss is not None:
+ metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()
+
+ # Prefix all keys with metric_key_prefix + '_'
+ for key in list(metrics.keys()):
+ if not key.startswith(f"{metric_key_prefix}_"):
+ metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
+
+ return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples)
+
+ def _gather_and_numpify(self, tensors, name):
+ """
+ Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
+ concatenating them to `gathered`
+ """
+ if tensors is None:
+ return
+ if is_torch_tpu_available():
+ tensors = nested_xla_mesh_reduce(tensors, name)
+ elif is_sagemaker_mp_enabled():
+ tensors = smp_gather(tensors)
+ elif self.args.local_rank != -1:
+ tensors = distributed_concat(tensors)
+
+ return nested_numpify(tensors)
+
+ def _add_sm_patterns_to_gitignore(self) -> None:
+ """Add SageMaker Checkpointing patterns to .gitignore file."""
+ # Make sure we only do this on the main process
+ if not self.is_world_process_zero():
+ return
+
+ patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"]
+
+ # Get current .gitignore content
+ if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")):
+ with open(os.path.join(self.repo.local_dir, ".gitignore"), "r") as f:
+ current_content = f.read()
+ else:
+ current_content = ""
+
+ # Add the patterns to .gitignore
+ content = current_content
+ for pattern in patterns:
+ if pattern not in content:
+ if content.endswith("\n"):
+ content += pattern
+ else:
+ content += f"\n{pattern}"
+
+ # Write the .gitignore file if it has changed
+ if content != current_content:
+ with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f:
+ logger.debug(f"Writing .gitignore file. Content: {content}")
+ f.write(content)
+
+ self.repo.git_add(".gitignore")
+
+ # avoid race condition with git status
+ time.sleep(0.5)
+
+ if not self.repo.is_repo_clean():
+ self.repo.git_commit("Add *.sagemaker patterns to .gitignore.")
+ self.repo.git_push()