File size: 2,183 Bytes
bea74aa
 
 
b5ac54b
 
6e257b4
 
3019ade
 
bea74aa
3019ade
bea74aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bf30df
bea74aa
 
435431a
bea74aa
 
435431a
 
 
 
 
 
 
 
 
 
bea74aa
 
435431a
 
bea74aa
 
435431a
 
 
 
 
 
 
 
 
bea74aa
435431a
 
 
b5ac54b
 
bea74aa
435431a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
from torch import nn
from transformers import AutoModel, AutoTokenizer
import gradio as gr


device = torch.device("cpu")


class RaceClassifier(nn.Module):

    def __init__(self, n_classes):
        super(RaceClassifier, self).__init__()
        self.bert = AutoModel.from_pretrained("vinai/bertweet-base")
        self.drop = nn.Dropout(p=0.3)  # can be changed in future
        self.out = nn.Linear(self.bert.config.hidden_size,
                             n_classes)  # linear layer for the output with the number of classes

    def forward(self, input_ids, attention_mask):
        bert_output = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        last_hidden_state = bert_output[0]
        pooled_output = last_hidden_state[:, 0]
        output = self.drop(pooled_output)
        return self.out(output)


labels = {
    0: "African American",
    1: "Asian",
    2: "Latin",
    3: "White"
}
model_race = RaceClassifier(n_classes=4)
model_race.to(device)
model_race.load_state_dict(torch.load('best_model_race.pt', map_location=torch.device('cpu')))


max_textboxes = 10


def update_textboxes(k):
    components = []
    if k is None:
        k = 0
    for i in range(max_textboxes):
        if i < k:
            components.append(gr.update(visible=True))
        else:
            components.append(gr.update(visible=False))
    return components


def clear_textboxes():
    return [gr.update(value='') for _ in range(max_textboxes)]


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column(scale=1):
            s = gr.Slider(1, max_textboxes, value=1, step=1, label="How many tweets do you want to enter:")
            textboxes = [gr.Textbox(label=f"Tweet {i + 1}", visible=(i == 0)) for i in range(max_textboxes)]
            s.change(fn=update_textboxes, inputs=s, outputs=textboxes)
            btn = gr.Button("Predict")
        with gr.Column(scale=1):
            output = gr.Textbox(label="Profile of User")

    btn.click(fn=predict, inputs=textboxes, outputs=output)
    btn_clear = gr.Button("Clear")
    btn_clear.click(fn=clear_textboxes, outputs=textboxes)

demo.launch()