Larisa Kolesnichenko
Make processing of punctuation consistent with train data: frame each symbol with spaces
099a2f3
import os | |
import json | |
import tempfile | |
import sys | |
import datetime | |
import re | |
import string | |
sys.path.append('mtool') | |
import torch | |
from model.model import Model | |
from data.dataset import Dataset | |
from config.params import Params | |
from utility.initialize import initialize | |
from data.batch import Batch | |
from mtool.main import main as mtool_main | |
from tqdm import tqdm | |
class PredictionModel: | |
def __init__(self, checkpoint_path=os.path.join('models', 'checkpoint.bin'), default_mrp_path=os.path.join('models', 'default.mrp'), verbose=False): | |
self.verbose = verbose | |
self.checkpoint = torch.load('./models/checkpoint.bin', map_location=torch.device('cpu')) | |
self.args = Params().load_state_dict(self.checkpoint['params']) | |
self.args.log_wandb = False | |
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
self.args.training_data = default_mrp_path | |
self.args.validation_data = default_mrp_path | |
self.args.test_data = default_mrp_path | |
self.args.only_train = False | |
self.args.encoder = os.path.join('models', 'encoder') | |
initialize(self.args, init_wandb=False) | |
self.dataset = Dataset(self.args, verbose=False) | |
self.model = Model(self.dataset, self.args).to(self.device) | |
self.model.load_state_dict(self.checkpoint["model"]) | |
self.model.eval() | |
def _mrp_to_text(self, mrp_list, graph_mode='labeled-edge'): | |
framework = 'norec' | |
with tempfile.NamedTemporaryFile(delete=False, mode='w') as output_text_file: | |
output_text_filename = output_text_file.name | |
with tempfile.NamedTemporaryFile(delete=False, mode='w') as mrp_file: | |
line = '\n'.join([json.dumps(entry) for entry in mrp_list]) | |
mrp_file.write(line) | |
mrp_filename = mrp_file.name | |
if graph_mode == 'labeled-edge': | |
mtool_main([ | |
'--strings', | |
'--ids', | |
'--read', 'mrp', | |
'--write', framework, | |
mrp_filename, output_text_filename | |
]) | |
elif graph_mode == 'node-centric': | |
mtool_main([ | |
'--node_centric', | |
'--strings', | |
'--ids', | |
'--read', 'mrp', | |
'--write', framework, | |
mrp_filename, output_text_filename | |
]) | |
else: | |
raise Exception(f'Unknown graph mode: {graph_mode}') | |
with open(output_text_filename) as f: | |
texts = json.load(f) | |
os.unlink(output_text_filename) | |
os.unlink(mrp_filename) | |
return texts | |
def clean_texts(self, texts): | |
punctuation = ''.join([f'\\{s}' for s in string.punctuation]) | |
texts = [re.sub(f'([{punctuation}])', ' \\1 ', t) for t in texts] | |
texts = [re.sub(r' +', ' ', t) for t in texts] | |
return texts | |
def _predict_to_mrp(self, texts, graph_mode='labeled-edge'): | |
texts = self.clean_texts(texts) | |
framework, language = self.args.framework, self.args.language | |
data = self.dataset.load_sentences(texts, self.args) | |
res_sentences = {f"{i}": {'input': sentence} for i, sentence in enumerate(texts)} | |
date_str = datetime.datetime.now().date().isoformat() | |
for key, value_dict in res_sentences.items(): | |
value_dict['id'] = key | |
value_dict['time'] = date_str | |
value_dict['framework'], value_dict['language'] = framework, language | |
value_dict['nodes'], value_dict['edges'], value_dict['tops'] = [], [], [] | |
for i, batch in enumerate(tqdm(data) if self.verbose else data): | |
with torch.no_grad(): | |
predictions = self.model(Batch.to(batch, self.device), inference=True) | |
for prediction in predictions: | |
for key, value in prediction.items(): | |
res_sentences[prediction['id']][key] = value | |
return res_sentences | |
def predict(self, text_list, graph_mode='labeled-edge', language='no'): | |
mrp_predictions = self._predict_to_mrp(text_list, graph_mode) | |
predictions = self._mrp_to_text(mrp_predictions.values(), graph_mode) | |
return predictions |