romnatall
final
d3d0074
raw
history blame
3.73 kB
ATTENTION_SIZE=10
HIDDEN_SIZE=300
INPUT_SIZE=312
from math import e
import torch
from transformers import AutoTokenizer, AutoModel
import torch.nn as nn
import streamlit as st
class RomanAttention(nn.Module):
def __init__(self, hidden_size: int = HIDDEN_SIZE) -> None:
super().__init__()
self.clf = nn.Sequential(
nn.Linear(HIDDEN_SIZE, 512),
nn.ReLU(),
nn.Linear(512, 1),
)
def forward(self, hidden, final_hidden):
final_hidden = final_hidden.squeeze(0).unsqueeze(1)
cat = torch.concat((hidden, final_hidden), dim=1)
clf = self.clf(cat)
vals = torch.argsort(clf, descending=False, dim=1)
index=vals[:,:ATTENTION_SIZE].squeeze(2)
index1=vals[:,ATTENTION_SIZE:].squeeze(2)
selected_values = cat[torch.arange(index.size(0)).unsqueeze(1), index]
select_clf = clf[torch.arange(index.size(0)).unsqueeze(1), index1]
unselected_values = cat[torch.arange(index.size(0)).unsqueeze(1), index1]*select_clf*select_clf
mean_unselected = torch.mean(unselected_values, dim=1)
return torch.cat((selected_values, mean_unselected.unsqueeze(1)), dim=1)
import pytorch_lightning as lg
@st.cache_resource
def load_model():
m = AutoModel.from_pretrained("cointegrated/rubert-tiny2")
emb=m.embeddings
#emb.dropout=nn.Dropout(0)
for param in emb.parameters():
param.requires_grad = False
tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
return emb, tokenizer
emb, tokenizer = load_model()
def tokenize(text):
t=tokenizer(text, padding=True, truncation=True,pad_to_multiple_of=300,max_length=300)['input_ids']
if len(t) <30:
t+=[0]*(30-len(t))
return t
class MyModel(lg.LightningModule):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(INPUT_SIZE, HIDDEN_SIZE, batch_first=True)
self.attn = RomanAttention(HIDDEN_SIZE)
self.clf = nn.Sequential(
nn.Linear(HIDDEN_SIZE*(ATTENTION_SIZE+1), 100),
nn.Dropout(),
nn.ReLU(),
nn.Linear(100, 3)
)
self.criterion = nn.CrossEntropyLoss()
self.optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
self.early_stopping = lg.callbacks.EarlyStopping(
monitor='val_acc',
min_delta=0.01,
patience=2,
verbose=True,
mode='max'
)
self.verbose=False
def forward(self, x):
if type(x) == str:
x = torch.tensor([tokenize(x)])
embeddings = emb(x)
output, (h_n, c_n) = self.lstm(embeddings)
attention = self.attn(output, c_n)
out =attention #torch.cat((output, attention), dim=1)
out = nn.Flatten()(out)
out_clf = self.clf(out)
return out_clf
def training_step(self, batch, batch_idx):
x, y = batch
y_pred = self(x)
loss = self.criterion(y_pred, y)
accuracy = (torch.argmax(y_pred, dim=1) == y).float().mean()
self.log('train_loss', loss, on_epoch=True, prog_bar=True)
self.log('train_accuracy', accuracy , on_epoch=True, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_pred = self(x)
loss = self.criterion(y_pred, y)
accuracy = ( torch.argmax(y_pred, dim=1) == y).float().mean()
self.log('val_loss', loss , on_epoch=True, prog_bar=True)
self.log('val_accuracy', accuracy , on_epoch=True, prog_bar=True)
return loss
def configure_optimizers(self):
return self.optimizer