CatoEr commited on
Commit
f7d5b05
1 Parent(s): 435431a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -1
app.py CHANGED
@@ -37,8 +37,43 @@ model_race = RaceClassifier(n_classes=4)
37
  model_race.to(device)
38
  model_race.load_state_dict(torch.load('best_model_race.pt', map_location=torch.device('cpu')))
39
 
 
 
 
 
40
 
41
- max_textboxes = 10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
 
44
  def update_textboxes(k):
 
37
  model_race.to(device)
38
  model_race.load_state_dict(torch.load('best_model_race.pt', map_location=torch.device('cpu')))
39
 
40
+ def predict(*text):
41
+ tweets = [tweet for tweet in text if tweet]
42
+ print(tweets)
43
+ sentences = tweets
44
 
45
+ tokenizer = AutoTokenizer.from_pretrained("vinai/bertweet-base", normalization=True)
46
+
47
+ encoded_sentences = tokenizer(
48
+ sentences,
49
+ padding=True,
50
+ truncation=True,
51
+ return_tensors='pt',
52
+ max_length=128,
53
+ )
54
+
55
+ input_ids = encoded_sentences["input_ids"].to(device)
56
+ attention_mask = encoded_sentences["attention_mask"].to(device)
57
+
58
+ model_race.eval()
59
+ with torch.no_grad():
60
+ outputs = model_race(input_ids, attention_mask)
61
+ probs = torch.nn.functional.softmax(outputs, dim=1)
62
+ predictions = torch.argmax(outputs, dim=1)
63
+ predictions = predictions.cpu().numpy()
64
+
65
+ output_string = "RACE\n Probabilities:\n"
66
+ for i, prob in enumerate(probs[0]):
67
+ print(f"{labels[i]} = {round(prob.item() * 100, 2)}%")
68
+ output_string += f"{labels[i]} = {round(prob.item() * 100, 2)}%\n"
69
+
70
+ print(labels[predictions[0]])
71
+ output_string += f"Predicted as: {labels[predictions[0]]}"
72
+
73
+ return output_string
74
+
75
+
76
+ max_textboxes = 20
77
 
78
 
79
  def update_textboxes(k):