import os import json import tempfile import sys import datetime import re import string sys.path.append('mtool') import torch from model.model import Model from data.dataset import Dataset from config.params import Params from utility.initialize import initialize from data.batch import Batch from mtool.main import main as mtool_main from tqdm import tqdm class PredictionModel: def __init__(self, checkpoint_path=os.path.join('models', 'checkpoint.bin'), default_mrp_path=os.path.join('models', 'default.mrp'), verbose=False): self.verbose = verbose self.checkpoint = torch.load('./models/checkpoint.bin', map_location=torch.device('cpu')) self.args = Params().load_state_dict(self.checkpoint['params']) self.args.log_wandb = False self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') self.args.training_data = default_mrp_path self.args.validation_data = default_mrp_path self.args.test_data = default_mrp_path self.args.only_train = False self.args.encoder = os.path.join('models', 'encoder') initialize(self.args, init_wandb=False) self.dataset = Dataset(self.args, verbose=False) self.model = Model(self.dataset, self.args).to(self.device) self.model.load_state_dict(self.checkpoint["model"]) self.model.eval() def _mrp_to_text(self, mrp_list, graph_mode='labeled-edge'): framework = 'norec' with tempfile.NamedTemporaryFile(delete=False, mode='w') as output_text_file: output_text_filename = output_text_file.name with tempfile.NamedTemporaryFile(delete=False, mode='w') as mrp_file: line = '\n'.join([json.dumps(entry) for entry in mrp_list]) mrp_file.write(line) mrp_filename = mrp_file.name if graph_mode == 'labeled-edge': mtool_main([ '--strings', '--ids', '--read', 'mrp', '--write', framework, mrp_filename, output_text_filename ]) elif graph_mode == 'node-centric': mtool_main([ '--node_centric', '--strings', '--ids', '--read', 'mrp', '--write', framework, mrp_filename, output_text_filename ]) else: raise Exception(f'Unknown graph mode: {graph_mode}') with open(output_text_filename) as f: texts = json.load(f) os.unlink(output_text_filename) os.unlink(mrp_filename) return texts def clean_texts(self, texts): punctuation = ''.join([f'\\{s}' for s in string.punctuation]) texts = [re.sub(f'([{punctuation}])', ' \\1 ', t) for t in texts] texts = [re.sub(r' +', ' ', t) for t in texts] return texts def _predict_to_mrp(self, texts, graph_mode='labeled-edge'): texts = self.clean_texts(texts) framework, language = self.args.framework, self.args.language data = self.dataset.load_sentences(texts, self.args) res_sentences = {f"{i}": {'input': sentence} for i, sentence in enumerate(texts)} date_str = datetime.datetime.now().date().isoformat() for key, value_dict in res_sentences.items(): value_dict['id'] = key value_dict['time'] = date_str value_dict['framework'], value_dict['language'] = framework, language value_dict['nodes'], value_dict['edges'], value_dict['tops'] = [], [], [] for i, batch in enumerate(tqdm(data) if self.verbose else data): with torch.no_grad(): predictions = self.model(Batch.to(batch, self.device), inference=True) for prediction in predictions: for key, value in prediction.items(): res_sentences[prediction['id']][key] = value return res_sentences def predict(self, text_list, graph_mode='labeled-edge', language='no'): mrp_predictions = self._predict_to_mrp(text_list, graph_mode) predictions = self._mrp_to_text(mrp_predictions.values(), graph_mode) return predictions