Spaces:
Runtime error
Runtime error
File size: 2,105 Bytes
9047480 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import streamlit as st
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
import pandas as pd
import random
classifiers = ['toxic', 'severe_toxic', 'obscene',
'threat', 'insult', 'identity_hate']
def reset_scores():
global scores_df
scores_df = pd.DataFrame(columns=['Comment'] + classifiers)
def get_score(model_base, text):
if model_base == "bert-base-cased":
model_dir = "./bert/_bert_model"
elif model_base == "distilbert-base-cased":
model_dir = "./distilbert/_distilbert_model"
else:
model_dir = "./roberta/_roberta_model"
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_base)
inputs = tokenizer.encode_plus(
text, max_length=512, truncation=True, padding=True, return_tensors='pt')
outputs = model(**inputs)
predictions = torch.sigmoid(outputs.logits)
return predictions
st.title("Toxic Comment Classifier")
model_base = st.selectbox("Select a pretrained model",
["roberta-base", "bert-base-cased", "distilbert-base-cased"])
text_input = st.text_input("Enter text for toxicity classification",
"")
submit_btn = st.button("Submit")
if submit_btn and text_input:
result = get_score(model_base, text_input)
df = pd.DataFrame([result[0].tolist()], columns=classifiers)
df = df.round(2) # Round the values to 2 decimal places
df = df.applymap(lambda x: '{:.0%}'.format(x))
st.table(df)
test_df = pd.read_csv(
"./jigsaw-toxic-comment-classification-challenge/test.csv")
sample_df = test_df.sample(n=3)
reset_scores()
for index, row in sample_df.iterrows():
result = get_score(model_base, row['comment_text'])
scores = result[0].tolist()
scores_df.loc[len(scores_df)] = [row['comment_text']] + scores
scores_df = scores_df.round(2)
st.subheader("Toxicity Scores for Random Comments")
if st.button("Refresh"):
reset_scores()
st.success("New tweets have been loaded!")
st.table(scores_df)
|