|
from typing import List, Optional, Tuple |
|
|
|
import torch |
|
from torch import nn |
|
from modules.wenet_extractor.utils.common import get_activation, get_rnn |
|
|
|
|
|
def ApplyPadding(input, padding, pad_value) -> torch.Tensor: |
|
""" |
|
Args: |
|
input: [bs, max_time_step, dim] |
|
padding: [bs, max_time_step] |
|
""" |
|
return padding * pad_value + input * (1 - padding) |
|
|
|
|
|
class PredictorBase(torch.nn.Module): |
|
|
|
|
|
def __init__(self) -> None: |
|
super().__init__() |
|
|
|
def init_state( |
|
self, batch_size: int, device: torch.device, method: str = "zero" |
|
) -> List[torch.Tensor]: |
|
_, _, _ = batch_size, method, device |
|
raise NotImplementedError("this is a base precictor") |
|
|
|
def batch_to_cache(self, cache: List[torch.Tensor]) -> List[List[torch.Tensor]]: |
|
_ = cache |
|
raise NotImplementedError("this is a base precictor") |
|
|
|
def cache_to_batch(self, cache: List[List[torch.Tensor]]) -> List[torch.Tensor]: |
|
_ = cache |
|
raise NotImplementedError("this is a base precictor") |
|
|
|
def forward( |
|
self, |
|
input: torch.Tensor, |
|
cache: Optional[List[torch.Tensor]] = None, |
|
): |
|
( |
|
_, |
|
_, |
|
) = ( |
|
input, |
|
cache, |
|
) |
|
raise NotImplementedError("this is a base precictor") |
|
|
|
def forward_step( |
|
self, input: torch.Tensor, padding: torch.Tensor, cache: List[torch.Tensor] |
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]: |
|
( |
|
_, |
|
_, |
|
_, |
|
) = ( |
|
input, |
|
padding, |
|
cache, |
|
) |
|
raise NotImplementedError("this is a base precictor") |
|
|
|
|
|
class RNNPredictor(PredictorBase): |
|
def __init__( |
|
self, |
|
voca_size: int, |
|
embed_size: int, |
|
output_size: int, |
|
embed_dropout: float, |
|
hidden_size: int, |
|
num_layers: int, |
|
bias: bool = True, |
|
rnn_type: str = "lstm", |
|
dropout: float = 0.1, |
|
) -> None: |
|
super().__init__() |
|
self.n_layers = num_layers |
|
self.hidden_size = hidden_size |
|
|
|
self.embed = nn.Embedding(voca_size, embed_size) |
|
self.dropout = nn.Dropout(embed_dropout) |
|
|
|
|
|
|
|
self.rnn = get_rnn(rnn_type=rnn_type)( |
|
input_size=embed_size, |
|
hidden_size=hidden_size, |
|
num_layers=num_layers, |
|
bias=bias, |
|
batch_first=True, |
|
dropout=dropout, |
|
) |
|
self.projection = nn.Linear(hidden_size, output_size) |
|
|
|
def forward( |
|
self, |
|
input: torch.Tensor, |
|
cache: Optional[List[torch.Tensor]] = None, |
|
) -> torch.Tensor: |
|
""" |
|
Args: |
|
input (torch.Tensor): [batch, max_time). |
|
padding (torch.Tensor): [batch, max_time] |
|
cache : rnn predictor cache[0] == state_m |
|
cache[1] == state_c |
|
Returns: |
|
output: [batch, max_time, output_size] |
|
""" |
|
|
|
|
|
embed = self.embed(input) |
|
embed = self.dropout(embed) |
|
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None |
|
if cache is None: |
|
state = self.init_state(batch_size=input.size(0), device=input.device) |
|
states = (state[0], state[1]) |
|
else: |
|
assert len(cache) == 2 |
|
states = (cache[0], cache[1]) |
|
out, (m, c) = self.rnn(embed, states) |
|
out = self.projection(out) |
|
|
|
|
|
|
|
|
|
_, _ = m, c |
|
return out |
|
|
|
def batch_to_cache(self, cache: List[torch.Tensor]) -> List[List[torch.Tensor]]: |
|
""" |
|
Args: |
|
cache: [state_m, state_c] |
|
state_ms: [1*n_layers, bs, ...] |
|
state_cs: [1*n_layers, bs, ...] |
|
Returns: |
|
new_cache: [[state_m_1, state_c_1], [state_m_2, state_c_2]...] |
|
""" |
|
assert len(cache) == 2 |
|
state_ms = cache[0] |
|
state_cs = cache[1] |
|
|
|
assert state_ms.size(1) == state_cs.size(1) |
|
|
|
new_cache: List[List[torch.Tensor]] = [] |
|
for state_m, state_c in zip( |
|
torch.split(state_ms, 1, dim=1), torch.split(state_cs, 1, dim=1) |
|
): |
|
new_cache.append([state_m, state_c]) |
|
return new_cache |
|
|
|
def cache_to_batch(self, cache: List[List[torch.Tensor]]) -> List[torch.Tensor]: |
|
""" |
|
Args: |
|
cache : [[state_m_1, state_c_1], [state_m_1, state_c_1]...] |
|
|
|
Returns: |
|
new_caceh: [state_ms, state_cs], |
|
state_ms: [1*n_layers, bs, ...] |
|
state_cs: [1*n_layers, bs, ...] |
|
""" |
|
state_ms = torch.cat([states[0] for states in cache], dim=1) |
|
state_cs = torch.cat([states[1] for states in cache], dim=1) |
|
return [state_ms, state_cs] |
|
|
|
def init_state( |
|
self, |
|
batch_size: int, |
|
device: torch.device, |
|
method: str = "zero", |
|
) -> List[torch.Tensor]: |
|
assert batch_size > 0 |
|
|
|
_ = method |
|
return [ |
|
torch.zeros(1 * self.n_layers, batch_size, self.hidden_size, device=device), |
|
torch.zeros(1 * self.n_layers, batch_size, self.hidden_size, device=device), |
|
] |
|
|
|
def forward_step( |
|
self, input: torch.Tensor, padding: torch.Tensor, cache: List[torch.Tensor] |
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]: |
|
""" |
|
Args: |
|
input (torch.Tensor): [batch_size, time_step=1] |
|
padding (torch.Tensor): [batch_size,1], 1 is padding value |
|
cache : rnn predictor cache[0] == state_m |
|
cache[1] == state_c |
|
""" |
|
assert len(cache) == 2 |
|
state_m, state_c = cache[0], cache[1] |
|
embed = self.embed(input) |
|
embed = self.dropout(embed) |
|
out, (m, c) = self.rnn(embed, (state_m, state_c)) |
|
|
|
out = self.projection(out) |
|
m = ApplyPadding(m, padding.unsqueeze(0), state_m) |
|
c = ApplyPadding(c, padding.unsqueeze(0), state_c) |
|
|
|
return (out, [m, c]) |
|
|
|
|
|
class EmbeddingPredictor(PredictorBase): |
|
"""Embedding predictor |
|
|
|
Described in: |
|
https://arxiv.org/pdf/2109.07513.pdf |
|
|
|
embed-> proj -> layer norm -> swish |
|
""" |
|
|
|
def __init__( |
|
self, |
|
voca_size: int, |
|
embed_size: int, |
|
embed_dropout: float, |
|
n_head: int, |
|
history_size: int = 2, |
|
activation: str = "swish", |
|
bias: bool = False, |
|
layer_norm_epsilon: float = 1e-5, |
|
) -> None: |
|
super().__init__() |
|
|
|
self.num_heads = n_head |
|
self.embed_size = embed_size |
|
self.context_size = history_size + 1 |
|
self.pos_embed = torch.nn.Linear( |
|
embed_size * self.context_size, self.num_heads, bias=bias |
|
) |
|
self.embed = nn.Embedding(voca_size, self.embed_size) |
|
self.embed_dropout = nn.Dropout(p=embed_dropout) |
|
self.ffn = nn.Linear(self.embed_size, self.embed_size) |
|
self.norm = nn.LayerNorm(self.embed_size, eps=layer_norm_epsilon) |
|
self.activatoin = get_activation(activation) |
|
|
|
def init_state( |
|
self, batch_size: int, device: torch.device, method: str = "zero" |
|
) -> List[torch.Tensor]: |
|
assert batch_size > 0 |
|
_ = method |
|
return [ |
|
torch.zeros( |
|
batch_size, self.context_size - 1, self.embed_size, device=device |
|
), |
|
] |
|
|
|
def batch_to_cache(self, cache: List[torch.Tensor]) -> List[List[torch.Tensor]]: |
|
""" |
|
Args: |
|
cache : [history] |
|
history: [bs, ...] |
|
Returns: |
|
new_ache : [[history_1], [history_2], [history_3]...] |
|
""" |
|
assert len(cache) == 1 |
|
cache_0 = cache[0] |
|
history: List[List[torch.Tensor]] = [] |
|
for h in torch.split(cache_0, 1, dim=0): |
|
history.append([h]) |
|
return history |
|
|
|
def cache_to_batch(self, cache: List[List[torch.Tensor]]) -> List[torch.Tensor]: |
|
""" |
|
Args: |
|
cache : [[history_1], [history_2], [history3]...] |
|
|
|
Returns: |
|
new_caceh: [history], |
|
history: [bs, ...] |
|
""" |
|
history = torch.cat([h[0] for h in cache], dim=0) |
|
return [history] |
|
|
|
def forward(self, input: torch.Tensor, cache: Optional[List[torch.Tensor]] = None): |
|
"""forward for training""" |
|
input = self.embed(input) |
|
input = self.embed_dropout(input) |
|
if cache is None: |
|
zeros = self.init_state(input.size(0), device=input.device)[0] |
|
else: |
|
assert len(cache) == 1 |
|
zeros = cache[0] |
|
|
|
input = torch.cat( |
|
(zeros, input), dim=1 |
|
) |
|
|
|
input = input.unfold(1, self.context_size, 1).permute( |
|
0, 1, 3, 2 |
|
) |
|
|
|
multi_head_pos = self.pos_embed.weight.view( |
|
self.num_heads, self.embed_size, self.context_size |
|
) |
|
|
|
|
|
input_expand = input.unsqueeze(2) |
|
multi_head_pos = multi_head_pos.permute( |
|
0, 2, 1 |
|
) |
|
|
|
|
|
weight = input_expand * multi_head_pos |
|
weight = weight.sum(dim=-1, keepdim=False).unsqueeze( |
|
3 |
|
) |
|
output = weight.matmul(input_expand).squeeze( |
|
dim=3 |
|
) |
|
output = output.sum(dim=2) |
|
output = output / (self.num_heads * self.context_size) |
|
|
|
output = self.ffn(output) |
|
output = self.norm(output) |
|
output = self.activatoin(output) |
|
return output |
|
|
|
def forward_step( |
|
self, |
|
input: torch.Tensor, |
|
padding: torch.Tensor, |
|
cache: List[torch.Tensor], |
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]: |
|
"""forward step for inference |
|
Args: |
|
input (torch.Tensor): [batch_size, time_step=1] |
|
padding (torch.Tensor): [batch_size,1], 1 is padding value |
|
cache: for embedding predictor, cache[0] == history |
|
""" |
|
assert input.size(1) == 1 |
|
assert len(cache) == 1 |
|
history = cache[0] |
|
assert history.size(1) == self.context_size - 1 |
|
input = self.embed(input) |
|
input = self.embed_dropout(input) |
|
context_input = torch.cat((history, input), dim=1) |
|
input_expand = context_input.unsqueeze(1).unsqueeze( |
|
2 |
|
) |
|
|
|
|
|
multi_head_pos = self.pos_embed.weight.view( |
|
self.num_heads, self.embed_size, self.context_size |
|
) |
|
|
|
multi_head_pos = multi_head_pos.permute( |
|
0, 2, 1 |
|
) |
|
|
|
weight = input_expand * multi_head_pos |
|
weight = weight.sum(dim=-1, keepdim=False).unsqueeze( |
|
3 |
|
) |
|
output = weight.matmul(input_expand).squeeze(dim=3) |
|
output = output.sum(dim=2) |
|
output = output / (self.num_heads * self.context_size) |
|
|
|
output = self.ffn(output) |
|
output = self.norm(output) |
|
output = self.activatoin(output) |
|
new_cache = context_input[:, 1:, :] |
|
|
|
|
|
return (output, [new_cache]) |
|
|
|
|
|
class ConvPredictor(PredictorBase): |
|
def __init__( |
|
self, |
|
voca_size: int, |
|
embed_size: int, |
|
embed_dropout: float, |
|
history_size: int = 2, |
|
activation: str = "relu", |
|
bias: bool = False, |
|
layer_norm_epsilon: float = 1e-5, |
|
) -> None: |
|
super().__init__() |
|
|
|
assert history_size >= 0 |
|
self.embed_size = embed_size |
|
self.context_size = history_size + 1 |
|
self.embed = nn.Embedding(voca_size, self.embed_size) |
|
self.embed_dropout = nn.Dropout(p=embed_dropout) |
|
self.conv = nn.Conv1d( |
|
in_channels=embed_size, |
|
out_channels=embed_size, |
|
kernel_size=self.context_size, |
|
padding=0, |
|
groups=embed_size, |
|
bias=bias, |
|
) |
|
self.norm = nn.LayerNorm(embed_size, eps=layer_norm_epsilon) |
|
self.activatoin = get_activation(activation) |
|
|
|
def init_state( |
|
self, batch_size: int, device: torch.device, method: str = "zero" |
|
) -> List[torch.Tensor]: |
|
assert batch_size > 0 |
|
assert method == "zero" |
|
return [ |
|
torch.zeros( |
|
batch_size, self.context_size - 1, self.embed_size, device=device |
|
) |
|
] |
|
|
|
def cache_to_batch(self, cache: List[List[torch.Tensor]]) -> List[torch.Tensor]: |
|
""" |
|
Args: |
|
cache : [[history_1], [history_2], [history3]...] |
|
|
|
Returns: |
|
new_caceh: [history], |
|
history: [bs, ...] |
|
""" |
|
history = torch.cat([h[0] for h in cache], dim=0) |
|
return [history] |
|
|
|
def batch_to_cache(self, cache: List[torch.Tensor]) -> List[List[torch.Tensor]]: |
|
""" |
|
Args: |
|
cache : [history] |
|
history: [bs, ...] |
|
Returns: |
|
new_ache : [[history_1], [history_2], [history_3]...] |
|
""" |
|
assert len(cache) == 1 |
|
cache_0 = cache[0] |
|
history: List[List[torch.Tensor]] = [] |
|
for h in torch.split(cache_0, 1, dim=0): |
|
history.append([h]) |
|
return history |
|
|
|
def forward(self, input: torch.Tensor, cache: Optional[List[torch.Tensor]] = None): |
|
"""forward for training""" |
|
input = self.embed(input) |
|
input = self.embed_dropout(input) |
|
if cache is None: |
|
zeros = self.init_state(input.size(0), device=input.device)[0] |
|
else: |
|
assert len(cache) == 1 |
|
zeros = cache[0] |
|
|
|
input = torch.cat( |
|
(zeros, input), dim=1 |
|
) |
|
input = input.permute(0, 2, 1) |
|
out = self.conv(input).permute(0, 2, 1) |
|
out = self.activatoin(self.norm(out)) |
|
return out |
|
|
|
def forward_step( |
|
self, input: torch.Tensor, padding: torch.Tensor, cache: List[torch.Tensor] |
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]: |
|
"""forward step for inference |
|
Args: |
|
input (torch.Tensor): [batch_size, time_step=1] |
|
padding (torch.Tensor): [batch_size,1], 1 is padding value |
|
cache: for embedding predictor, cache[0] == history |
|
""" |
|
assert input.size(1) == 1 |
|
assert len(cache) == 1 |
|
history = cache[0] |
|
assert history.size(1) == self.context_size - 1 |
|
input = self.embed(input) |
|
input = self.embed_dropout(input) |
|
context_input = torch.cat((history, input), dim=1) |
|
input = context_input.permute(0, 2, 1) |
|
out = self.conv(input).permute(0, 2, 1) |
|
out = self.activatoin(self.norm(out)) |
|
|
|
new_cache = context_input[:, 1:, :] |
|
|
|
return (out, [new_cache]) |
|
|