Spaces:
Sleeping
Sleeping
import torch | |
from transformers import BertTokenizer, BertModel | |
from huggingface_hub import PyTorchModelHubMixin | |
import numpy as np | |
import gradio as gr | |
import nltk | |
nltk.download('stopwords') | |
from nltk.corpus import stopwords | |
import re | |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
device | |
class BERTClass(torch.nn.Module, PyTorchModelHubMixin): | |
def __init__(self): | |
super(BERTClass, self).__init__() | |
self.bert_model = BertModel.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2', return_dict=True) | |
self.dropout = torch.nn.Dropout(0.3) | |
self.linear = torch.nn.Linear(1024, 11) | |
def forward(self, input_ids, attn_mask, token_type_ids): | |
output = self.bert_model( | |
input_ids, | |
attention_mask=attn_mask, | |
token_type_ids=token_type_ids | |
) | |
output_dropout = self.dropout(output.pooler_output) | |
output = self.linear(output_dropout) | |
return output | |
model = BERTClass() | |
model = model.from_pretrained("Asutosh2003/ct-bert-v2-vaccine-concern") | |
model.to(device) | |
tokenizer = BertTokenizer.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2') | |
MAX_LEN = 256 | |
def rmTrash(raw_string, remuser, remstop, remurls): | |
final_string = "" | |
raw_string_2 = "" | |
if remuser == True: | |
for i in raw_string.split(): | |
if '@' not in i: | |
raw_string_2 += ' ' + i | |
else: | |
raw_string_2 = raw_string | |
raw_string_2 = re.sub(r'[^\w\s]', '', raw_string_2.lower()) | |
if remurls == True: | |
raw_string_2 = re.sub(r'http\S+', '', raw_string_2.lower()) | |
if remstop == True: | |
raw_string_tokens = raw_string_2.split() | |
for token in raw_string_tokens: | |
if (not(token in stopwords.words('english'))): | |
final_string = final_string + ' ' + token | |
else: | |
final_string = raw_string_2 | |
return final_string | |
def return_vec(text): | |
text = rmTrash(text,True,True,True) | |
encodings = tokenizer.encode_plus( | |
text, | |
None, | |
add_special_tokens=True, | |
max_length=MAX_LEN, | |
padding='max_length', | |
return_token_type_ids=True, | |
truncation=True, | |
return_attention_mask=True, | |
return_tensors='pt' | |
) | |
model.eval() | |
with torch.no_grad(): | |
input_ids = encodings['input_ids'].to(device, dtype=torch.long) | |
attention_mask = encodings['attention_mask'].to(device, dtype=torch.long) | |
token_type_ids = encodings['token_type_ids'].to(device, dtype=torch.long) | |
output = model(input_ids, attention_mask, token_type_ids) | |
final_output = torch.sigmoid(output).cpu().detach().numpy().tolist() | |
return list(final_output[0]) | |
def filter_threshold_lst(vector, threshold_list): | |
optimized_vector = [] | |
optimized_vector = [1 if val >= threshold else 0 for val, threshold in zip(vector, threshold_list)] | |
optimized_vector.append(optimized_vector) | |
return optimized_vector | |
def predict(text, threshold_lst): | |
pred_lbl_lst = [] | |
labels = ('side-effect', 'ineffective', 'rushed', 'pharma', 'mandatory', 'unnecessary', 'political', 'ingredients', 'conspiracy', 'country', 'religious') | |
prob_lst = return_vec(text) | |
vec = filter_threshold_lst(prob_lst, threshold_lst) | |
if vec[:11] == [0] * 11: | |
pred_lbl_lst = ['none'] | |
vec = [0] * 11 | |
vec.append(1) | |
return pred_lbl_lst, prob_lst | |
for i in range(len(vec)): | |
if vec[i] == 1: | |
pred_lbl_lst.append(labels[i]) | |
return pred_lbl_lst, prob_lst | |
def gr_predict(text): | |
thres = [0.616, 0.212, 0.051, 0.131, 0.212, 0.111, 0.071, 0.566, 0.061, 0.02, 0.081] | |
out_lst, _ = predict(text,thres) | |
out_str = '' | |
for lbl in out_lst: | |
out_str += lbl + ',' | |
out_str = out_str[:-1] | |
return out_str | |
descr = """ | |
This app uses [Covid-twitter-BERT-v2](https://huggingface.co/digitalepidemiologylab/covid-twitter-bert-v2) | |
fine tuned on a custom subset of [Caves dataset](https://arxiv.org/abs/2204.13746) sent by [FIRE 2023](http://fire.irsi.res.in/fire/2023/home) | |
conference to do multi-label classification of tweets expressing concerns towards vaccines. The different concerns/classes are | |
('side-effect', 'ineffective', 'rushed', 'pharma', 'mandatory', 'unnecessary', 'political', 'ingredients', 'conspiracy', 'country', 'religious'). | |
Each tweet can be expressing multiple of these concerns. If a tweet is not expressing any concern falling into any of these categories | |
it will be classified as 'None'.\n | |
[Source files](https://github.com/Ranjit246/AISoME_FIRE_2023)\n | |
Try it out with some ridiculous statements about vaccines. You can use the examples below as a start. | |
""" | |
# Gradio Interface | |
iface = gr.Interface( | |
fn=gr_predict, | |
inputs=gr.Textbox(), | |
outputs=gr.Label(), # Use Label widget for output | |
examples=["This vaccine gave me mumps", "Chinese vaccine will infect our brain", | |
"Trump is gonna use these vaccines to control us and become the president"], | |
title="Vaccine Concerns ML", | |
description=descr | |
) | |
# Launch the Gradio app | |
iface.launch(debug=True) |