ddemszky commited on
Commit
7800c33
β€’
1 Parent(s): 062ea44

add custom handler

Browse files
__pycache__/handler.cpython-39.pyc ADDED
Binary file (3.15 kB). View file
 
__pycache__/utils.cpython-39.pyc ADDED
Binary file (6.51 kB). View file
 
handler.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from scipy.special import softmax
3
+
4
+ from utils import clean_str, clean_str_nopunct
5
+ import torch
6
+ from transformers import BertTokenizer
7
+ from utils import MultiHeadModel, BertInputBuilder, get_num_words
8
+
9
+ MODEL_CHECKPOINT='ddemszky/uptake-model'
10
+
11
+ class EndpointHandler():
12
+ def __init__(self, path="."):
13
+ print("Loading models...")
14
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
16
+ self.input_builder = BertInputBuilder(tokenizer=self.tokenizer)
17
+ self.model = MultiHeadModel.from_pretrained(path, head2size={"nsp": 2})
18
+ self.model.to(self.device)
19
+ self.max_length = 120
20
+
21
+ def get_clean_text(self, text, remove_punct=False):
22
+ if remove_punct:
23
+ return clean_str_nopunct(text)
24
+ return clean_str(text)
25
+
26
+ def get_prediction(self, instance):
27
+ instance["attention_mask"] = [[1] * len(instance["input_ids"])]
28
+ for key in ["input_ids", "token_type_ids", "attention_mask"]:
29
+ instance[key] = torch.tensor(instance[key]).unsqueeze(0) # Batch size = 1
30
+ instance[key].to(self.device)
31
+
32
+ output = self.model(input_ids=instance["input_ids"],
33
+ attention_mask=instance["attention_mask"],
34
+ token_type_ids=instance["token_type_ids"],
35
+ return_pooler_output=False)
36
+ return output
37
+
38
+ def get_uptake_score(self, utterances, speakerA, speakerB):
39
+
40
+ textA = self.get_clean_text(utterances[speakerA], remove_punct=False)
41
+ textB = self.get_clean_text(utterances[speakerB], remove_punct=False)
42
+
43
+ instance = self.input_builder.build_inputs([textA], textB,
44
+ max_length=self.max_length,
45
+ input_str=True)
46
+ output = self.get_prediction(instance)
47
+ uptake_score = softmax(output["nsp_logits"][0].tolist())[1]
48
+ return uptake_score
49
+
50
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
51
+ """
52
+ data args:
53
+ inputs (:obj: `list`)
54
+ parameters (:obj: `dict`)
55
+ Return:
56
+ A :obj:`list` | `dict`: will be serialized and returned
57
+ """
58
+ # get inputs
59
+ inputs = data.pop("inputs", data)
60
+ params = data.pop("parameters", None)
61
+
62
+ utterances = inputs
63
+ print("EXAMPLES")
64
+ for utt_pair in utterances[:3]:
65
+ print("speaker A: %s" % utt_pair[params["speaker_A"]])
66
+ print("speaker B: %s" % utt_pair[params["speaker_B"]])
67
+ print("----")
68
+
69
+ print("Running inference on %d examples..." % len(utterances))
70
+ self.model.eval()
71
+ uptake_scores = []
72
+ with torch.no_grad():
73
+ for i, utt in enumerate(utterances):
74
+ prev_num_words = get_num_words(utt[params["speaker_A"]])
75
+ if prev_num_words < params["student_min_words"]:
76
+ uptake_scores.append(None)
77
+ continue
78
+ uptake_score = self.get_uptake_score(utterances=utt,
79
+ speakerA=params["speaker_A"],
80
+ speakerB=params["speaker_B"])
81
+ uptake_scores.append(uptake_score)
82
+
83
+ return uptake_scores
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ cleantext==1.1.4
2
+ num2words==0.5.10
3
+ scipy==1.7.3
4
+ torch==1.10.2
5
+ transformers==4.25.1
test.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from handler import EndpointHandler
2
+
3
+ # init handler
4
+ my_handler = EndpointHandler(path=".")
5
+
6
+ # prepare sample payload
7
+ example = {"inputs": [{"speaker_A": "I am quite excited how this will turn out",
8
+ "speaker_B": "I'm excited, too"}],
9
+ "parameters": {"speaker_A": "speaker_A",
10
+ "speaker_B": "speaker_B",
11
+ "student_min_words": 5,
12
+ }}
13
+
14
+ # test the handler
15
+ print(my_handler(example))
utils.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers.models.bert.modeling_bert import BertModel, BertPreTrainedModel
3
+ from torch import nn
4
+ from itertools import chain
5
+ from torch.nn import MSELoss, CrossEntropyLoss
6
+ from cleantext import clean
7
+ from num2words import num2words
8
+ import re
9
+ import string
10
+
11
+ punct_chars = list((set(string.punctuation) | {'’', 'β€˜', '–', 'β€”', '~', '|', 'β€œ', '”', '…', "'", "`", '_'}))
12
+ punct_chars.sort()
13
+ punctuation = ''.join(punct_chars)
14
+ replace = re.compile('[%s]' % re.escape(punctuation))
15
+
16
+ def get_num_words(text):
17
+ if not isinstance(text, str):
18
+ print("%s is not a string" % text)
19
+ text = replace.sub(' ', text)
20
+ text = re.sub(r'\s+', ' ', text)
21
+ text = text.strip()
22
+ text = re.sub(r'\[.+\]', " ", text)
23
+ return len(text.split())
24
+
25
+ def number_to_words(num):
26
+ try:
27
+ return num2words(re.sub(",", "", num))
28
+ except:
29
+ return num
30
+
31
+
32
+ clean_str = lambda s: clean(s,
33
+ fix_unicode=True, # fix various unicode errors
34
+ to_ascii=True, # transliterate to closest ASCII representation
35
+ lower=True, # lowercase text
36
+ no_line_breaks=True, # fully strip line breaks as opposed to only normalizing them
37
+ no_urls=True, # replace all URLs with a special token
38
+ no_emails=True, # replace all email addresses with a special token
39
+ no_phone_numbers=True, # replace all phone numbers with a special token
40
+ no_numbers=True, # replace all numbers with a special token
41
+ no_digits=False, # replace all digits with a special token
42
+ no_currency_symbols=False, # replace all currency symbols with a special token
43
+ no_punct=False, # fully remove punctuation
44
+ replace_with_url="<URL>",
45
+ replace_with_email="<EMAIL>",
46
+ replace_with_phone_number="<PHONE>",
47
+ replace_with_number=lambda m: number_to_words(m.group()),
48
+ replace_with_digit="0",
49
+ replace_with_currency_symbol="<CUR>",
50
+ lang="en"
51
+ )
52
+
53
+ clean_str_nopunct = lambda s: clean(s,
54
+ fix_unicode=True, # fix various unicode errors
55
+ to_ascii=True, # transliterate to closest ASCII representation
56
+ lower=True, # lowercase text
57
+ no_line_breaks=True, # fully strip line breaks as opposed to only normalizing them
58
+ no_urls=True, # replace all URLs with a special token
59
+ no_emails=True, # replace all email addresses with a special token
60
+ no_phone_numbers=True, # replace all phone numbers with a special token
61
+ no_numbers=True, # replace all numbers with a special token
62
+ no_digits=False, # replace all digits with a special token
63
+ no_currency_symbols=False, # replace all currency symbols with a special token
64
+ no_punct=True, # fully remove punctuation
65
+ replace_with_url="<URL>",
66
+ replace_with_email="<EMAIL>",
67
+ replace_with_phone_number="<PHONE>",
68
+ replace_with_number=lambda m: number_to_words(m.group()),
69
+ replace_with_digit="0",
70
+ replace_with_currency_symbol="<CUR>",
71
+ lang="en"
72
+ )
73
+
74
+
75
+
76
+ class MultiHeadModel(BertPreTrainedModel):
77
+ """Pre-trained BERT model that uses our loss functions"""
78
+
79
+ def __init__(self, config, head2size):
80
+ super(MultiHeadModel, self).__init__(config, head2size)
81
+ config.num_labels = 1
82
+ self.bert = BertModel(config)
83
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
84
+ module_dict = {}
85
+ for head_name, num_labels in head2size.items():
86
+ module_dict[head_name] = nn.Linear(config.hidden_size, num_labels)
87
+ self.heads = nn.ModuleDict(module_dict)
88
+
89
+ self.init_weights()
90
+
91
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None,
92
+ head2labels=None, return_pooler_output=False, head2mask=None,
93
+ nsp_loss_weights=None):
94
+
95
+ device = "cuda" if torch.cuda.is_available() else "cpu"
96
+
97
+ # Get logits
98
+ output = self.bert(
99
+ input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
100
+ output_attentions=False, output_hidden_states=False, return_dict=True)
101
+ pooled_output = self.dropout(output["pooler_output"]).to(device)
102
+
103
+ head2logits = {}
104
+ return_dict = {}
105
+ for head_name, head in self.heads.items():
106
+ head2logits[head_name] = self.heads[head_name](pooled_output)
107
+ head2logits[head_name] = head2logits[head_name].float()
108
+ return_dict[head_name + "_logits"] = head2logits[head_name]
109
+
110
+
111
+ if head2labels is not None:
112
+ for head_name, labels in head2labels.items():
113
+ num_classes = head2logits[head_name].shape[1]
114
+
115
+ # Regression (e.g. for politeness)
116
+ if num_classes == 1:
117
+
118
+ # Only consider positive examples
119
+ if head2mask is not None and head_name in head2mask:
120
+ num_positives = head2labels[head2mask[head_name]].sum() # use certain labels as mask
121
+ if num_positives == 0:
122
+ return_dict[head_name + "_loss"] = torch.tensor([0]).to(device)
123
+ else:
124
+ loss_fct = MSELoss(reduction='none')
125
+ loss = loss_fct(head2logits[head_name].view(-1), labels.float().view(-1))
126
+ return_dict[head_name + "_loss"] = loss.dot(head2labels[head2mask[head_name]].float().view(-1)) / num_positives
127
+ else:
128
+ loss_fct = MSELoss()
129
+ return_dict[head_name + "_loss"] = loss_fct(head2logits[head_name].view(-1), labels.float().view(-1))
130
+ else:
131
+ loss_fct = CrossEntropyLoss(weight=nsp_loss_weights.float())
132
+ return_dict[head_name + "_loss"] = loss_fct(head2logits[head_name], labels.view(-1))
133
+
134
+
135
+ if return_pooler_output:
136
+ return_dict["pooler_output"] = output["pooler_output"]
137
+
138
+ return return_dict
139
+
140
+ class InputBuilder(object):
141
+ """Base class for building inputs from segments."""
142
+
143
+ def __init__(self, tokenizer):
144
+ self.tokenizer = tokenizer
145
+ self.mask = [tokenizer.mask_token_id]
146
+
147
+ def build_inputs(self, history, reply, max_length):
148
+ raise NotImplementedError
149
+
150
+ def mask_seq(self, sequence, seq_id):
151
+ sequence[seq_id] = self.mask
152
+ return sequence
153
+
154
+ @classmethod
155
+ def _combine_sequence(self, history, reply, max_length, flipped=False):
156
+ # Trim all inputs to max_length
157
+ history = [s[:max_length] for s in history]
158
+ reply = reply[:max_length]
159
+ if flipped:
160
+ return [reply] + history
161
+ return history + [reply]
162
+
163
+
164
+ class BertInputBuilder(InputBuilder):
165
+ """Processor for BERT inputs"""
166
+
167
+ def __init__(self, tokenizer):
168
+ InputBuilder.__init__(self, tokenizer)
169
+ self.cls = [tokenizer.cls_token_id]
170
+ self.sep = [tokenizer.sep_token_id]
171
+ self.model_inputs = ["input_ids", "token_type_ids", "attention_mask"]
172
+ self.padded_inputs = ["input_ids", "token_type_ids"]
173
+ self.flipped = False
174
+
175
+
176
+ def build_inputs(self, history, reply, max_length, input_str=True):
177
+ """See base class."""
178
+ if input_str:
179
+ history = [self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(t)) for t in history]
180
+ reply = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(reply))
181
+ sequence = self._combine_sequence(history, reply, max_length, self.flipped)
182
+ sequence = [s + self.sep for s in sequence]
183
+ sequence[0] = self.cls + sequence[0]
184
+
185
+ instance = {}
186
+ instance["input_ids"] = list(chain(*sequence))
187
+ last_speaker = 0
188
+ other_speaker = 1
189
+ seq_length = len(sequence)
190
+ instance["token_type_ids"] = [last_speaker if ((seq_length - i) % 2 == 1) else other_speaker
191
+ for i, s in enumerate(sequence) for _ in s]
192
+ return instance