Kithogue commited on
Commit
bde5804
Β·
1 Parent(s): a3be615

Add app file

Browse files
Files changed (1) hide show
  1. app.py +31 -0
app.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import torch
6
+ import tqdm
7
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
+ from baseline_BERT import id2label
9
+ import gradio as gr
10
+
11
+ model_ckpt = "Kithogue/2-lvl-events-multilingual"
12
+ tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
13
+
14
+
15
+ def get_inference(sample):
16
+ model_hf = AutoModelForSequenceClassification.from_pretrained(model_ckpt)
17
+ encoding = tokenizer(sample, return_tensors="pt")
18
+ encoding = {k: v.to('cuda') for k, v in encoding.items()}
19
+ outputs = model_hf(**encoding)
20
+ logits = outputs.logits
21
+ # apply sigmoid + threshold
22
+ sigmoid = torch.nn.Sigmoid()
23
+ probs = sigmoid(logits.squeeze().cpu())
24
+ predictions = np.zeros(probs.shape)
25
+ predictions[np.where(probs >= 0.4)] = 1
26
+ # turn predicted id's into actual label names
27
+ predicted_labels = [id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]
28
+ return predicted_labels
29
+
30
+
31
+ gr.Interface(fn=get_inference, inputs=["text"], outputs=["text"]).launch(share=True)