|
from typing import Dict, List, Any |
|
from transformers import AutoModel, AutoTokenizer |
|
import torch |
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
|
|
self.model = AutoModel.from_pretrained(path, trust_remote_code=True) |
|
self.model.eval() |
|
self.tokenizer = AutoTokenizer.from_pretrained('allenai/led-base-16384') |
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, data: Any) -> List[List[Dict[str, float]]]: |
|
""" |
|
Args: |
|
data (:obj:): |
|
includes the input data and the parameters for the inference. |
|
Return: |
|
A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing : |
|
- "label": A string representing what the label/class is. There can be multiple labels. |
|
- "score": A score between 0 and 1 describing how confident the model is for this label/class. |
|
""" |
|
text = data['inputs'].pop("text", "") |
|
label_tolerance = data['inputs'].pop("label_tolerance", 0) |
|
backup_tolerance = data['inputs'].pop("backup_tolerance", None) |
|
|
|
|
|
inputs = self.preprocess_text(text) |
|
outputs = self.model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']) |
|
|
|
|
|
predictions = self.extract_results(input_ids=inputs['input_ids'][0].tolist(), offset_mapping=inputs['offset_mapping'], logits=outputs['logits'], |
|
label_tolerance=label_tolerance, backup_tolerance=backup_tolerance) |
|
|
|
return predictions |
|
|
|
def preprocess_text(self, text): |
|
|
|
inputs = self.tokenizer(text, return_offsets_mapping=True) |
|
input_ids = torch.tensor([inputs["input_ids"]]) |
|
attention_mask = torch.tensor([inputs["attention_mask"]]) |
|
|
|
return {"input_ids": input_ids, "attention_mask": attention_mask, "offset_mapping": inputs["offset_mapping"]} |
|
|
|
def extract_results(self, input_ids, offset_mapping, logits, label_tolerance=0, backup_tolerance=None): |
|
|
|
def convert_indices_to_result_obj(indices_array): |
|
result_array = [] |
|
if (indices_array): |
|
for result_indices in indices_array: |
|
text = self.tokenizer.decode(input_ids[result_indices[0]:result_indices[-1]]) |
|
indices = [offset_mapping[result_indices[0]-1][1], offset_mapping[result_indices[-2]][1]] |
|
if text != "" and not text.isspace(): |
|
while True: |
|
if text[0] == " ": |
|
text = text[1:] |
|
indices[0] += 1 |
|
else: |
|
break |
|
result_array.append({'text': text, 'indices': indices}) |
|
return result_array |
|
|
|
|
|
|
|
labeled_result_indices = [] |
|
result_indices = [] |
|
for index, token_logits in enumerate(logits.tolist()[0]): |
|
|
|
if (len(result_indices) > 0): |
|
if token_logits[2] > label_tolerance: |
|
result_indices.append(index) |
|
else: |
|
result_indices.append(index) |
|
labeled_result_indices.append(result_indices) |
|
result_indices = [] |
|
|
|
elif (token_logits[1] > label_tolerance): |
|
result_indices.append(index) |
|
|
|
if (len(result_indices) > 0): |
|
labeled_result_indices.append(result_indices) |
|
|
|
|
|
|
|
backup_result_indices = [] |
|
result_indices = [] |
|
if (backup_tolerance): |
|
for index, token_logits in enumerate(logits.tolist()[0]): |
|
|
|
if (len(result_indices) > 0): |
|
if token_logits[2] > backup_tolerance: |
|
result_indices.append(index) |
|
else: |
|
|
|
result_indices.append(index) |
|
overlaps_labeled_result = False |
|
if (len(labeled_result_indices) > 0): |
|
for index in result_indices: |
|
for group in labeled_result_indices: |
|
for labeled_index in group: |
|
if (index == labeled_index): |
|
overlaps_labeled_result = True |
|
if (not overlaps_labeled_result): |
|
backup_result_indices.append(result_indices) |
|
|
|
result_indices = [] |
|
|
|
elif (token_logits[1] > backup_tolerance): |
|
result_indices.append(index) |
|
|
|
|
|
labeled_results = convert_indices_to_result_obj(labeled_result_indices) |
|
backup_results = convert_indices_to_result_obj(backup_result_indices) |
|
|
|
|
|
return {'labeled_results': labeled_results, 'backup_results': backup_results} |