Spaces:
Runtime error
Runtime error
from typing import Any, Dict, Union | |
import torch | |
from torch import nn | |
from transformers import Trainer as HFTrainer | |
from transformers.file_utils import is_apex_available | |
if is_apex_available(): | |
from apex import amp | |
from utils import label_smoothed_nll_loss | |
class Trainer(HFTrainer): | |
def __init__(self, label_smoothing: float = 0, **kwargs): | |
super().__init__(**kwargs) | |
self.label_smoothing = label_smoothing | |
# override to support label smoothing | |
def _training_step( | |
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], optimizer: torch.optim.Optimizer | |
) -> float: | |
model.train() | |
for k, v in inputs.items(): | |
if isinstance(v, torch.Tensor): | |
inputs[k] = v.to(self.args.device) | |
# Our model outputs do not work with DataParallel, so forcing return tuple. | |
if isinstance(model, nn.DataParallel): | |
inputs["return_tuple"] = True | |
if self.label_smoothing == 0: | |
outputs = model(**inputs) | |
loss = outputs[0] # model outputs are always tuple in transformers (see doc) | |
else: | |
labels = inputs.pop("labels") | |
labels[labels == -100] = model.config.pad_token_id | |
outputs = model(**inputs) | |
lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1) | |
loss, nll_loss = label_smoothed_nll_loss( | |
lprobs, labels, self.label_smoothing, ignore_index=model.config.pad_token_id | |
) | |
if self.args.n_gpu > 1: | |
loss = loss.mean() # mean() to average on multi-gpu parallel training | |
if self.args.gradient_accumulation_steps > 1: | |
loss = loss / self.args.gradient_accumulation_steps | |
if self.args.fp16: | |
with amp.scale_loss(loss, optimizer) as scaled_loss: | |
scaled_loss.backward() | |
else: | |
loss.backward() | |
return loss.item() | |