File size: 5,757 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import logging
import torch
import torch.nn.functional as F
from typeguard import check_argument_types
class CTC(torch.nn.Module):
"""CTC module.
Args:
odim: dimension of outputs
encoder_output_sizse: number of encoder projection units
dropout_rate: dropout rate (0.0 ~ 1.0)
ctc_type: builtin or warpctc
reduce: reduce the CTC loss into a scalar
"""
def __init__(
self,
odim: int,
encoder_output_sizse: int,
dropout_rate: float = 0.0,
ctc_type: str = "builtin",
reduce: bool = True,
ignore_nan_grad: bool = False,
):
assert check_argument_types()
super().__init__()
eprojs = encoder_output_sizse
self.dropout_rate = dropout_rate
self.ctc_lo = torch.nn.Linear(eprojs, odim)
self.ctc_type = ctc_type
self.ignore_nan_grad = ignore_nan_grad
if self.ctc_type == "builtin":
self.ctc_loss = torch.nn.CTCLoss(reduction="none")
elif self.ctc_type == "warpctc":
import warpctc_pytorch as warp_ctc
if ignore_nan_grad:
raise NotImplementedError(
"ignore_nan_grad option is not supported for warp_ctc"
)
self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce)
else:
raise ValueError(
f'ctc_type must be "builtin" or "warpctc": {self.ctc_type}'
)
self.reduce = reduce
def loss_fn(self, th_pred, th_target, th_ilen, th_olen) -> torch.Tensor:
if self.ctc_type == "builtin":
th_pred = th_pred.log_softmax(2)
loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
if loss.requires_grad and self.ignore_nan_grad:
# ctc_grad: (L, B, O)
ctc_grad = loss.grad_fn(torch.ones_like(loss))
ctc_grad = ctc_grad.sum([0, 2])
indices = torch.isfinite(ctc_grad)
size = indices.long().sum()
if size == 0:
# Return as is
logging.warning(
"All samples in this mini-batch got nan grad."
" Returning nan value instead of CTC loss"
)
elif size != th_pred.size(1):
logging.warning(
f"{th_pred.size(1) - size}/{th_pred.size(1)}"
" samples got nan grad."
" These were ignored for CTC loss."
)
# Create mask for target
target_mask = torch.full(
[th_target.size(0)],
1,
dtype=torch.bool,
device=th_target.device,
)
s = 0
for ind, le in enumerate(th_olen):
if not indices[ind]:
target_mask[s : s + le] = 0
s += le
# Calc loss again using maksed data
loss = self.ctc_loss(
th_pred[:, indices, :],
th_target[target_mask],
th_ilen[indices],
th_olen[indices],
)
else:
size = th_pred.size(1)
if self.reduce:
# Batch-size average
loss = loss.sum() / size
else:
loss = loss / size
return loss
elif self.ctc_type == "warpctc":
# warpctc only supports float32
th_pred = th_pred.to(dtype=torch.float32)
th_target = th_target.cpu().int()
th_ilen = th_ilen.cpu().int()
th_olen = th_olen.cpu().int()
loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
if self.reduce:
# NOTE: sum() is needed to keep consistency since warpctc
# return as tensor w/ shape (1,)
# but builtin return as tensor w/o shape (scalar).
loss = loss.sum()
return loss
else:
raise NotImplementedError
def forward(self, hs_pad, hlens, ys_pad, ys_lens):
"""Calculate CTC loss.
Args:
hs_pad: batch of padded hidden state sequences (B, Tmax, D)
hlens: batch of lengths of hidden state sequences (B)
ys_pad: batch of padded character id sequence tensor (B, Lmax)
ys_lens: batch of lengths of character sequence (B)
"""
# hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab)
ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
# ys_hat: (B, L, D) -> (L, B, D)
ys_hat = ys_hat.transpose(0, 1)
# (B, L) -> (BxL,)
ys_true = torch.cat([ys_pad[i, :l] for i, l in enumerate(ys_lens)])
loss = self.loss_fn(ys_hat, ys_true, hlens, ys_lens).to(
device=hs_pad.device, dtype=hs_pad.dtype
)
return loss
def log_softmax(self, hs_pad):
"""log_softmax of frame activations
Args:
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
Returns:
torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim)
"""
return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
def argmax(self, hs_pad):
"""argmax of frame activations
Args:
torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
Returns:
torch.Tensor: argmax applied 2d tensor (B, Tmax)
"""
return torch.argmax(self.ctc_lo(hs_pad), dim=2)
|