import streamlit as st from annotated_text import annotated_text import os import torch.nn as nn import torch.nn.functional as F import torch import torch.optim as optim from transformers import DistilBertModel from transformers import AutoTokenizer import lightning.pytorch as pl class Classifier(pl.LightningModule): def __init__(self): super().__init__() self.ln1 = torch.nn.Linear(512*768, 3) # self.ln2 = torch.nn.Linear(1000, 3 ) self.criterion = nn.CrossEntropyLoss() def training_step(self, batch, batch_idx): x, y = batch with torch.no_grad(): x = get_bert()(input_ids = x[:,:512], attention_mask = x[:,512:]).last_hidden_state.reshape(-1, 512*768) x = (x/torch.linalg.norm(x,2, 1)).reshape(-1,512*768) x = self.ln1(x) # x = self.ln2(x) loss = self.criterion(x, y) self.log("my_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) return loss def configure_optimizers(self): optimizer = optim.Adam(self.parameters(), lr=1e-3) return optimizer def preprocess(self, x): tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased", use_fast=True) return tokenizer(x, padding='max_length', return_tensors="pt") def forward(self, x): print("here!", self.ln1.type) with torch.no_grad(): x = get_bert()(**x).last_hidden_state.reshape(-1, 512*768) x = (x/torch.linalg.norm(x,2, 1)).reshape(-1,512*768) x = self.ln1(x) # x = self.ln2(x) return x @st.cache def get_bert(): return DistilBertModel.from_pretrained("distilbert-base-uncased") @st.cache def get_classifier(): os.system('gdown 1GxhHvg3lwlGpA7So06v3l43U8pSASy9L') return Classifier.load_from_checkpoint(f"{os.getcwd()}/model_params") def get_annotated_text(text): model = get_classifier() text = text.split(".") l = [] for i in text: if i.strip(' ') == '': continue c = model(model.preprocess([i])).argmax() print("class : ", c) if c == 0: l.append((i, "Leadership")) if c == 1: l.append((i, "Diversity")) if c == 2: l.append((i, "Integrity")) l.append(".") return tuple(l) st.title("Code of Conduct Classifier") input_text = st.text_area("enter code of conduct text" ) st.title("annotated text") print(input_text) annotated_text(*get_annotated_text(input_text))