wasmdashai's picture
first commit
7694c84
raw
history blame
2.44 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import encode, decode
from typing import Union, List
class Shakkelha(nn.Module):
def __init__(self,
dim_input: int=91,
dim_output: int=19,
sd_path: str=None):
super().__init__()
self.emb0 = nn.Embedding(dim_input, 25)
self.lstm0 = nn.LSTM(25, 256, batch_first=True, bidirectional=True)
self.lstm1 = nn.LSTM(512, 256, batch_first=True, bidirectional=True)
self.dropout = nn.Dropout(p=0.5)
self.dense0 = nn.Linear(512, 512)
self.dense1 = nn.Linear(512, 512)
self.dense2 = nn.Linear(512, dim_output)
self.eval()
if sd_path is not None:
self.load_state_dict(torch.load(sd_path))
def forward(self, x: torch.Tensor):
x = self.emb0(x)
x, _ = self.lstm0(x)
x = self.dropout(x)
x, _ = self.lstm1(x)
x = self.dropout(x)
x = F.relu(self.dense0(x))
x = F.relu(self.dense1(x))
x = F.softmax(self.dense2(x), dim=-1)
return x
@torch.inference_mode()
def infer(self, x: torch.Tensor):
return self.forward(x)
def _predict_list(self, input_list: List[str], return_probs: bool=False):
output_list = []
probs_list = []
for input_text in input_list:
if return_probs:
output_text, probs = self._predict_single(input_text, return_probs=True)
output_list.append(output_text)
probs_list.append(probs)
else:
output_list.append(self._predict_single(input_text))
if return_probs:
return output_list, probs_list
return output_list
def _predict_single(self, input_text: str, return_probs: bool=False):
ids = encode(input_text)
input = torch.LongTensor(ids)[None].to(self.emb0.weight.device)
probs = self.infer(input).cpu()
output = decode(probs, input_text)
if return_probs:
return output, probs
return output
def predict(self, input: Union[str, List[str]], return_probs: bool=False):
if isinstance(input, str):
return self._predict_single(input, return_probs=return_probs)
return self._predict_list(input, return_probs=return_probs)