File size: 3,732 Bytes
26290c2
 
 
d3d0074
26290c2
 
 
d3d0074
26290c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3d0074
 
 
 
 
 
 
 
 
26290c2
d3d0074
26290c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
ATTENTION_SIZE=10
HIDDEN_SIZE=300
INPUT_SIZE=312
from math import e
import torch
from transformers import AutoTokenizer, AutoModel
import torch.nn as nn
import streamlit as st

class RomanAttention(nn.Module):
    def __init__(self, hidden_size: int = HIDDEN_SIZE) -> None:
        super().__init__()
        self.clf = nn.Sequential(
            nn.Linear(HIDDEN_SIZE, 512),
            nn.ReLU(),
            nn.Linear(512, 1),
        )

    def forward(self, hidden, final_hidden):
        final_hidden = final_hidden.squeeze(0).unsqueeze(1)

        cat = torch.concat((hidden, final_hidden), dim=1)
        clf = self.clf(cat)
        vals = torch.argsort(clf, descending=False, dim=1)
        index=vals[:,:ATTENTION_SIZE].squeeze(2)
        index1=vals[:,ATTENTION_SIZE:].squeeze(2)
        selected_values = cat[torch.arange(index.size(0)).unsqueeze(1), index]
        select_clf = clf[torch.arange(index.size(0)).unsqueeze(1), index1]
        unselected_values = cat[torch.arange(index.size(0)).unsqueeze(1), index1]*select_clf*select_clf
        mean_unselected = torch.mean(unselected_values, dim=1)
        return torch.cat((selected_values, mean_unselected.unsqueeze(1)), dim=1)


import pytorch_lightning as  lg

@st.cache_resource
def load_model():
    m = AutoModel.from_pretrained("cointegrated/rubert-tiny2")
    emb=m.embeddings
    #emb.dropout=nn.Dropout(0)
    for param in emb.parameters():
        param.requires_grad = False
    tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
    return emb, tokenizer

emb, tokenizer = load_model()

def tokenize(text):
    t=tokenizer(text, padding=True, truncation=True,pad_to_multiple_of=300,max_length=300)['input_ids']
    if len(t) <30:
        t+=[0]*(30-len(t))
    return t


class MyModel(lg.LightningModule):
    def __init__(self):
        super().__init__()

        self.lstm = nn.LSTM(INPUT_SIZE, HIDDEN_SIZE, batch_first=True)
        self.attn = RomanAttention(HIDDEN_SIZE)
        self.clf = nn.Sequential(
            nn.Linear(HIDDEN_SIZE*(ATTENTION_SIZE+1), 100),
            nn.Dropout(),
            nn.ReLU(),
            nn.Linear(100, 3)
        )
        
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        self.early_stopping = lg.callbacks.EarlyStopping(
            monitor='val_acc',  
            min_delta=0.01,      
            patience=2,         
            verbose=True,
            mode='max'           
        )
        self.verbose=False

    def forward(self, x):
        if type(x) == str:
            x = torch.tensor([tokenize(x)])
        embeddings = emb(x)
        output, (h_n, c_n) = self.lstm(embeddings)
        attention = self.attn(output, c_n)
        out =attention #torch.cat((output, attention), dim=1)
        out = nn.Flatten()(out)
        out_clf = self.clf(out)
        return out_clf
        

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = self.criterion(y_pred, y)
        
        accuracy = (torch.argmax(y_pred, dim=1) == y).float().mean()
        self.log('train_loss', loss, on_epoch=True, prog_bar=True)
        self.log('train_accuracy', accuracy , on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = self.criterion(y_pred, y)
        accuracy = ( torch.argmax(y_pred, dim=1) == y).float().mean()
        self.log('val_loss', loss , on_epoch=True, prog_bar=True)
        self.log('val_accuracy', accuracy , on_epoch=True, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        return self.optimizer