import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification from catboost import CatBoostClassifier import torch.nn as nn import streamlit as st @st.cache_resource def load_model(): catboost_model = CatBoostClassifier(random_seed=42,eval_metric='Accuracy') catboost_model.load_model("pages/anti_toxic/dont_be_toxic.pt") model_checkpoint = 'cointegrated/rubert-tiny-toxicity' tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint) model.classifier=nn.Dropout(0) model.dropout = nn.Dropout(0) return catboost_model, tokenizer, model catboost_model, tokenizer, model = load_model() def predict(text): t=tokenizer(text, return_tensors='pt',truncation=True, padding=True) with torch.no_grad(): t = model(**t)[0].tolist()[0] return catboost_model.predict_proba(t)