|
from typing import Dict |
|
from typing import Tuple |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from typeguard import check_argument_types |
|
|
|
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask |
|
from espnet2.lm.abs_model import AbsLM |
|
from espnet2.torch_utils.device_funcs import force_gatherable |
|
from espnet2.train.abs_espnet_model import AbsESPnetModel |
|
|
|
|
|
class ESPnetLanguageModel(AbsESPnetModel): |
|
def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0): |
|
assert check_argument_types() |
|
super().__init__() |
|
self.lm = lm |
|
self.sos = vocab_size - 1 |
|
self.eos = vocab_size - 1 |
|
|
|
|
|
self.ignore_id = ignore_id |
|
|
|
def nll( |
|
self, text: torch.Tensor, text_lengths: torch.Tensor |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
batch_size = text.size(0) |
|
|
|
text = text[:, : text_lengths.max()] |
|
|
|
|
|
|
|
x = F.pad(text, [1, 0], "constant", self.eos) |
|
t = F.pad(text, [0, 1], "constant", self.ignore_id) |
|
for i, l in enumerate(text_lengths): |
|
t[i, l] = self.sos |
|
x_lengths = text_lengths + 1 |
|
|
|
|
|
|
|
y, _ = self.lm(x, None) |
|
|
|
|
|
|
|
nll = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none") |
|
|
|
nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0) |
|
|
|
nll = nll.view(batch_size, -1) |
|
return nll, x_lengths |
|
|
|
def forward( |
|
self, text: torch.Tensor, text_lengths: torch.Tensor |
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: |
|
nll, y_lengths = self.nll(text, text_lengths) |
|
ntokens = y_lengths.sum() |
|
loss = nll.sum() / ntokens |
|
stats = dict(loss=loss.detach()) |
|
|
|
|
|
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device) |
|
return loss, stats, weight |
|
|
|
def collect_feats( |
|
self, text: torch.Tensor, text_lengths: torch.Tensor |
|
) -> Dict[str, torch.Tensor]: |
|
return {} |
|
|