Update app.py
Browse files
app.py
CHANGED
@@ -6,16 +6,15 @@ def Model():
|
|
6 |
from transformers import DebertaTokenizer, DebertaForSequenceClassification
|
7 |
tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base")
|
8 |
model = DebertaForSequenceClassification.from_pretrained("microsoft/deberta-base", num_labels=8)
|
9 |
-
bn_state_dict = torch.load('model_weights.pt')
|
10 |
model.load_state_dict(bn_state_dict)
|
11 |
return model, tokenizer
|
12 |
|
13 |
def Predict(model, tokenizer, text):
|
14 |
res = tokenizer(s, padding=True, truncation=True, return_tensors="pt", max_length=512)
|
15 |
-
res.to("cuda:0")
|
16 |
res = model(**res)
|
17 |
logits = res.logits.softmax(dim=1)
|
18 |
-
logits = logits.cpu().detach().numpy()[0]
|
19 |
return logits
|
20 |
|
21 |
def Print(logits, dictionary):
|
|
|
6 |
from transformers import DebertaTokenizer, DebertaForSequenceClassification
|
7 |
tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base")
|
8 |
model = DebertaForSequenceClassification.from_pretrained("microsoft/deberta-base", num_labels=8)
|
9 |
+
bn_state_dict = torch.load('model_weights.pt', map_location=torch.device('cpu'))
|
10 |
model.load_state_dict(bn_state_dict)
|
11 |
return model, tokenizer
|
12 |
|
13 |
def Predict(model, tokenizer, text):
|
14 |
res = tokenizer(s, padding=True, truncation=True, return_tensors="pt", max_length=512)
|
|
|
15 |
res = model(**res)
|
16 |
logits = res.logits.softmax(dim=1)
|
17 |
+
logits = logits.numpy()[0]#.cpu().detach().numpy()[0]
|
18 |
return logits
|
19 |
|
20 |
def Print(logits, dictionary):
|