Spaces:
Sleeping
Sleeping
File size: 2,183 Bytes
bea74aa b5ac54b 6e257b4 3019ade bea74aa 3019ade bea74aa 3bf30df bea74aa 435431a bea74aa 435431a bea74aa 435431a bea74aa 435431a bea74aa 435431a b5ac54b bea74aa 435431a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import torch
from torch import nn
from transformers import AutoModel, AutoTokenizer
import gradio as gr
device = torch.device("cpu")
class RaceClassifier(nn.Module):
def __init__(self, n_classes):
super(RaceClassifier, self).__init__()
self.bert = AutoModel.from_pretrained("vinai/bertweet-base")
self.drop = nn.Dropout(p=0.3) # can be changed in future
self.out = nn.Linear(self.bert.config.hidden_size,
n_classes) # linear layer for the output with the number of classes
def forward(self, input_ids, attention_mask):
bert_output = self.bert(
input_ids=input_ids,
attention_mask=attention_mask
)
last_hidden_state = bert_output[0]
pooled_output = last_hidden_state[:, 0]
output = self.drop(pooled_output)
return self.out(output)
labels = {
0: "African American",
1: "Asian",
2: "Latin",
3: "White"
}
model_race = RaceClassifier(n_classes=4)
model_race.to(device)
model_race.load_state_dict(torch.load('best_model_race.pt', map_location=torch.device('cpu')))
max_textboxes = 10
def update_textboxes(k):
components = []
if k is None:
k = 0
for i in range(max_textboxes):
if i < k:
components.append(gr.update(visible=True))
else:
components.append(gr.update(visible=False))
return components
def clear_textboxes():
return [gr.update(value='') for _ in range(max_textboxes)]
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=1):
s = gr.Slider(1, max_textboxes, value=1, step=1, label="How many tweets do you want to enter:")
textboxes = [gr.Textbox(label=f"Tweet {i + 1}", visible=(i == 0)) for i in range(max_textboxes)]
s.change(fn=update_textboxes, inputs=s, outputs=textboxes)
btn = gr.Button("Predict")
with gr.Column(scale=1):
output = gr.Textbox(label="Profile of User")
btn.click(fn=predict, inputs=textboxes, outputs=output)
btn_clear = gr.Button("Clear")
btn_clear.click(fn=clear_textboxes, outputs=textboxes)
demo.launch()
|