File size: 5,256 Bytes
f7a2749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fed03d4
 
 
f7a2749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4983fe7
f7a2749
4983fe7
 
f7a2749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf8eaa2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from typing import  Dict, List, Any
from transformers import AutoModel, AutoTokenizer
import torch


class EndpointHandler():
    def __init__(self, path=""):
        # load the optimized model
        self.model = AutoModel.from_pretrained(path, trust_remote_code=True)
        self.model.eval()
        self.tokenizer = AutoTokenizer.from_pretrained('allenai/led-base-16384')
        
        # create inference pipeline
        #self.pipeline = pipeline("token-classification", model=model, tokenizer=tokenizer)


    def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
        """
        Args:
            data (:obj:):
                includes the input data and the parameters for the inference.
        Return:
            A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing :
                - "label": A string representing what the label/class is. There can be multiple labels.
                - "score": A score between 0 and 1 describing how confident the model is for this label/class.
        """
        text = data['inputs'].pop("text", "")
        label_tolerance = data['inputs'].pop("label_tolerance", 0)
        backup_tolerance = data['inputs'].pop("backup_tolerance", None)
        
        # Return labeled results and backup results based on tolerances
        inputs = self.preprocess_text(text)
        outputs = self.model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])

        # Extract labeled results
        predictions = self.extract_results(input_ids=inputs['input_ids'][0].tolist(), offset_mapping=inputs['offset_mapping'], logits=outputs['logits'], 
            label_tolerance=label_tolerance, backup_tolerance=backup_tolerance)
        
        return predictions
    
    def preprocess_text(self, text):
        
        inputs = self.tokenizer(text, return_offsets_mapping=True)
        input_ids = torch.tensor([inputs["input_ids"]])#, dtype=torch.fp32)
        attention_mask = torch.tensor([inputs["attention_mask"]])#, dtype=torch.fp32)
        
        return {"input_ids": input_ids, "attention_mask": attention_mask, "offset_mapping": inputs["offset_mapping"]}

    def extract_results(self, input_ids, offset_mapping, logits, label_tolerance=0, backup_tolerance=None):

        def convert_indices_to_result_obj(indices_array):
            result_array = []
            if (indices_array):
                for result_indices in indices_array:
                    text = self.tokenizer.decode(input_ids[result_indices[0]:result_indices[-1]]).strip()
                    indices = [offset_mapping[result_indices[0]][0], offset_mapping[result_indices[-1]][0]]
                    if text != " " and text != "":
                        result_array.append({'text': text, 'indices': indices})
            return result_array
                
        
        # Extract labeled results first
        labeled_result_indices = []
        result_indices = []
        for index, token_logits in enumerate(logits.tolist()[0]):

            if (len(result_indices) > 0):
                if token_logits[2] > label_tolerance:
                    result_indices.append(index)
                else:
                    labeled_result_indices.append(result_indices)
                    result_indices = []
                    
            elif (token_logits[1] > label_tolerance):
                result_indices.append(index)

        if (len(result_indices) > 0):
            labeled_result_indices.append(result_indices)

        
        # Extract backup results, avoiding overlapping with labeled results
        backup_result_indices = []
        result_indices = []
        if (backup_tolerance):
            for index, token_logits in enumerate(logits.tolist()[0]):

                if (len(result_indices) > 0):
                    if token_logits[2] > backup_tolerance:
                        result_indices.append(index)
                    else:
                        # Check if backup result overlaps at all with any labeled result. If it does just ignore it
                        overlaps_labeled_result = False
                        if (len(labeled_result_indices) > 0):
                            for index in result_indices:
                                for group in labeled_result_indices:
                                    for labeled_index in group:
                                        if (index == labeled_index):
                                            overlaps_labeled_result = True
                        if (not overlaps_labeled_result):         
                            backup_result_indices.append(result_indices)
                        
                        result_indices = []
                        
                elif (token_logits[1] > backup_tolerance):
                    result_indices.append(index)

        # Convert both labeled results and backup results to {name: "", indices: []}
        labeled_results = convert_indices_to_result_obj(labeled_result_indices)
        backup_results = convert_indices_to_result_obj(backup_result_indices)
            
            
        return {'labeled_results': labeled_results, 'backup_results': backup_results}