ATTENTION_SIZE=10 HIDDEN_SIZE=300 INPUT_SIZE=312 from math import e import torch from transformers import AutoTokenizer, AutoModel import torch.nn as nn import streamlit as st class RomanAttention(nn.Module): def __init__(self, hidden_size: int = HIDDEN_SIZE) -> None: super().__init__() self.clf = nn.Sequential( nn.Linear(HIDDEN_SIZE, 512), nn.ReLU(), nn.Linear(512, 1), ) def forward(self, hidden, final_hidden): final_hidden = final_hidden.squeeze(0).unsqueeze(1) cat = torch.concat((hidden, final_hidden), dim=1) clf = self.clf(cat) vals = torch.argsort(clf, descending=False, dim=1) index=vals[:,:ATTENTION_SIZE].squeeze(2) index1=vals[:,ATTENTION_SIZE:].squeeze(2) selected_values = cat[torch.arange(index.size(0)).unsqueeze(1), index] select_clf = clf[torch.arange(index.size(0)).unsqueeze(1), index1] unselected_values = cat[torch.arange(index.size(0)).unsqueeze(1), index1]*select_clf*select_clf mean_unselected = torch.mean(unselected_values, dim=1) return torch.cat((selected_values, mean_unselected.unsqueeze(1)), dim=1) import pytorch_lightning as lg @st.cache_resource def load_model(): m = AutoModel.from_pretrained("cointegrated/rubert-tiny2") emb=m.embeddings #emb.dropout=nn.Dropout(0) for param in emb.parameters(): param.requires_grad = False tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2") return emb, tokenizer emb, tokenizer = load_model() def tokenize(text): t=tokenizer(text, padding=True, truncation=True,pad_to_multiple_of=300,max_length=300)['input_ids'] if len(t) <30: t+=[0]*(30-len(t)) return t class MyModel(lg.LightningModule): def __init__(self): super().__init__() self.lstm = nn.LSTM(INPUT_SIZE, HIDDEN_SIZE, batch_first=True) self.attn = RomanAttention(HIDDEN_SIZE) self.clf = nn.Sequential( nn.Linear(HIDDEN_SIZE*(ATTENTION_SIZE+1), 100), nn.Dropout(), nn.ReLU(), nn.Linear(100, 3) ) self.criterion = nn.CrossEntropyLoss() self.optimizer = torch.optim.Adam(self.parameters(), lr=0.001) self.early_stopping = lg.callbacks.EarlyStopping( monitor='val_acc', min_delta=0.01, patience=2, verbose=True, mode='max' ) self.verbose=False def forward(self, x): if type(x) == str: x = torch.tensor([tokenize(x)]) embeddings = emb(x) output, (h_n, c_n) = self.lstm(embeddings) attention = self.attn(output, c_n) out =attention #torch.cat((output, attention), dim=1) out = nn.Flatten()(out) out_clf = self.clf(out) return out_clf def training_step(self, batch, batch_idx): x, y = batch y_pred = self(x) loss = self.criterion(y_pred, y) accuracy = (torch.argmax(y_pred, dim=1) == y).float().mean() self.log('train_loss', loss, on_epoch=True, prog_bar=True) self.log('train_accuracy', accuracy , on_epoch=True, prog_bar=True) return loss def validation_step(self, batch, batch_idx): x, y = batch y_pred = self(x) loss = self.criterion(y_pred, y) accuracy = ( torch.argmax(y_pred, dim=1) == y).float().mean() self.log('val_loss', loss , on_epoch=True, prog_bar=True) self.log('val_accuracy', accuracy , on_epoch=True, prog_bar=True) return loss def configure_optimizers(self): return self.optimizer