File size: 3,392 Bytes
7800c33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from typing import Dict, List, Any
from scipy.special import softmax

from utils import clean_str, clean_str_nopunct
import torch
from transformers import BertTokenizer
from utils import MultiHeadModel, BertInputBuilder, get_num_words

MODEL_CHECKPOINT='ddemszky/uptake-model'

class EndpointHandler():
    def __init__(self, path="."):
        print("Loading models...")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        self.input_builder = BertInputBuilder(tokenizer=self.tokenizer)
        self.model = MultiHeadModel.from_pretrained(path, head2size={"nsp": 2})
        self.model.to(self.device)
        self.max_length = 120

    def get_clean_text(self, text, remove_punct=False):
        if remove_punct:
            return clean_str_nopunct(text)
        return clean_str(text)

    def get_prediction(self, instance):
        instance["attention_mask"] = [[1] * len(instance["input_ids"])]
        for key in ["input_ids", "token_type_ids", "attention_mask"]:
            instance[key] = torch.tensor(instance[key]).unsqueeze(0)  # Batch size = 1
            instance[key].to(self.device)

        output = self.model(input_ids=instance["input_ids"],
                       attention_mask=instance["attention_mask"],
                       token_type_ids=instance["token_type_ids"],
                       return_pooler_output=False)
        return output

    def get_uptake_score(self, utterances, speakerA, speakerB):

        textA = self.get_clean_text(utterances[speakerA], remove_punct=False)
        textB = self.get_clean_text(utterances[speakerB], remove_punct=False)

        instance = self.input_builder.build_inputs([textA], textB,
                                              max_length=self.max_length,
                                              input_str=True)
        output = self.get_prediction(instance)
        uptake_score = softmax(output["nsp_logits"][0].tolist())[1]
        return uptake_score

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
       data args:
            inputs (:obj: `list`)
            parameters (:obj: `dict`)
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        # get inputs
        inputs = data.pop("inputs", data)
        params = data.pop("parameters", None)

        utterances = inputs
        print("EXAMPLES")
        for utt_pair in utterances[:3]:
            print("speaker A: %s" % utt_pair[params["speaker_A"]])
            print("speaker B: %s" % utt_pair[params["speaker_B"]])
            print("----")

        print("Running inference on %d examples..." % len(utterances))
        self.model.eval()
        uptake_scores = []
        with torch.no_grad():
            for i, utt in enumerate(utterances):
                prev_num_words = get_num_words(utt[params["speaker_A"]])
                if prev_num_words < params["student_min_words"]:
                    uptake_scores.append(None)
                    continue
                uptake_score = self.get_uptake_score(utterances=utt,
                                                speakerA=params["speaker_A"],
                                                speakerB=params["speaker_B"])
                uptake_scores.append(uptake_score)

        return uptake_scores