File size: 4,368 Bytes
991f07c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import json
import tempfile
import sys
import datetime
import re
import string
sys.path.append('mtool')

import torch

from model.model import Model
from data.dataset import Dataset
from config.params import Params
from utility.initialize import initialize
from data.batch import Batch
from mtool.main import main as mtool_main


from tqdm import tqdm

class PredictionModel(torch.nn.Module):
    def __init__(self, checkpoint_path=os.path.join('models', 'checkpoint.bin'), default_mrp_path=os.path.join('models', 'default.mrp'), verbose=False):
        super().__init__()
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.checkpoint = torch.load('./models/checkpoint.bin', map_location=torch.device('cpu'))
        self.verbose = verbose
        self.args = Params().load_state_dict(self.checkpoint['params'])
        self.args.log_wandb = False

        self.args.training_data = default_mrp_path
        self.args.validation_data = default_mrp_path
        self.args.test_data = default_mrp_path
        self.args.only_train = False
        self.args.encoder = os.path.join('models', 'encoder')
        initialize(self.args, init_wandb=False)
        self.dataset = Dataset(self.args, verbose=False)
        self.model = Model(self.dataset, self.args).to(self.device)
        self.model.load_state_dict(self.checkpoint["model"], strict=False)
        self.model.eval()


    def _mrp_to_text(self, mrp_list, graph_mode='labeled-edge'):
        framework = 'norec'
        with tempfile.NamedTemporaryFile(delete=False, mode='w') as output_text_file:
            output_text_filename = output_text_file.name

        with tempfile.NamedTemporaryFile(delete=False, mode='w') as mrp_file:
            line = '\n'.join([json.dumps(entry) for entry in mrp_list])
            mrp_file.write(line)
            mrp_filename = mrp_file.name

        if graph_mode == 'labeled-edge':
            mtool_main([
                '--strings',
                '--ids',
                '--read', 'mrp',
                '--write', framework,
                mrp_filename, output_text_filename
            ])
        elif graph_mode == 'node-centric':
            mtool_main([
                '--node_centric',
                '--strings',
                '--ids',
                '--read', 'mrp',
                '--write', framework,
                mrp_filename, output_text_filename
            ])
        else:
            raise Exception(f'Unknown graph mode: {graph_mode}')
        
        with open(output_text_filename) as f:
            texts = json.load(f)

        os.unlink(output_text_filename)
        os.unlink(mrp_filename)

        return texts


    def clean_texts(self, texts):
        punctuation = ''.join([f'\\{s}' for s in string.punctuation])
        texts = [re.sub(f'([{punctuation}])', ' \\1 ', t) for t in texts]
        texts = [re.sub(r' +', ' ', t) for t in texts]
        return texts


    def _predict_to_mrp(self, texts, graph_mode='labeled-edge'):        
        texts = self.clean_texts(texts)
        framework, language = self.args.framework, self.args.language
        data = self.dataset.load_sentences(texts, self.args)
        res_sentences = {f"{i}": {'input': sentence} for i, sentence in enumerate(texts)}
        date_str = datetime.datetime.now().date().isoformat()
        for key, value_dict in res_sentences.items():
            value_dict['id'] = key
            value_dict['time'] = date_str
            value_dict['framework'], value_dict['language'] = framework, language
            value_dict['nodes'], value_dict['edges'], value_dict['tops'] = [], [], []
        for i, batch in enumerate(tqdm(data) if self.verbose else data):
            with torch.no_grad():
                predictions = self.model(Batch.to(batch, self.device), inference=True)
                for prediction in predictions:
                    for key, value in prediction.items():
                        res_sentences[prediction['id']][key] = value
        return res_sentences


    def predict(self, text_list, graph_mode='labeled-edge', language='no'):
        mrp_predictions = self._predict_to_mrp(text_list, graph_mode)
        predictions = self._mrp_to_text(mrp_predictions.values(), graph_mode)
        return predictions

    def forward(self, x):
        return self.predict(x)