Spaces:
Runtime error
Runtime error
File size: 1,039 Bytes
154ca7b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import os, sys
myPath = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, myPath + '/../../')
# ==========
import torch
from modules.prediction.model_loader import load_checkpoint
from modules.prediction.ERCBCM import ERCBCM
from modules.tokenizer import tokenizer, normalize_v2, PAD_TOKEN_ID
erc_root_folder = './model'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# ==========
model_for_evaluate = ERCBCM().to(device)
def prepare():
load_checkpoint(erc_root_folder + '/model.pt', model_for_evaluate, device)
def predict(sentence, name):
label = torch.tensor([0])
label = label.type(torch.LongTensor)
label = label.to(device)
text = tokenizer.encode(normalize_v2(sentence, name))
text += [PAD_TOKEN_ID] * (128 - len(text))
text = torch.tensor([text])
text = text.type(torch.LongTensor)
text = text.to(device)
_, output = model_for_evaluate(text, label)
pred = torch.argmax(output, 1).tolist()[0]
return 'CALLING' if pred == 1 else 'MENTIONING' |