from typing import Callable, Dict, Iterable, List from torch import nn # these functions are taken from transformers repo def grad_status(model: nn.Module) -> Iterable: return (par.requires_grad for par in model.parameters()) def freeze_params(model: nn.Module): for par in model.parameters(): par.requires_grad = False def freeze_embeds(model: nn.Module): """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" try: freeze_params(model.model.shared) for d in [model.model.encoder, model.model.decoder]: freeze_params(d.embed_positions) freeze_params(d.embed_tokens) except AttributeError: freeze_params(model.shared) for d in [model.encoder, model.decoder]: freeze_params(d.embed_tokens) def assert_not_all_frozen(model): model_grads: List[bool] = list(grad_status(model)) npars = len(model_grads) assert any(model_grads), f"none of {npars} weights require grad" def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): """From fairseq""" if target.dim() == lprobs.dim() - 1: target = target.unsqueeze(-1) nll_loss = -lprobs.gather(dim=-1, index=target) smooth_loss = -lprobs.sum(dim=-1, keepdim=True) if ignore_index is not None: pad_mask = target.eq(ignore_index) nll_loss.masked_fill_(pad_mask, 0.0) smooth_loss.masked_fill_(pad_mask, 0.0) bs = pad_mask.long().sum() else: nll_loss = nll_loss.squeeze(-1) smooth_loss = smooth_loss.squeeze(-1) bs = lprobs.shape[0] nll_loss = nll_loss.sum() # mean()? Scared to break other math. smooth_loss = smooth_loss.sum() eps_i = epsilon / lprobs.size(-1) loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss return loss / bs, nll_loss / bs