File size: 2,612 Bytes
7694c84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import torch.nn as nn
from .lstm_hsm import LSTMHardSigmoid
from . import encode, decode
from typing import Union, List


class Shakkala(nn.Module):
    def __init__(self, 
                 dim_input: int=149, 
                 dim_output: int=28,
                 sd_path: str=None):
        super().__init__()
        self.emb_input = nn.Embedding(dim_input, 288)

        self.lstm0 = LSTMHardSigmoid(288, hidden_size=288, bidirectional=True, batch_first=True)
        self.bn0 = nn.BatchNorm1d(576, momentum=0.01, eps=0.001)
        self.lstm1 = LSTMHardSigmoid(576, hidden_size=144, bidirectional=True, batch_first=True)
        self.lstm2 = LSTMHardSigmoid(288, hidden_size=96, bidirectional=True, batch_first=True)

        self.dense0 = nn.Linear(192, dim_output)

        self.eval()
        self.max_sentence = None

        if sd_path is not None:
            self.load_state_dict(torch.load(sd_path))
    
    def forward(self, x: torch.Tensor):
        x = self.emb_input(x)

        x, _ = self.lstm0(x)
        x = self.bn0(x.transpose(1,2)).transpose(1,2)   
        x, _ = self.lstm1(x)
        x, _ = self.lstm2(x)

        x = self.dense0(x)
        x = nn.Softmax(dim=-1)(x)

        return x
    
    @torch.inference_mode()
    def infer(self, x: torch.Tensor):
        return self.forward(x)    

    def _predict_list(self, input_list: List[str], return_probs: bool=False):
        output_list = []
        probs_list = []
        for input_text in input_list:
            if return_probs:
                output_text, probs = self._predict_single(input_text, return_probs=True)
                output_list.append(output_text)
                probs_list.append(probs)
            else:
                output_list.append(self._predict_single(input_text))

        if return_probs:
            return output_list, return_probs
        
        return output_list
    
    def _predict_single(self, input_text: str, return_probs: bool=False):
        input_ids_pad, input_letters_ids = encode(input_text, self.max_sentence)
        input = torch.LongTensor(input_ids_pad)[None].to(self.emb_input.weight.device)
        probs = self.infer(input).cpu()
        output = decode(probs, input_text, input_letters_ids)

        if return_probs:
            return output, probs
        
        return output

    def predict(self, input: Union[str, List[str]], return_probs: bool=False):
        if isinstance(input, str):
            return self._predict_single(input, return_probs=return_probs)        
        
        return self._predict_list(input, return_probs=return_probs)