arkmartov commited on
Commit
5ca6967
1 Parent(s): 5157910

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -6,16 +6,15 @@ def Model():
6
  from transformers import DebertaTokenizer, DebertaForSequenceClassification
7
  tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base")
8
  model = DebertaForSequenceClassification.from_pretrained("microsoft/deberta-base", num_labels=8)
9
- bn_state_dict = torch.load('model_weights.pt')
10
  model.load_state_dict(bn_state_dict)
11
  return model, tokenizer
12
 
13
  def Predict(model, tokenizer, text):
14
  res = tokenizer(s, padding=True, truncation=True, return_tensors="pt", max_length=512)
15
- res.to("cuda:0")
16
  res = model(**res)
17
  logits = res.logits.softmax(dim=1)
18
- logits = logits.cpu().detach().numpy()[0]
19
  return logits
20
 
21
  def Print(logits, dictionary):
 
6
  from transformers import DebertaTokenizer, DebertaForSequenceClassification
7
  tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base")
8
  model = DebertaForSequenceClassification.from_pretrained("microsoft/deberta-base", num_labels=8)
9
+ bn_state_dict = torch.load('model_weights.pt', map_location=torch.device('cpu'))
10
  model.load_state_dict(bn_state_dict)
11
  return model, tokenizer
12
 
13
  def Predict(model, tokenizer, text):
14
  res = tokenizer(s, padding=True, truncation=True, return_tensors="pt", max_length=512)
 
15
  res = model(**res)
16
  logits = res.logits.softmax(dim=1)
17
+ logits = logits.numpy()[0]#.cpu().detach().numpy()[0]
18
  return logits
19
 
20
  def Print(logits, dictionary):