Trainer
ORTTrainer
class optimum.onnxruntime.ORTTrainer
< source >( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None args: ORTTrainingArguments = None data_collator: typing.Optional[DataCollator] = None train_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None eval_dataset: typing.Union[torch.utils.data.dataset.Dataset, typing.Dict[str, torch.utils.data.dataset.Dataset], NoneType] = None tokenizer: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None model_init: typing.Optional[typing.Callable[[], transformers.modeling_utils.PreTrainedModel]] = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict]] = None callbacks: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = None optimizers: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None) preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None )
Parameters
- model (PreTrainedModel or
torch.nn.Module
, optional) — The model to train, evaluate or use for predictions. If not provided, amodel_init
must be passed.ORTTrainer
is optimized to work with the PreTrainedModel provided by the transformers library. You can still use your own models defined astorch.nn.Module
for training with ONNX Runtime backend and inference with PyTorch backend as long as they work the same way as the 🤗 Transformers models. - args (
ORTTrainingArguments
, optional) — The arguments to tweak for training. Will default to a basic instance ofORTTrainingArguments
with theoutput_dir
set to a directory named tmp_trainer in the current directory if not provided. - data_collator (
DataCollator
, optional) — The function to use to form a batch from a list of elements oftrain_dataset
oreval_dataset
. Will default to default_data_collator if notokenizer
is provided, an instance of DataCollatorWithPadding otherwise. - train_dataset (
torch.utils.data.Dataset
ortorch.utils.data.IterableDataset
, optional) — The dataset to use for training. If it is a Dataset, columns not accepted by themodel.forward()
method are automatically removed. Note that if it’s atorch.utils.data.IterableDataset
with some randomization and you are training in a distributed fashion, your iterable dataset should either use a internal attributegenerator
that is atorch.Generator
for the randomization that must be identical on all processes (and the ORTTrainer will manually set the seed of thisgenerator
at each epoch) or have aset_epoch()
method that internally sets the seed of the RNGs used. - eval_dataset (Union[
torch.utils.data.Dataset
, Dict[str,torch.utils.data.Dataset
]), optional) — The dataset to use for evaluation. If it is a Dataset, columns not accepted by themodel.forward()
method are automatically removed. If it is a dictionary, it will evaluate on each dataset prepending the dictionary key to the metric name. - tokenizer (PreTrainedTokenizerBase, optional) — The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an interrupted training or reuse the fine-tuned model.
- model_init (
Callable[[], PreTrainedModel]
, optional) — A function that instantiates the model to be used. If provided, each call toORTTrainer.train
will start from a new instance of the model as given by this function. The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to be able to choose different architectures according to hyper parameters (such as layer count, sizes of inner layers, dropout probabilities etc). - compute_metrics (
Callable[[EvalPrediction], Dict]
, optional) — The function that will be used to compute metrics at evaluation. Must take aEvalPrediction
and return a dictionary string to metric values. - callbacks (List of
TrainerCallback
, optional) — A list of callbacks to customize the training loop. Will add those to the list of default callbacks detailed in here. If you want to remove one of the default callbacks used, use theORTTrainer.remove_callback
method. - optimizers (
Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]
, optional) — A tuple containing the optimizer and the scheduler to use. Will default to an instance ofAdamW
on your model and a scheduler given byget_linear_schedule_with_warmup
controlled byargs
. - preprocess_logits_for_metrics (
Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
, optional) — A function that preprocess the logits right before caching them at each evaluation step. Must take two tensors, the logits and the labels, and return the logits once processed as desired. The modifications made by this function will be reflected in the predictions received bycompute_metrics
. Note that the labels (second parameter) will beNone
if the dataset does not have them.
ORTTrainer is a simple but feature-complete training and eval loop for ONNX Runtime, optimized for 🤗 Transformers.
Important attributes:
- model — Always points to the core model. If using a transformers model, it will be a PreTrainedModel subclass.
- model_wrapped — Always points to the most external model in case one or more other modules wrap the
original model. This is the model that should be used for the forward pass. For example, under
DeepSpeed
, the inner model is first wrapped inORTModule
and then inDeepSpeed
and then again intorch.nn.DistributedDataParallel
. If the inner model hasn’t been wrapped, thenself.model_wrapped
is the same asself.model
. - is_model_parallel — Whether or not a model has been switched to a model parallel mode (different from data parallelism, this means some of the model layers are split on different GPUs).
- place_model_on_device — Whether or not to automatically place the model on the device - it will be set
to
False
if model parallel or deepspeed is used, or if the defaultORTTrainingArguments.place_model_on_device
is overridden to returnFalse
. - is_in_train — Whether or not a model is currently running
train
(e.g. whenevaluate
is called while intrain
)
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
ORTTrainer’s init through optimizers
, or subclass and override this method in a subclass.
get_ort_optimizer_cls_and_kwargs
< source >( args: ORTTrainingArguments )
Returns the optimizer class and optimizer parameters implemented in ONNX Runtime based on ORTTrainingArguments
.
train
< source >( resume_from_checkpoint: typing.Union[bool, str, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), typing.Dict[str, typing.Any]] = None ignore_keys_for_eval: typing.Optional[typing.List[str]] = None **kwargs )
Parameters
- resume_from_checkpoint (
str
orbool
, optional) — If astr
, local path to a saved checkpoint as saved by a previous instance ofORTTrainer
. If abool
and equalsTrue
, load the last checkpoint in args.output_dir as saved by a previous instance ofORTTrainer
. If present, training will resume from the model/optimizer/scheduler states loaded here. - trial (
optuna.Trial
orDict[str, Any]
, optional) — The trial run or the hyperparameter dictionary for hyperparameter search. - ignore_keys_for_eval (
List[str]
, optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training. - kwargs (
Dict[str, Any]
, optional) — Additional keyword arguments used to hide deprecated arguments
Main entry point for training with ONNX Runtime accelerator.
ORTSeq2SeqTrainer
class optimum.onnxruntime.ORTSeq2SeqTrainer
< source >( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None args: ORTTrainingArguments = None data_collator: typing.Optional[DataCollator] = None train_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None eval_dataset: typing.Union[torch.utils.data.dataset.Dataset, typing.Dict[str, torch.utils.data.dataset.Dataset], NoneType] = None tokenizer: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None model_init: typing.Optional[typing.Callable[[], transformers.modeling_utils.PreTrainedModel]] = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict]] = None callbacks: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = None optimizers: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None) preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None )
evaluate
< source >( eval_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'eval' **gen_kwargs )
Parameters
- eval_dataset (
Dataset
, optional) — Pass a dataset if you wish to overrideself.eval_dataset
. If it is an Dataset, columns not accepted by themodel.forward()
method are automatically removed. It must implement the__len__
method. - ignore_keys (
List[str]
, optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. - metric_key_prefix (
str
, optional, defaults to"eval"
) — An optional prefix to be used as the metrics key prefix. For example the metrics “bleu” will be named “eval_bleu” if the prefix is"eval"
(default) - max_length (
int
, optional) — The maximum target length to use when predicting with the generate method. - num_beams (
int
, optional) — Number of beams for beam search that will be used when predicting with the generate method. 1 means no beam search. - gen_kwargs —
Additional
generate
specific kwargs.
Run evaluation and returns metrics.
The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
(pass it to the init compute_metrics
argument).
You can also subclass and override this method to inject custom behavior.
predict
< source >( test_dataset: Dataset ignore_keys: typing.Optional[typing.List[str]] = None metric_key_prefix: str = 'test' **gen_kwargs )
Parameters
- test_dataset (
Dataset
) — Dataset to run the predictions on. If it is a Dataset, columns not accepted by themodel.forward()
method are automatically removed. Has to implement the method__len__
- ignore_keys (
List[str]
, optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. - metric_key_prefix (
str
, optional, defaults to"eval"
) — An optional prefix to be used as the metrics key prefix. For example the metrics “bleu” will be named “eval_bleu” if the prefix is"eval"
(default) - max_length (
int
, optional) — The maximum target length to use when predicting with the generate method. - num_beams (
int
, optional) — Number of beams for beam search that will be used when predicting with the generate method. 1 means no beam search. - gen_kwargs —
Additional
generate
specific kwargs.
Run prediction and returns predictions and potential metrics.
Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
will also return metrics, like in evaluate()
.
If your predictions or labels have different sequence lengths (for instance because you’re doing dynamic padding in a token classification task) the predictions will be padded (on the right) to allow for concatenation into one array. The padding index is -100.
Returns: NamedTuple A namedtuple with the following keys:
- predictions (
np.ndarray
): The predictions ontest_dataset
. - label_ids (
np.ndarray
, optional): The labels (if the dataset contained some). - metrics (
Dict[str, float]
, optional): The potential dictionary of metrics (if the dataset contained labels).
ORTTrainingArguments
class optimum.onnxruntime.ORTTrainingArguments
< source >( output_dir: str overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False evaluation_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 learning_rate: float = 5e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Optional[typing.Dict] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: typing.Optional[str] = 'passive' log_level_replica: typing.Optional[str] = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 500 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: bool = False fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, typing.List[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[typing.List[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[typing.List[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Optional[str] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None deepspeed: typing.Optional[str] = None label_smoothing_factor: float = 0.0 optim: typing.Optional[str] = 'adamw_hf' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Optional[typing.List[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: bool = False hub_always_push: bool = False gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Optional[dict] = None include_inputs_for_metrics: bool = False fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: typing.Optional[int] = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None dispatch_batches: typing.Optional[bool] = None split_batches: typing.Optional[bool] = False include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: float = None use_module_with_loss: typing.Optional[bool] = False save_onnx: typing.Optional[bool] = False onnx_prefix: typing.Optional[str] = None onnx_log_level: typing.Optional[str] = 'WARNING' )
Parameters
- optim (
str
ortraining_args.ORTOptimizerNames
ortransformers.training_args.OptimizerNames
, optional, defaults to"adamw_hf"
) — The optimizer to use, including optimizers in Transformers: adamw_hf, adamw_torch, adamw_apex_fused, or adafactor. And optimizers implemented by ONNX Runtime: adamw_ort_fused.
ORTSeq2SeqTrainingArguments
class optimum.onnxruntime.ORTSeq2SeqTrainingArguments
< source >( output_dir: str overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False evaluation_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 learning_rate: float = 5e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Optional[typing.Dict] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: typing.Optional[str] = 'passive' log_level_replica: typing.Optional[str] = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 500 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: bool = False fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, typing.List[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[typing.List[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[typing.List[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Optional[str] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None deepspeed: typing.Optional[str] = None label_smoothing_factor: float = 0.0 optim: typing.Optional[str] = 'adamw_hf' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Optional[typing.List[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: bool = False hub_always_push: bool = False gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Optional[dict] = None include_inputs_for_metrics: bool = False fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: typing.Optional[int] = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None dispatch_batches: typing.Optional[bool] = None split_batches: typing.Optional[bool] = False include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: float = None use_module_with_loss: typing.Optional[bool] = False save_onnx: typing.Optional[bool] = False onnx_prefix: typing.Optional[str] = None onnx_log_level: typing.Optional[str] = 'WARNING' sortish_sampler: bool = False predict_with_generate: bool = False generation_max_length: typing.Optional[int] = None generation_num_beams: typing.Optional[int] = None generation_config: typing.Union[str, pathlib.Path, transformers.generation.configuration_utils.GenerationConfig, NoneType] = None )
Parameters
- optim (
str
ortraining_args.ORTOptimizerNames
ortransformers.training_args.OptimizerNames
, optional, defaults to"adamw_hf"
) — The optimizer to use, including optimizers in Transformers: adamw_hf, adamw_torch, adamw_apex_fused, or adafactor. And optimizers implemented by ONNX Runtime: adamw_ort_fused.