File size: 4,638 Bytes
bf66e5a 366e62e bf66e5a a36be93 bf66e5a a36be93 bf66e5a 355a0ec 000ad8b a36be93 bf66e5a 216cf30 bf66e5a a36be93 bf66e5a 4c4f932 bf66e5a a36be93 4c4f932 66e62c6 000ad8b a36be93 813fd4a 355a0ec a36be93 813fd4a a36be93 6d8b690 a36be93 bf66e5a 366e62e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, StoppingCriteria, StoppingCriteriaList
# class EndpointHandler():
# def __init__(self, path=""):
# tokenizer = AutoTokenizer.from_pretrained(path)
# tokenizer.pad_token = tokenizer.eos_token
# self.model = AutoModelForCausalLM.from_pretrained(path).to('cuda')
# self.tokenizer = tokenizer
# self.stopping_criteria = StoppingCriteriaList([StopAtPeriodCriteria(tokenizer)])
# def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
# """
# data args:
# inputs (:obj: `str`)
# kwargs
# Return:
# A :obj:`list` | `dict`: will be serialized and returned
# """
# inputs = data.pop("inputs", data)
# additional_bad_words_ids = data.pop("additional_bad_words_ids", [])
# # 3070, 10456, [313, 334] corresponds to "(*", and we do not want to output a comment
# # 13 is a newline character
# # [1976, 441, 29889], [4920, 441, 29889] is "Abort." [4920, 18054, 29889] is "Aborted."
# # [2087, 29885, 4430, 29889] is "Admitted."
# bad_words_ids = [[3070], [313, 334], [10456], [13], [1976, 441, 29889], [2087, 29885, 4430, 29889], [4920, 441], [4920, 441, 29889], [4920, 18054, 29889]]
# bad_words_ids.extend(additional_bad_words_ids)
# input_ids = self.tokenizer.encode(inputs, return_tensors="pt").to('cuda')
# max_generation_length = 75 # Desired number of tokens to generate
# # max_input_length = 4092 - max_generation_length # Maximum input length to allow space for generation
# # # Truncate input_ids to the most recent tokens that fit within the max_input_length
# # if input_ids.shape[1] > max_input_length:
# # input_ids = input_ids[:, -max_input_length:]
# max_length = input_ids.shape[1] + max_generation_length
# generated_ids = self.model.generate(
# input_ids,
# max_length=max_length, # 50 new tokens
# bad_words_ids=bad_words_ids,
# temperature=1,
# top_k=40,
# do_sample=True,
# stopping_criteria=self.stopping_criteria,
# )
# generated_text = self.tokenizer.decode(generated_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
# prediction = [{"generated_text": generated_text, "generated_ids": generated_ids[0][input_ids.shape[1]:].tolist()}]
# return prediction
class EndpointHandler():
def __init__(self, path=""):
self.model_path = path
tokenizer = AutoTokenizer.from_pretrained(path)
tokenizer.pad_token = tokenizer.eos_token
self.tokenizer = tokenizer
# Initialize the pipeline for text generation
self.text_generation_pipeline = pipeline("text-generation", model=path, tokenizer=self.tokenizer, device=0) # device=0 for CUDA
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs", data)
additional_bad_words_ids = data.pop("additional_bad_words_ids", [])
# Define bad words to avoid in the output
bad_words_ids = [[3070], [313, 334], [10456], [13], [1976, 441, 29889], [2087, 29885, 4430, 29889], [4920, 441], [4920, 441, 29889], [4920, 18054, 29889]]
bad_words_ids.extend(additional_bad_words_ids)
# Generate text using the pipeline
generation_kwargs = {
"max_new_tokens": 75,
"temperature": 0.7,
"top_k": 40,
"bad_words_ids": bad_words_ids,
"pad_token_id": self.tokenizer.eos_token_id,
"return_full_text": False, # Only return the new generated tokens
}
generated_outputs = self.text_generation_pipeline(inputs, **generation_kwargs)
# Format the output
predictions = [{"generated_text": output["generated_text"]} for output in generated_outputs]
return predictions
class StopAtPeriodCriteria(StoppingCriteria):
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs):
# Decode the last generated token to text
last_token_text = self.tokenizer.decode(input_ids[:, -1], skip_special_tokens=True)
# Check if the decoded text ends with a period
return '.' in last_token_text |