|
import os |
|
import json |
|
import tempfile |
|
import sys |
|
import datetime |
|
import re |
|
import string |
|
sys.path.append('mtool') |
|
|
|
import torch |
|
|
|
from model.model import Model |
|
from data.dataset import Dataset |
|
from config.params import Params |
|
from utility.initialize import initialize |
|
from data.batch import Batch |
|
from mtool.main import main as mtool_main |
|
|
|
|
|
from tqdm import tqdm |
|
|
|
class PredictionModel(torch.nn.Module): |
|
def __init__(self, checkpoint_path=os.path.join('models', 'checkpoint.bin'), default_mrp_path=os.path.join('models', 'default.mrp'), verbose=False): |
|
super().__init__() |
|
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
self.checkpoint = torch.load('./models/checkpoint.bin', map_location=torch.device('cpu')) |
|
self.verbose = verbose |
|
self.args = Params().load_state_dict(self.checkpoint['params']) |
|
self.args.log_wandb = False |
|
|
|
self.args.training_data = default_mrp_path |
|
self.args.validation_data = default_mrp_path |
|
self.args.test_data = default_mrp_path |
|
self.args.only_train = False |
|
self.args.encoder = os.path.join('models', 'encoder') |
|
initialize(self.args, init_wandb=False) |
|
self.dataset = Dataset(self.args, verbose=False) |
|
self.model = Model(self.dataset, self.args).to(self.device) |
|
self.model.load_state_dict(self.checkpoint["model"], strict=False) |
|
self.model.eval() |
|
|
|
|
|
def _mrp_to_text(self, mrp_list, graph_mode='labeled-edge'): |
|
framework = 'norec' |
|
with tempfile.NamedTemporaryFile(delete=False, mode='w') as output_text_file: |
|
output_text_filename = output_text_file.name |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, mode='w') as mrp_file: |
|
line = '\n'.join([json.dumps(entry) for entry in mrp_list]) |
|
mrp_file.write(line) |
|
mrp_filename = mrp_file.name |
|
|
|
if graph_mode == 'labeled-edge': |
|
mtool_main([ |
|
'--strings', |
|
'--ids', |
|
'--read', 'mrp', |
|
'--write', framework, |
|
mrp_filename, output_text_filename |
|
]) |
|
elif graph_mode == 'node-centric': |
|
mtool_main([ |
|
'--node_centric', |
|
'--strings', |
|
'--ids', |
|
'--read', 'mrp', |
|
'--write', framework, |
|
mrp_filename, output_text_filename |
|
]) |
|
else: |
|
raise Exception(f'Unknown graph mode: {graph_mode}') |
|
|
|
with open(output_text_filename) as f: |
|
texts = json.load(f) |
|
|
|
os.unlink(output_text_filename) |
|
os.unlink(mrp_filename) |
|
|
|
return texts |
|
|
|
|
|
def clean_texts(self, texts): |
|
punctuation = ''.join([f'\\{s}' for s in string.punctuation]) |
|
texts = [re.sub(f'([{punctuation}])', ' \\1 ', t) for t in texts] |
|
texts = [re.sub(r' +', ' ', t) for t in texts] |
|
return texts |
|
|
|
|
|
def _predict_to_mrp(self, texts, graph_mode='labeled-edge'): |
|
texts = self.clean_texts(texts) |
|
framework, language = self.args.framework, self.args.language |
|
data = self.dataset.load_sentences(texts, self.args) |
|
res_sentences = {f"{i}": {'input': sentence} for i, sentence in enumerate(texts)} |
|
date_str = datetime.datetime.now().date().isoformat() |
|
for key, value_dict in res_sentences.items(): |
|
value_dict['id'] = key |
|
value_dict['time'] = date_str |
|
value_dict['framework'], value_dict['language'] = framework, language |
|
value_dict['nodes'], value_dict['edges'], value_dict['tops'] = [], [], [] |
|
for i, batch in enumerate(tqdm(data) if self.verbose else data): |
|
with torch.no_grad(): |
|
predictions = self.model(Batch.to(batch, self.device), inference=True) |
|
for prediction in predictions: |
|
for key, value in prediction.items(): |
|
res_sentences[prediction['id']][key] = value |
|
return res_sentences |
|
|
|
|
|
def predict(self, text_list, graph_mode='labeled-edge', language='no'): |
|
mrp_predictions = self._predict_to_mrp(text_list, graph_mode) |
|
predictions = self._mrp_to_text(mrp_predictions.values(), graph_mode) |
|
return predictions |
|
|
|
def forward(self, x): |
|
return self.predict(x) |
|
|
|
|