llemma_7b / handler.py
Pierce Maloney
testing custom logit processor for bad words
02ffbef
raw
history blame
2.46 kB
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, StoppingCriteria, StoppingCriteriaList, LogitsProcessor, LogitsProcessorList
class EndpointHandler():
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(path)
tokenizer.pad_token = tokenizer.eos_token
self.pipeline = pipeline('text-generation', model=model, tokenizer=tokenizer)
self.stopping_criteria = StoppingCriteriaList([StopAtPeriodCriteria(tokenizer)])
self.logits_processor = LogitsProcessorList([BanSpecificTokensLogitsProcessor(tokenizer, [3070])])
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs", data)
# Bad word: id 3070 corresponds to "(*", and we do not want to output a comment
prediction = self.pipeline(
inputs,
stopping_criteria=self.stopping_criteria,
max_new_tokens=50,
return_full_text=False,
# bad_words_ids=[[3070], [313, 334]],
logits_processor=self.logits_processor,
temperature=1,
top_k=40,
)
return prediction
class StopAtPeriodCriteria(StoppingCriteria):
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs):
# Decode the last generated token to text
last_token_text = self.tokenizer.decode(input_ids[:, -1], skip_special_tokens=True)
# Check if the decoded text ends with a period
return '.' in last_token_text
class BanSpecificTokensLogitsProcessor(LogitsProcessor):
"""
Logits processor that sets the logits of specific tokens to -inf,
effectively banning them from being generated.
"""
def __init__(self, tokenizer, banned_tokens_ids):
self.tokenizer = tokenizer
self.banned_tokens_ids = banned_tokens_ids
def __call__(self, input_ids, scores):
# Set logits of banned tokens to -inf
for token_id in self.banned_tokens_ids:
scores[:, token_id] = float('-inf')
return scores