romnatall
chache bert model
6a3534a
raw
history blame
1.34 kB
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.linear_model import LogisticRegression
import streamlit as st
import pickle
import streamlit as st
@st.cache_resource
def get_model():
model = AutoModel.from_pretrained("cointegrated/rubert-tiny2")
tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
return model, tokenizer
model, tokenizer = get_model()
def predict_bert(input_text):
MAX_LEN = 300
tokenized_input = tokenizer.encode(input_text, add_special_tokens=True, truncation=True, max_length=MAX_LEN)
padded_input = np.array(tokenized_input + [0]*(MAX_LEN-len(tokenized_input)))
attention_mask = np.where(padded_input != 0, 1, 0)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
with torch.no_grad():
input_tensor = torch.tensor(padded_input).unsqueeze(0).to(device)
attention_mask_tensor = torch.tensor(attention_mask).unsqueeze(0).to(device)
last_hidden_states = model(input_tensor, attention_mask=attention_mask_tensor)[0]
features = last_hidden_states[:,0,:].cpu().numpy()
with open('pages/film_review/model/log_reg_bert.pkl', 'rb') as f:
loaded_model = pickle.load(f)
prediction = loaded_model.predict(features)
return prediction[0]