from typing import Dict, List, Any from scipy.special import softmax from utils import clean_str, clean_str_nopunct import torch from transformers import BertTokenizer from utils import MultiHeadModel, BertInputBuilder, get_num_words MODEL_CHECKPOINT='ddemszky/uptake-model' class EndpointHandler(): def __init__(self, path="."): print("Loading models...") self.device = "cuda" if torch.cuda.is_available() else "cpu" self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") self.input_builder = BertInputBuilder(tokenizer=self.tokenizer) self.model = MultiHeadModel.from_pretrained(path, head2size={"nsp": 2}) self.model.to(self.device) self.max_length = 120 def get_clean_text(self, text, remove_punct=False): if remove_punct: return clean_str_nopunct(text) return clean_str(text) def get_prediction(self, instance): instance["attention_mask"] = [[1] * len(instance["input_ids"])] for key in ["input_ids", "token_type_ids", "attention_mask"]: instance[key] = torch.tensor(instance[key]).unsqueeze(0) # Batch size = 1 instance[key].to(self.device) output = self.model(input_ids=instance["input_ids"], attention_mask=instance["attention_mask"], token_type_ids=instance["token_type_ids"], return_pooler_output=False) return output def get_uptake_score(self, utterances, speakerA, speakerB): textA = self.get_clean_text(utterances[speakerA], remove_punct=False) textB = self.get_clean_text(utterances[speakerB], remove_punct=False) instance = self.input_builder.build_inputs([textA], textB, max_length=self.max_length, input_str=True) output = self.get_prediction(instance) uptake_score = softmax(output["nsp_logits"][0].tolist())[1] return uptake_score def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `list`) parameters (:obj: `dict`) Return: A :obj:`list` | `dict`: will be serialized and returned """ # get inputs inputs = data.pop("inputs", data) params = data.pop("parameters", None) utterances = inputs print("EXAMPLES") for utt_pair in utterances[:3]: print("speaker A: %s" % utt_pair[params["speaker_A"]]) print("speaker B: %s" % utt_pair[params["speaker_B"]]) print("----") print("Running inference on %d examples..." % len(utterances)) self.model.eval() uptake_scores = [] with torch.no_grad(): for i, utt in enumerate(utterances): prev_num_words = get_num_words(utt[params["speaker_A"]]) if prev_num_words < params["student_min_words"]: uptake_scores.append(None) continue uptake_score = self.get_uptake_score(utterances=utt, speakerA=params["speaker_A"], speakerB=params["speaker_B"]) uptake_scores.append(uptake_score) return uptake_scores