File size: 4,882 Bytes
fa0a93c
 
 
 
 
 
 
 
 
 
b34e98b
 
fa0a93c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import numpy as np
from math import exp
import torch
from torch import nn
from transformers import BertTokenizer, BertForNextSentencePrediction
import utils
from maddog import Extractor
import spacy
import constant

import spacy.cli
spacy.cli.download("en_core_web_sm")
nlp = spacy.load("en_core_web_sm")
ruleExtractor = Extractor()
kb = utils.load_acronym_kb('acronym_kb.json')
model_path='acrobert.pt'

class AcronymBERT(nn.Module):
    def __init__(self, model_name="bert-base-uncased", device='cpu'):
        super().__init__()
        self.device = device
        self.model = BertForNextSentencePrediction.from_pretrained(model_name)
        self.tokenizer = BertTokenizer.from_pretrained(model_name)

    def forward(self, sentence):

        samples = self.tokenizer(sentence, padding=True, return_tensors='pt', truncation=True)["input_ids"]
        samples = samples.to(self.device)
        outputs = self.model(samples).logits
        scores = nn.Softmax(dim=1)(outputs)[:, 0]

        return scores

model = AcronymBERT(device='cpu')
model.load_state_dict(torch.load(model_path, map_location='cpu'))

def softmax(elements):
    total = sum([exp(e) for e in elements])
    return exp(elements[0]) / total


def predict(topk, model, short_form, context, batch_size, acronym_kb, device):
    ori_candidate = utils.get_candidate(acronym_kb, short_form, can_num=10)
    long_terms = [str.lower(can) for can in ori_candidate]
    scores = cal_score(model.model, model.tokenizer, long_terms, context, batch_size, device)
    #indexes = [np.argmax(scores)]
    topk = min(len(scores), topk)
    indexes = np.array(scores).argsort()[::-1][:topk]
    names = [ori_candidate[i] for i in indexes]
    return names


def cal_score(model, tokenizer, long_forms, contexts, batch_size, device):
    ps = list()
    for index in range(0, len(long_forms), batch_size):
        batch_lf = long_forms[index:index + batch_size]
        batch_ctx = [contexts] * len(batch_lf)
        encoding = tokenizer(batch_lf, batch_ctx, return_tensors="pt", padding=True, truncation=True, max_length=400).to(device)
        outputs = model(**encoding)
        logits = outputs.logits.cpu().detach().numpy()
        p = [softmax(lg) for lg in logits]
        ps.extend(p)
    return ps


def dog_extract(sentence):
    tokens = [t.text for t in nlp(sentence) if len(t.text.strip()) > 0]
    rulebased_pairs = ruleExtractor.extract(tokens, constant.RULES)
    return rulebased_pairs


def acrobert(sentence, model, device):

    model.to(device)

    #params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    #print(params)

    tokens = [t.text for t in nlp(sentence) if len(t.text.strip()) > 0]
    rulebased_pairs = ruleExtractor.extract(tokens, constant.RULES)

    results = list()
    for acronym in rulebased_pairs.keys():
        if rulebased_pairs[acronym][0] != '':
            results.append((acronym, rulebased_pairs[acronym][0]))
        else:
            pred = predict(1, model, acronym, sentence, batch_size=10, acronym_kb=kb, device=device)
            results.append((acronym, pred[0]))
    return results


def popularity(sentence):

    tokens = [t.text for t in nlp(sentence) if len(t.text.strip()) > 0]
    rulebased_pairs = ruleExtractor.extract(tokens, constant.RULES)

    results = list()
    for acronym in rulebased_pairs.keys():
        if rulebased_pairs[acronym][0] != '':
            results.append((acronym, rulebased_pairs[acronym][0]))
        else:

            pred = utils.get_candidate(kb, acronym, can_num=1)
            results.append((acronym, pred[0]))
    return results


def acronym_linker(sentence, mode='acrobert', model=model, device='cpu'):
    if mode == 'acrobert':
        return acrobert(sentence, model, device)
    if mode == 'pop':
        return popularity(sentence)
    raise Exception('mode name should in this list [acrobert, pop]')


if __name__ == '__main__':
    #sentence = \
    #"This new genome assembly and the annotation are tagged as a RefSeq genome by NCBI and thus provide substantially enhanced genomic resources for future research involving S. scovelli."

    #sentence = """ There have been initiated several projects to modernize the network of ECB
#corridors, financed from ispa funds and state-guaranteed loans from international
#financial institutions."""
#     sentence = """A whistleblower like monologist Mike Daisey gets targeted as a scapegoat who must
# be discredited and diminished in the public ’s eye. More often than not, PR is
# a preemptive process. Celebrity publicists are paid lots of money to keep certain
# stories out of the news."""
    sentence = "AI is the ability of a digital computer or computer-controlled robot to perform tasks commonly associated with intelligent beings, including NLP that processes text or document"
    results = acronym_linker(sentence)
    print(results)