File size: 4,368 Bytes
991f07c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import os
import json
import tempfile
import sys
import datetime
import re
import string
sys.path.append('mtool')
import torch
from model.model import Model
from data.dataset import Dataset
from config.params import Params
from utility.initialize import initialize
from data.batch import Batch
from mtool.main import main as mtool_main
from tqdm import tqdm
class PredictionModel(torch.nn.Module):
def __init__(self, checkpoint_path=os.path.join('models', 'checkpoint.bin'), default_mrp_path=os.path.join('models', 'default.mrp'), verbose=False):
super().__init__()
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.checkpoint = torch.load('./models/checkpoint.bin', map_location=torch.device('cpu'))
self.verbose = verbose
self.args = Params().load_state_dict(self.checkpoint['params'])
self.args.log_wandb = False
self.args.training_data = default_mrp_path
self.args.validation_data = default_mrp_path
self.args.test_data = default_mrp_path
self.args.only_train = False
self.args.encoder = os.path.join('models', 'encoder')
initialize(self.args, init_wandb=False)
self.dataset = Dataset(self.args, verbose=False)
self.model = Model(self.dataset, self.args).to(self.device)
self.model.load_state_dict(self.checkpoint["model"], strict=False)
self.model.eval()
def _mrp_to_text(self, mrp_list, graph_mode='labeled-edge'):
framework = 'norec'
with tempfile.NamedTemporaryFile(delete=False, mode='w') as output_text_file:
output_text_filename = output_text_file.name
with tempfile.NamedTemporaryFile(delete=False, mode='w') as mrp_file:
line = '\n'.join([json.dumps(entry) for entry in mrp_list])
mrp_file.write(line)
mrp_filename = mrp_file.name
if graph_mode == 'labeled-edge':
mtool_main([
'--strings',
'--ids',
'--read', 'mrp',
'--write', framework,
mrp_filename, output_text_filename
])
elif graph_mode == 'node-centric':
mtool_main([
'--node_centric',
'--strings',
'--ids',
'--read', 'mrp',
'--write', framework,
mrp_filename, output_text_filename
])
else:
raise Exception(f'Unknown graph mode: {graph_mode}')
with open(output_text_filename) as f:
texts = json.load(f)
os.unlink(output_text_filename)
os.unlink(mrp_filename)
return texts
def clean_texts(self, texts):
punctuation = ''.join([f'\\{s}' for s in string.punctuation])
texts = [re.sub(f'([{punctuation}])', ' \\1 ', t) for t in texts]
texts = [re.sub(r' +', ' ', t) for t in texts]
return texts
def _predict_to_mrp(self, texts, graph_mode='labeled-edge'):
texts = self.clean_texts(texts)
framework, language = self.args.framework, self.args.language
data = self.dataset.load_sentences(texts, self.args)
res_sentences = {f"{i}": {'input': sentence} for i, sentence in enumerate(texts)}
date_str = datetime.datetime.now().date().isoformat()
for key, value_dict in res_sentences.items():
value_dict['id'] = key
value_dict['time'] = date_str
value_dict['framework'], value_dict['language'] = framework, language
value_dict['nodes'], value_dict['edges'], value_dict['tops'] = [], [], []
for i, batch in enumerate(tqdm(data) if self.verbose else data):
with torch.no_grad():
predictions = self.model(Batch.to(batch, self.device), inference=True)
for prediction in predictions:
for key, value in prediction.items():
res_sentences[prediction['id']][key] = value
return res_sentences
def predict(self, text_list, graph_mode='labeled-edge', language='no'):
mrp_predictions = self._predict_to_mrp(text_list, graph_mode)
predictions = self._mrp_to_text(mrp_predictions.values(), graph_mode)
return predictions
def forward(self, x):
return self.predict(x)
|