Spaces:
Sleeping
Sleeping
ATTENTION_SIZE=10 | |
HIDDEN_SIZE=300 | |
INPUT_SIZE=312 | |
from math import e | |
import torch | |
from transformers import AutoTokenizer, AutoModel | |
import torch.nn as nn | |
import streamlit as st | |
class RomanAttention(nn.Module): | |
def __init__(self, hidden_size: int = HIDDEN_SIZE) -> None: | |
super().__init__() | |
self.clf = nn.Sequential( | |
nn.Linear(HIDDEN_SIZE, 512), | |
nn.ReLU(), | |
nn.Linear(512, 1), | |
) | |
def forward(self, hidden, final_hidden): | |
final_hidden = final_hidden.squeeze(0).unsqueeze(1) | |
cat = torch.concat((hidden, final_hidden), dim=1) | |
clf = self.clf(cat) | |
vals = torch.argsort(clf, descending=False, dim=1) | |
index=vals[:,:ATTENTION_SIZE].squeeze(2) | |
index1=vals[:,ATTENTION_SIZE:].squeeze(2) | |
selected_values = cat[torch.arange(index.size(0)).unsqueeze(1), index] | |
select_clf = clf[torch.arange(index.size(0)).unsqueeze(1), index1] | |
unselected_values = cat[torch.arange(index.size(0)).unsqueeze(1), index1]*select_clf*select_clf | |
mean_unselected = torch.mean(unselected_values, dim=1) | |
return torch.cat((selected_values, mean_unselected.unsqueeze(1)), dim=1) | |
import pytorch_lightning as lg | |
def load_model(): | |
m = AutoModel.from_pretrained("cointegrated/rubert-tiny2") | |
emb=m.embeddings | |
#emb.dropout=nn.Dropout(0) | |
for param in emb.parameters(): | |
param.requires_grad = False | |
tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2") | |
return emb, tokenizer | |
emb, tokenizer = load_model() | |
def tokenize(text): | |
t=tokenizer(text, padding=True, truncation=True,pad_to_multiple_of=300,max_length=300)['input_ids'] | |
if len(t) <30: | |
t+=[0]*(30-len(t)) | |
return t | |
class MyModel(lg.LightningModule): | |
def __init__(self): | |
super().__init__() | |
self.lstm = nn.LSTM(INPUT_SIZE, HIDDEN_SIZE, batch_first=True) | |
self.attn = RomanAttention(HIDDEN_SIZE) | |
self.clf = nn.Sequential( | |
nn.Linear(HIDDEN_SIZE*(ATTENTION_SIZE+1), 100), | |
nn.Dropout(), | |
nn.ReLU(), | |
nn.Linear(100, 3) | |
) | |
self.criterion = nn.CrossEntropyLoss() | |
self.optimizer = torch.optim.Adam(self.parameters(), lr=0.001) | |
self.early_stopping = lg.callbacks.EarlyStopping( | |
monitor='val_acc', | |
min_delta=0.01, | |
patience=2, | |
verbose=True, | |
mode='max' | |
) | |
self.verbose=False | |
def forward(self, x): | |
if type(x) == str: | |
x = torch.tensor([tokenize(x)]) | |
embeddings = emb(x) | |
output, (h_n, c_n) = self.lstm(embeddings) | |
attention = self.attn(output, c_n) | |
out =attention #torch.cat((output, attention), dim=1) | |
out = nn.Flatten()(out) | |
out_clf = self.clf(out) | |
return out_clf | |
def training_step(self, batch, batch_idx): | |
x, y = batch | |
y_pred = self(x) | |
loss = self.criterion(y_pred, y) | |
accuracy = (torch.argmax(y_pred, dim=1) == y).float().mean() | |
self.log('train_loss', loss, on_epoch=True, prog_bar=True) | |
self.log('train_accuracy', accuracy , on_epoch=True, prog_bar=True) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
x, y = batch | |
y_pred = self(x) | |
loss = self.criterion(y_pred, y) | |
accuracy = ( torch.argmax(y_pred, dim=1) == y).float().mean() | |
self.log('val_loss', loss , on_epoch=True, prog_bar=True) | |
self.log('val_accuracy', accuracy , on_epoch=True, prog_bar=True) | |
return loss | |
def configure_optimizers(self): | |
return self.optimizer | |