Spaces:
Sleeping
Sleeping
File size: 4,122 Bytes
bea74aa b5ac54b 6e257b4 3019ade bea74aa 3019ade bea74aa 7216ad1 bea74aa 7216ad1 bea74aa 50239d2 7216ad1 50239d2 7216ad1 bea74aa f7d5b05 bea74aa f7d5b05 7216ad1 f7d5b05 7216ad1 f7d5b05 7216ad1 f7d5b05 bea74aa 435431a bea74aa 435431a bea74aa 435431a 8225fca 435431a bea74aa 435431a b5ac54b bea74aa 435431a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import torch
from torch import nn
from transformers import AutoModel, AutoTokenizer
import gradio as gr
device = torch.device("cpu")
class RaceClassifier(nn.Module):
def __init__(self, n_classes):
super(RaceClassifier, self).__init__()
self.bert = AutoModel.from_pretrained("vinai/bertweet-base")
self.drop = nn.Dropout(p=0.3) # can be changed in future
self.out = nn.Linear(self.bert.config.hidden_size,
n_classes) # linear layer for the output with the number of classes
def forward(self, input_ids, attention_mask):
bert_output = self.bert(
input_ids=input_ids,
attention_mask=attention_mask
)
last_hidden_state = bert_output[0]
pooled_output = last_hidden_state[:, 0]
output = self.drop(pooled_output)
return self.out(output)
race_labels = {
0: "African American",
1: "Asian",
2: "Latin",
3: "White"
}
orientation_labels = {
0: "Heterosexual",
1: "LGBTQ"
}
model_race = RaceClassifier(n_classes=4)
model_race.to(device)
model_race.load_state_dict(torch.load('best_model_race.pt', map_location=torch.device('cpu')))
model_orientation = RaceClassifier(n_classes=2)
model_orientation.to(device)
model_orientation.load_state_dict(torch.load('best_model_orientation_last.pt', map_location=torch.device('cpu')))
def evaluate(model, input, mask):
model.eval()
with torch.no_grad():
outputs = model(input, mask)
probs = torch.nn.functional.softmax(outputs, dim=1)
predictions = torch.argmax(outputs, dim=1)
predictions = predictions.cpu().numpy()
return probs, predictions
def write_output(probs, predictions, title, labels):
output_string = f"{title.upper()}\n Probabilities:\n"
for i, prob in enumerate(probs[0]):
print(f"{labels[i]} = {round(prob.item() * 100, 2)}%")
output_string += f"{labels[i]} = {round(prob.item() * 100, 2)}%\n"
output_string += f"Predicted as: {labels[predictions[0]]}\n"
return output_string
def predict(*text):
tweets = [tweet for tweet in text if tweet]
print(tweets)
sentences = tweets
tokenizer = AutoTokenizer.from_pretrained("vinai/bertweet-base", normalization=True)
encoded_sentences = tokenizer(
sentences,
padding=True,
truncation=True,
return_tensors='pt',
max_length=128,
)
input_ids = encoded_sentences["input_ids"].to(device)
attention_mask = encoded_sentences["attention_mask"].to(device)
race_probs, race_predictions = evaluate(model_race, input_ids, attention_mask)
orientation_probs, orientation_predictions = evaluate(model_orientation, input_ids, attention_mask)
final_output = str()
final_output += write_output(race_probs, race_predictions, "race", race_labels)
final_output += "\n"
final_output += write_output(orientation_probs, orientation_predictions, "sexual orientation", orientation_labels)
final_output += "\n"
return final_output
max_textboxes = 20
def update_textboxes(k):
components = []
if k is None:
k = 0
for i in range(max_textboxes):
if i < k:
components.append(gr.update(visible=True))
else:
components.append(gr.update(visible=False))
return components
def clear_textboxes():
return [gr.update(value='') for _ in range(max_textboxes)]
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=1):
s = gr.Slider(1, max_textboxes, value=1, step=1, label="How many tweets do you want to enter:")
textboxes = [gr.Textbox(label=f"Tweet {i + 1}", visible=(i == 0)) for i in range(max_textboxes)]
s.change(fn=update_textboxes, inputs=s, outputs=textboxes)
btn = gr.Button("Predict")
btn_clear = gr.Button("Clear")
with gr.Column(scale=1):
output = gr.Textbox(label="Profile of User")
btn.click(fn=predict, inputs=textboxes, outputs=output)
btn_clear.click(fn=clear_textboxes, outputs=textboxes)
demo.launch()
|