File size: 5,610 Bytes
f7a2749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fed03d4
 
 
f7a2749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd09245
3535ad1
4c87f22
 
 
 
 
 
 
4983fe7
f7a2749
 
 
 
 
 
 
 
 
 
 
 
3535ad1
f7a2749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3535ad1
f7a2749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf8eaa2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from typing import  Dict, List, Any
from transformers import AutoModel, AutoTokenizer
import torch


class EndpointHandler():
    def __init__(self, path=""):
        # load the optimized model
        self.model = AutoModel.from_pretrained(path, trust_remote_code=True)
        self.model.eval()
        self.tokenizer = AutoTokenizer.from_pretrained('allenai/led-base-16384')
        
        # create inference pipeline
        #self.pipeline = pipeline("token-classification", model=model, tokenizer=tokenizer)


    def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
        """
        Args:
            data (:obj:):
                includes the input data and the parameters for the inference.
        Return:
            A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing :
                - "label": A string representing what the label/class is. There can be multiple labels.
                - "score": A score between 0 and 1 describing how confident the model is for this label/class.
        """
        text = data['inputs'].pop("text", "")
        label_tolerance = data['inputs'].pop("label_tolerance", 0)
        backup_tolerance = data['inputs'].pop("backup_tolerance", None)
        
        # Return labeled results and backup results based on tolerances
        inputs = self.preprocess_text(text)
        outputs = self.model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])

        # Extract labeled results
        predictions = self.extract_results(input_ids=inputs['input_ids'][0].tolist(), offset_mapping=inputs['offset_mapping'], logits=outputs['logits'], 
            label_tolerance=label_tolerance, backup_tolerance=backup_tolerance)
        
        return predictions
    
    def preprocess_text(self, text):
        
        inputs = self.tokenizer(text, return_offsets_mapping=True)
        input_ids = torch.tensor([inputs["input_ids"]])#, dtype=torch.fp32)
        attention_mask = torch.tensor([inputs["attention_mask"]])#, dtype=torch.fp32)
        
        return {"input_ids": input_ids, "attention_mask": attention_mask, "offset_mapping": inputs["offset_mapping"]}

    def extract_results(self, input_ids, offset_mapping, logits, label_tolerance=0, backup_tolerance=None):

        def convert_indices_to_result_obj(indices_array):
            result_array = []
            if (indices_array):
                for result_indices in indices_array:
                    text = self.tokenizer.decode(input_ids[result_indices[0]:result_indices[-1]])
                    indices = [offset_mapping[result_indices[0]-1][1], offset_mapping[result_indices[-2]][1]]
                    if text != "" and not text.isspace():
                        while True:
                            if text[0] == " ":
                                text = text[1:]
                                indices[0] += 1
                            else:
                                break
                        result_array.append({'text': text, 'indices': indices})
            return result_array
                
        
        # Extract labeled results first
        labeled_result_indices = []
        result_indices = []
        for index, token_logits in enumerate(logits.tolist()[0]):

            if (len(result_indices) > 0):
                if token_logits[2] > label_tolerance:
                    result_indices.append(index)
                else:
                    result_indices.append(index)
                    labeled_result_indices.append(result_indices)
                    result_indices = []
                    
            elif (token_logits[1] > label_tolerance):
                result_indices.append(index)

        if (len(result_indices) > 0):
            labeled_result_indices.append(result_indices)

        
        # Extract backup results, avoiding overlapping with labeled results
        backup_result_indices = []
        result_indices = []
        if (backup_tolerance):
            for index, token_logits in enumerate(logits.tolist()[0]):

                if (len(result_indices) > 0):
                    if token_logits[2] > backup_tolerance:
                        result_indices.append(index)
                    else:
                        # Check if backup result overlaps at all with any labeled result. If it does just ignore it
                        result_indices.append(index)
                        overlaps_labeled_result = False
                        if (len(labeled_result_indices) > 0):
                            for index in result_indices:
                                for group in labeled_result_indices:
                                    for labeled_index in group:
                                        if (index == labeled_index):
                                            overlaps_labeled_result = True
                        if (not overlaps_labeled_result):         
                            backup_result_indices.append(result_indices)
                        
                        result_indices = []
                        
                elif (token_logits[1] > backup_tolerance):
                    result_indices.append(index)

        # Convert both labeled results and backup results to {name: "", indices: []}
        labeled_results = convert_indices_to_result_obj(labeled_result_indices)
        backup_results = convert_indices_to_result_obj(backup_result_indices)
            
            
        return {'labeled_results': labeled_results, 'backup_results': backup_results}