import torch import torch.nn as nn from .lstm_hsm import LSTMHardSigmoid from . import encode, decode from typing import Union, List class Shakkala(nn.Module): def __init__(self, dim_input: int=149, dim_output: int=28, sd_path: str=None): super().__init__() self.emb_input = nn.Embedding(dim_input, 288) self.lstm0 = LSTMHardSigmoid(288, hidden_size=288, bidirectional=True, batch_first=True) self.bn0 = nn.BatchNorm1d(576, momentum=0.01, eps=0.001) self.lstm1 = LSTMHardSigmoid(576, hidden_size=144, bidirectional=True, batch_first=True) self.lstm2 = LSTMHardSigmoid(288, hidden_size=96, bidirectional=True, batch_first=True) self.dense0 = nn.Linear(192, dim_output) self.eval() self.max_sentence = None if sd_path is not None: self.load_state_dict(torch.load(sd_path)) def forward(self, x: torch.Tensor): x = self.emb_input(x) x, _ = self.lstm0(x) x = self.bn0(x.transpose(1,2)).transpose(1,2) x, _ = self.lstm1(x) x, _ = self.lstm2(x) x = self.dense0(x) x = nn.Softmax(dim=-1)(x) return x @torch.inference_mode() def infer(self, x: torch.Tensor): return self.forward(x) def _predict_list(self, input_list: List[str], return_probs: bool=False): output_list = [] probs_list = [] for input_text in input_list: if return_probs: output_text, probs = self._predict_single(input_text, return_probs=True) output_list.append(output_text) probs_list.append(probs) else: output_list.append(self._predict_single(input_text)) if return_probs: return output_list, return_probs return output_list def _predict_single(self, input_text: str, return_probs: bool=False): input_ids_pad, input_letters_ids = encode(input_text, self.max_sentence) input = torch.LongTensor(input_ids_pad)[None].to(self.emb_input.weight.device) probs = self.infer(input).cpu() output = decode(probs, input_text, input_letters_ids) if return_probs: return output, probs return output def predict(self, input: Union[str, List[str]], return_probs: bool=False): if isinstance(input, str): return self._predict_single(input, return_probs=return_probs) return self._predict_list(input, return_probs=return_probs)