File size: 1,039 Bytes
154ca7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os, sys

myPath = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, myPath + '/../../')

# ==========

import torch

from modules.prediction.model_loader import load_checkpoint
from modules.prediction.ERCBCM import ERCBCM
from modules.tokenizer import tokenizer, normalize_v2, PAD_TOKEN_ID

erc_root_folder = './model'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ==========

model_for_evaluate = ERCBCM().to(device)

def prepare():
    load_checkpoint(erc_root_folder + '/model.pt', model_for_evaluate, device)

def predict(sentence, name):
    label = torch.tensor([0])
    label = label.type(torch.LongTensor)
    label = label.to(device)
    text = tokenizer.encode(normalize_v2(sentence, name))
    text += [PAD_TOKEN_ID] * (128 - len(text))
    text = torch.tensor([text])
    text = text.type(torch.LongTensor)
    text = text.to(device)
    _, output = model_for_evaluate(text, label)
    pred = torch.argmax(output, 1).tolist()[0]
    return 'CALLING' if pred == 1 else 'MENTIONING'