llemma_7b / handler.py
Pierce Maloney
back to normal, removed additional bad words argument
216cf30
raw
history blame
2.02 kB
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, StoppingCriteria, StoppingCriteriaList
class EndpointHandler():
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
tokenizer = AutoTokenizer.from_pretrained(path)
tokenizer.pad_token = tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(path)
self.tokenizer = tokenizer
self.stopping_criteria = StoppingCriteriaList([StopAtPeriodCriteria(tokenizer)])
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs", data)
# Bad word: id 3070, 10456 corresponds to "(*", and we do not want to output a comment
bad_words_ids = [[3070], [313, 334], [10456]]
input_ids = self.tokenizer.encode(inputs, return_tensors="pt")
# Generate text using model.generate
generated_ids = self.model.generate(
input_ids,
max_length=input_ids.shape[1] + 50, # 50 new tokens
bad_words_ids=bad_words_ids,
temperature=1,
top_k=40,
stopping_criteria=self.stopping_criteria,
)
generated_text = self.tokenizer.decode(generated_ids[0, input_ids.shape[1]:], skip_special_tokens=True)
prediction = [{"generated_text": generated_text}]
return prediction
class StopAtPeriodCriteria(StoppingCriteria):
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs):
# Decode the last generated token to text
last_token_text = self.tokenizer.decode(input_ids[:, -1], skip_special_tokens=True)
# Check if the decoded text ends with a period
return '.' in last_token_text