|
from collections import deque |
|
import itertools as it |
|
import random |
|
from threading import Thread |
|
from typing import Literal, Optional, TypedDict |
|
|
|
import torch |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
StoppingCriteria, |
|
StoppingCriteriaList, |
|
TextIteratorStreamer, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChatMLMessage(TypedDict): |
|
name: Optional[str] = None |
|
role: Literal["assistant", "system", "user"] |
|
content: str |
|
|
|
ChatML = list[ChatMLMessage] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AGENT_NAME: str = "Samantha" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_id = "julep-ai/samantha-33b" |
|
tokenizer_id = "julep-ai/samantha-33b" |
|
|
|
print("Loading model...") |
|
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto") |
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, use_fast=False) |
|
|
|
|
|
model.generate(**tokenizer("Hello", return_tensors="pt").to(0), max_new_tokens=2) |
|
|
|
print("Model loaded") |
|
|
|
|
|
|
|
|
|
|
|
|
|
class StopSequenceCriteria(StoppingCriteria): |
|
def __init__( |
|
self, |
|
tokenizer, |
|
stop: list[str], |
|
input_length, |
|
*args, |
|
**kwargs, |
|
): |
|
super().__init__(*args, **kwargs) |
|
|
|
self.stop = stop |
|
self.tokenizer = tokenizer |
|
self.input_length = input_length |
|
|
|
def __call__( |
|
self, |
|
input_ids: torch.LongTensor, |
|
scores: torch.FloatTensor, |
|
**kwargs, |
|
) -> bool: |
|
|
|
input_ids = input_ids.long().tolist() |
|
new_input_ids = [i[self.input_length:] for i in input_ids] |
|
|
|
for text in self.stop: |
|
generated_so_far = "" |
|
|
|
for input_id in new_input_ids: |
|
decoded = self.tokenizer.decode(input_id, skip_special_tokens=False) |
|
generated_so_far += decoded |
|
|
|
if text in generated_so_far: |
|
return True |
|
|
|
return False |
|
|
|
|
|
def message_role_to_prefix(message: ChatMLMessage) -> str: |
|
match message: |
|
case {"role": "system", "name": name, **rest}: |
|
return name |
|
case {"role": "user", "name": name, **rest}: |
|
return f"person ({name})" if name else "person" |
|
case {"role": "assistant", "name": name, **rest}: |
|
return f"me ({name})" if name else "me" |
|
|
|
|
|
def to_prompt( |
|
messages: ChatML, |
|
bos: str = "<|section|>", |
|
eos: str = "<|endsection|>", |
|
suffix: str = f"\n<|section|>me ({AGENT_NAME})\n", |
|
) -> str: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt = "\n".join([ |
|
f"{bos}{message_role_to_prefix(message)}\n{message['content']}{eos}" |
|
for message in messages |
|
]) |
|
|
|
return prompt + suffix |
|
|
|
|
|
def groupwise(iterable, n): |
|
"""Like itertools.pairwise but for n elements""" |
|
|
|
accum = deque((), n) |
|
count = 0 |
|
|
|
for element in iterable: |
|
accum.append(element) |
|
count += 1 |
|
|
|
if len(accum) == n: |
|
yield tuple(accum) |
|
|
|
if count < n: |
|
yield tuple(accum) |
|
|
|
|
|
def wrap_iterator(iterator): |
|
for item in iterator: |
|
yield item |
|
|
|
|
|
def remove_stops(iterator, tokenizer, stop: list[str] = []): |
|
|
|
|
|
if not stop: |
|
yield from iterator |
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
look_ahead = max([ |
|
len(tokenizer.encode(s, add_special_tokens=False)) |
|
for s in stop |
|
]) |
|
|
|
|
|
for items in groupwise(iterator, look_ahead): |
|
|
|
|
|
joined = "".join(items).strip() |
|
has_stop_sequence = {s: joined.endswith(s) for s in stop} |
|
|
|
|
|
if any(has_stop_sequence.values()): |
|
|
|
offending_sequence = next(s for s, is_bad in has_stop_sequence.items() if is_bad) |
|
|
|
|
|
yield joined.split(offending_sequence)[0] |
|
return |
|
|
|
|
|
first, *_ = items |
|
|
|
if first.strip(): |
|
yield first |
|
|
|
|
|
def generate( |
|
messages: ChatML, |
|
stop: list[str] = [], |
|
timeout: int = 15, |
|
stream: bool = False, |
|
**kwargs |
|
) -> TextIteratorStreamer | str: |
|
|
|
|
|
prompt = to_prompt(messages) |
|
inputs = tokenizer(prompt, return_tensors="pt").to(0) |
|
input_length = len(inputs["input_ids"].squeeze().tolist()) |
|
|
|
|
|
stopping_criteria = ( |
|
StoppingCriteriaList([StopSequenceCriteria( |
|
tokenizer=tokenizer, |
|
stop=stop, |
|
input_length=input_length, |
|
)]) |
|
if stop else None |
|
) |
|
|
|
|
|
generation_kwargs = { |
|
|
|
"max_new_tokens": 40, |
|
"repetition_penalty": 1.02, |
|
"no_repeat_ngram_size": 4, |
|
"renormalize_logits": True, |
|
"temperature": 1.1, |
|
|
|
|
|
**kwargs, |
|
|
|
|
|
"stopping_criteria": stopping_criteria, |
|
|
|
|
|
|
|
**inputs, |
|
} |
|
|
|
|
|
if not stream: |
|
[output] = model.generate(**generation_kwargs) |
|
result = tokenizer.decode(output[input_length:], skip_special_tokens=False) |
|
|
|
|
|
for s in stop: |
|
result = result.split(s)[0].strip() |
|
|
|
return result |
|
|
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=timeout, skip_special_tokens=False) |
|
generation_kwargs["streamer"] = streamer |
|
|
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
|
|
return remove_stops(streamer, tokenizer, stop) |
|
|
|
|
|
if __name__ == "__main__": |
|
user_name: str = input("Enter your name") |
|
message: str = input("Enter your message") |
|
chatml = [ |
|
ChatMLMessage(role="user", name=user_name, content=message), |
|
] |
|
|
|
prompt = to_prompt(chatml) |
|
|
|
|
|
llm_settings = dict( |
|
max_new_tokens=80, |
|
stop=["<|", "\n\n"], |
|
temperature=1.2, |
|
) |
|
|
|
|
|
response_stream = generate( |
|
chatml, |
|
stream=True, |
|
**llm_settings, |
|
) |
|
|
|
for m in response_stream: |
|
print(m, end=" ") |
|
|