|
import logging |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from typeguard import check_argument_types |
|
|
|
|
|
class CTC(torch.nn.Module): |
|
"""CTC module. |
|
|
|
Args: |
|
odim: dimension of outputs |
|
encoder_output_sizse: number of encoder projection units |
|
dropout_rate: dropout rate (0.0 ~ 1.0) |
|
ctc_type: builtin or warpctc |
|
reduce: reduce the CTC loss into a scalar |
|
""" |
|
|
|
def __init__( |
|
self, |
|
odim: int, |
|
encoder_output_sizse: int, |
|
dropout_rate: float = 0.0, |
|
ctc_type: str = "builtin", |
|
reduce: bool = True, |
|
ignore_nan_grad: bool = False, |
|
): |
|
assert check_argument_types() |
|
super().__init__() |
|
eprojs = encoder_output_sizse |
|
self.dropout_rate = dropout_rate |
|
self.ctc_lo = torch.nn.Linear(eprojs, odim) |
|
self.ctc_type = ctc_type |
|
self.ignore_nan_grad = ignore_nan_grad |
|
|
|
if self.ctc_type == "builtin": |
|
self.ctc_loss = torch.nn.CTCLoss(reduction="none") |
|
elif self.ctc_type == "warpctc": |
|
import warpctc_pytorch as warp_ctc |
|
|
|
if ignore_nan_grad: |
|
raise NotImplementedError( |
|
"ignore_nan_grad option is not supported for warp_ctc" |
|
) |
|
self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce) |
|
else: |
|
raise ValueError( |
|
f'ctc_type must be "builtin" or "warpctc": {self.ctc_type}' |
|
) |
|
|
|
self.reduce = reduce |
|
|
|
def loss_fn(self, th_pred, th_target, th_ilen, th_olen) -> torch.Tensor: |
|
if self.ctc_type == "builtin": |
|
th_pred = th_pred.log_softmax(2) |
|
loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen) |
|
|
|
if loss.requires_grad and self.ignore_nan_grad: |
|
|
|
ctc_grad = loss.grad_fn(torch.ones_like(loss)) |
|
ctc_grad = ctc_grad.sum([0, 2]) |
|
indices = torch.isfinite(ctc_grad) |
|
size = indices.long().sum() |
|
if size == 0: |
|
|
|
logging.warning( |
|
"All samples in this mini-batch got nan grad." |
|
" Returning nan value instead of CTC loss" |
|
) |
|
elif size != th_pred.size(1): |
|
logging.warning( |
|
f"{th_pred.size(1) - size}/{th_pred.size(1)}" |
|
" samples got nan grad." |
|
" These were ignored for CTC loss." |
|
) |
|
|
|
|
|
target_mask = torch.full( |
|
[th_target.size(0)], |
|
1, |
|
dtype=torch.bool, |
|
device=th_target.device, |
|
) |
|
s = 0 |
|
for ind, le in enumerate(th_olen): |
|
if not indices[ind]: |
|
target_mask[s : s + le] = 0 |
|
s += le |
|
|
|
|
|
loss = self.ctc_loss( |
|
th_pred[:, indices, :], |
|
th_target[target_mask], |
|
th_ilen[indices], |
|
th_olen[indices], |
|
) |
|
else: |
|
size = th_pred.size(1) |
|
|
|
if self.reduce: |
|
|
|
loss = loss.sum() / size |
|
else: |
|
loss = loss / size |
|
return loss |
|
|
|
elif self.ctc_type == "warpctc": |
|
|
|
th_pred = th_pred.to(dtype=torch.float32) |
|
|
|
th_target = th_target.cpu().int() |
|
th_ilen = th_ilen.cpu().int() |
|
th_olen = th_olen.cpu().int() |
|
loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen) |
|
if self.reduce: |
|
|
|
|
|
|
|
loss = loss.sum() |
|
return loss |
|
else: |
|
raise NotImplementedError |
|
|
|
def forward(self, hs_pad, hlens, ys_pad, ys_lens): |
|
"""Calculate CTC loss. |
|
|
|
Args: |
|
hs_pad: batch of padded hidden state sequences (B, Tmax, D) |
|
hlens: batch of lengths of hidden state sequences (B) |
|
ys_pad: batch of padded character id sequence tensor (B, Lmax) |
|
ys_lens: batch of lengths of character sequence (B) |
|
""" |
|
|
|
ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate)) |
|
|
|
ys_hat = ys_hat.transpose(0, 1) |
|
|
|
|
|
ys_true = torch.cat([ys_pad[i, :l] for i, l in enumerate(ys_lens)]) |
|
|
|
loss = self.loss_fn(ys_hat, ys_true, hlens, ys_lens).to( |
|
device=hs_pad.device, dtype=hs_pad.dtype |
|
) |
|
|
|
return loss |
|
|
|
def log_softmax(self, hs_pad): |
|
"""log_softmax of frame activations |
|
|
|
Args: |
|
Tensor hs_pad: 3d tensor (B, Tmax, eprojs) |
|
Returns: |
|
torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim) |
|
""" |
|
return F.log_softmax(self.ctc_lo(hs_pad), dim=2) |
|
|
|
def argmax(self, hs_pad): |
|
"""argmax of frame activations |
|
|
|
Args: |
|
torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) |
|
Returns: |
|
torch.Tensor: argmax applied 2d tensor (B, Tmax) |
|
""" |
|
return torch.argmax(self.ctc_lo(hs_pad), dim=2) |
|
|