ekolasky commited on
Commit
f7a2749
1 Parent(s): f5b20aa

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +115 -0
handler.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import AutoModel, AutoTokenizer
3
+ import torch
4
+
5
+
6
+ class EndpointHandler():
7
+ def __init__(self, path=""):
8
+ # load the optimized model
9
+ self.model = AutoModel.from_pretrained(path, trust_remote_code=True)
10
+ self.model.eval()
11
+ self.tokenizer = AutoTokenizer.from_pretrained('allenai/led-base-16384')
12
+
13
+ # create inference pipeline
14
+ #self.pipeline = pipeline("token-classification", model=model, tokenizer=tokenizer)
15
+
16
+
17
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
18
+ """
19
+ Args:
20
+ data (:obj:):
21
+ includes the input data and the parameters for the inference.
22
+ Return:
23
+ A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing :
24
+ - "label": A string representing what the label/class is. There can be multiple labels.
25
+ - "score": A score between 0 and 1 describing how confident the model is for this label/class.
26
+ """
27
+ text = data.pop("text", data)
28
+ label_tolerance = data.pop("label_tolerance", 0)
29
+ backup_tolerance = data.pop("backup_tolerance", None)
30
+
31
+ # Return labeled results and backup results based on tolerances
32
+ inputs = self.preprocess_text(text)
33
+ outputs = self.model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
34
+
35
+ # Extract labeled results
36
+ predictions = self.extract_results(input_ids=inputs['input_ids'][0].tolist(), offset_mapping=inputs['offset_mapping'], logits=outputs['logits'],
37
+ label_tolerance=label_tolerance, backup_tolerance=backup_tolerance)
38
+
39
+ return predictions
40
+
41
+ def preprocess_text(self, text):
42
+
43
+ inputs = self.tokenizer(text, return_offsets_mapping=True)
44
+ input_ids = torch.tensor([inputs["input_ids"]])#, dtype=torch.fp32)
45
+ attention_mask = torch.tensor([inputs["attention_mask"]])#, dtype=torch.fp32)
46
+
47
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "offset_mapping": inputs["offset_mapping"]}
48
+
49
+ def extract_results(self, input_ids, offset_mapping, logits, label_tolerance=0, backup_tolerance=None):
50
+
51
+ def convert_indices_to_result_obj(indices_array):
52
+ result_array = []
53
+ if (indices_array):
54
+ for result_indices in indices_array:
55
+ name = self.tokenizer.decode(input_ids[result_indices[0]:result_indices[-1]])
56
+ indices = [offset_mapping[result_indices[0]][0], offset_mapping[result_indices[-1]][0]]
57
+ result_array.append({'name': name, 'indices': indices})
58
+ return result_array
59
+
60
+
61
+ # Extract labeled results first
62
+ labeled_result_indices = []
63
+ result_indices = []
64
+ for index, token_logits in enumerate(logits.tolist()[0]):
65
+
66
+ if (len(result_indices) > 0):
67
+ if token_logits[2] > label_tolerance:
68
+ result_indices.append(index)
69
+ else:
70
+ print("Adding result")
71
+ labeled_result_indices.append(result_indices)
72
+ result_indices = []
73
+
74
+ elif (token_logits[1] > label_tolerance):
75
+ result_indices.append(index)
76
+
77
+ if (len(result_indices) > 0):
78
+ labeled_result_indices.append(result_indices)
79
+
80
+ print(labeled_result_indices)
81
+
82
+
83
+ # Extract backup results, avoiding overlapping with labeled results
84
+ backup_result_indices = []
85
+ result_indices = []
86
+ if (backup_tolerance):
87
+ for index, token_logits in enumerate(logits.tolist()[0]):
88
+
89
+ if (len(result_indices) > 0):
90
+ if token_logits[2] > backup_tolerance:
91
+ result_indices.append(index)
92
+ else:
93
+ # Check if backup result overlaps at all with any labeled result. If it does just ignore it
94
+ overlaps_labeled_result = False
95
+ if (len(labeled_result_indices) > 0):
96
+ for index in result_indices:
97
+ for group in labeled_result_indices:
98
+ for labeled_index in group:
99
+ if (index == labeled_index):
100
+ overlaps_labeled_result = True
101
+ if (not overlaps_labeled_result):
102
+ backup_result_indices.append(result_indices)
103
+
104
+ result_indices = []
105
+
106
+ elif (token_logits[1] > backup_tolerance):
107
+ result_indices.append(index)
108
+
109
+ # Convert both labeled results and backup results to {name: "", indices: []}
110
+ labeled_results = convert_indices_to_result_obj(labeled_result_indices)
111
+ backup_results = convert_indices_to_result_obj(backup_result_indices)
112
+
113
+
114
+ return {'labeled_results': labeled_results, 'backup_results': backup_results}
115
+