File size: 4,638 Bytes
bf66e5a
366e62e
bf66e5a
 
a36be93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf66e5a
 
a36be93
bf66e5a
355a0ec
000ad8b
a36be93
 
bf66e5a
216cf30
bf66e5a
a36be93
 
 
 
 
bf66e5a
 
4c4f932
bf66e5a
a36be93
4c4f932
66e62c6
000ad8b
a36be93
 
813fd4a
355a0ec
a36be93
 
813fd4a
 
a36be93
 
6d8b690
a36be93
 
 
bf66e5a
 
 
 
 
 
 
 
 
 
366e62e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, StoppingCriteria, StoppingCriteriaList


# class EndpointHandler():
#     def __init__(self, path=""):
#         tokenizer = AutoTokenizer.from_pretrained(path)
#         tokenizer.pad_token = tokenizer.eos_token
#         self.model = AutoModelForCausalLM.from_pretrained(path).to('cuda')
#         self.tokenizer = tokenizer
#         self.stopping_criteria = StoppingCriteriaList([StopAtPeriodCriteria(tokenizer)])

#     def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
#         """
#        data args:
#             inputs (:obj: `str`)
#             kwargs
#       Return:
#             A :obj:`list` | `dict`: will be serialized and returned
#         """
#         inputs = data.pop("inputs", data)
#         additional_bad_words_ids = data.pop("additional_bad_words_ids", [])


#         # 3070, 10456, [313, 334] corresponds to "(*", and we do not want to output a comment
#         # 13 is a newline character
#         # [1976, 441, 29889], [4920, 441, 29889] is "Abort." [4920, 18054, 29889] is "Aborted."
#         # [2087, 29885, 4430, 29889] is "Admitted."
#         bad_words_ids = [[3070], [313, 334], [10456], [13], [1976, 441, 29889], [2087, 29885, 4430, 29889], [4920, 441], [4920, 441, 29889], [4920, 18054, 29889]]
#         bad_words_ids.extend(additional_bad_words_ids)

#         input_ids = self.tokenizer.encode(inputs, return_tensors="pt").to('cuda')
#         max_generation_length = 75  # Desired number of tokens to generate
#         # max_input_length = 4092 - max_generation_length  # Maximum input length to allow space for generation

#         # # Truncate input_ids to the most recent tokens that fit within the max_input_length
#         # if input_ids.shape[1] > max_input_length:
#         #     input_ids = input_ids[:, -max_input_length:]

#         max_length = input_ids.shape[1] + max_generation_length
        
#         generated_ids = self.model.generate(
#             input_ids,
#             max_length=max_length,  # 50 new tokens
#             bad_words_ids=bad_words_ids,
#             temperature=1,
#             top_k=40,
#             do_sample=True,
#             stopping_criteria=self.stopping_criteria,
#         )

#         generated_text = self.tokenizer.decode(generated_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
#         prediction = [{"generated_text": generated_text, "generated_ids": generated_ids[0][input_ids.shape[1]:].tolist()}]
#         return prediction

class EndpointHandler():
    def __init__(self, path=""):
        self.model_path = path
        tokenizer = AutoTokenizer.from_pretrained(path)
        tokenizer.pad_token = tokenizer.eos_token
        self.tokenizer = tokenizer
        # Initialize the pipeline for text generation
        self.text_generation_pipeline = pipeline("text-generation", model=path, tokenizer=self.tokenizer, device=0)  # device=0 for CUDA

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        data args:
             inputs (:obj: `str`)
             kwargs
       Return:
             A :obj:`list` | `dict`: will be serialized and returned
        """
        inputs = data.pop("inputs", data)
        additional_bad_words_ids = data.pop("additional_bad_words_ids", [])

        # Define bad words to avoid in the output
        bad_words_ids = [[3070], [313, 334], [10456], [13], [1976, 441, 29889], [2087, 29885, 4430, 29889], [4920, 441], [4920, 441, 29889], [4920, 18054, 29889]]
        bad_words_ids.extend(additional_bad_words_ids)

        # Generate text using the pipeline
        generation_kwargs = {
            "max_new_tokens": 75,
            "temperature": 0.7,
            "top_k": 40,
            "bad_words_ids": bad_words_ids,
            "pad_token_id": self.tokenizer.eos_token_id,
            "return_full_text": False,  # Only return the new generated tokens
        }
        generated_outputs = self.text_generation_pipeline(inputs, **generation_kwargs)

        # Format the output
        predictions = [{"generated_text": output["generated_text"]} for output in generated_outputs]
        return predictions


class StopAtPeriodCriteria(StoppingCriteria):
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, input_ids, scores, **kwargs):
        # Decode the last generated token to text
        last_token_text = self.tokenizer.decode(input_ids[:, -1], skip_special_tokens=True)
        # Check if the decoded text ends with a period
        return '.' in last_token_text