samantha-33b / chat.py
diwank's picture
Update chat.py
2f31084
from collections import deque
import itertools as it
import random
from threading import Thread
from typing import Literal, Optional, TypedDict
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer,
)
###########
## Types ##
###########
class ChatMLMessage(TypedDict):
name: Optional[str] = None
role: Literal["assistant", "system", "user"]
content: str
ChatML = list[ChatMLMessage]
# Example:
# [
# {"role": "system", "name": "situation", "content": "I am talking to Diwank"},
# {"role": "assistant", "name": "Samantha", "content": "Hey Diwank"},
# {"role": "user", "name": "Diwank", "content": "Hey!"},
# ]
############
## Consts ##
############
AGENT_NAME: str = "Samantha"
###########
## Model ##
###########
# assistant_model_id = "julep-ai/samantha-7b-ds-03"
# assistant_model = AutoModelForCausalLM.from_pretrained(assistant_model_id, torch_dtype=torch.bfloat16, device_map="auto")
# Load model and tokenizer
model_id = "julep-ai/samantha-33b"
tokenizer_id = "julep-ai/samantha-33b"
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, use_fast=False, clean_up_tokenization_spaces=True)
# warmup
model.generate(**tokenizer("Hello", return_tensors="pt").to(0), max_new_tokens=2)
print("Model loaded")
##############
## Generate ##
##############
class StopSequenceCriteria(StoppingCriteria):
def __init__(
self,
tokenizer,
stop: list[str],
input_length,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.stop = stop
self.tokenizer = tokenizer
self.input_length = input_length
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
**kwargs,
) -> bool:
input_ids = input_ids.long().tolist()
new_input_ids = [i[self.input_length:] for i in input_ids]
for text in self.stop:
generated_so_far = ""
for input_id in new_input_ids:
decoded = self.tokenizer.decode(input_id, skip_special_tokens=False)
generated_so_far += decoded
if text in generated_so_far:
return True
return False
def message_role_to_prefix(message: ChatMLMessage) -> str:
match message:
case {"role": "system", "name": name, **rest}:
return name
case {"role": "user", "name": name, **rest}:
return f"person ({name})" if name else "person"
case {"role": "assistant", "name": name, **rest}:
return f"me ({name})" if name else "me"
def to_prompt(
messages: ChatML,
bos: str = "<|section|>",
eos: str = "<|endsection|>",
suffix: str = f"\n<|section|>me ({AGENT_NAME})\n",
) -> str:
# Input format:
# [
# {"role": "system", "name": "situation", "content": "I am talking to Diwank"},
# {"role": "assistant", "name": "Samantha", "content": "Hey Diwank"},
# {"role": "user", "name": "Diwank", "content": "Hey!"},
# ]
# Output format:
#
# <|section|>situation
# I am talking to Diwank<|endsection|>
# <|section|>me (Samantha)
# Hey Diwank<|endsection|>
# <|section|>person (Diwank)
# Hey<|endsection|>
# <|section|>me (Samantha)\n
prompt = "\n".join([
f"{bos}{message_role_to_prefix(message)}\n{message['content']}{eos}"
for message in messages
])
return prompt + suffix
def groupwise(iterable, n):
"""Like itertools.pairwise but for n elements"""
accum = deque((), n)
count = 0
for element in iterable:
accum.append(element)
count += 1
if len(accum) == n:
yield tuple(accum)
if count < n:
yield tuple(accum)
def wrap_iterator(iterator):
for item in iterator:
yield item
# TODO: Turn this into accepting regular expressions instead
def remove_stops(iterator, tokenizer, stop: list[str] = []):
# If there's nothing to check yield everything as is
if not stop:
yield from iterator
return
# We need to look ahead n number of tokens so that,
# we can check if a stop sequence is coming up
# and not yield starting part of the stop sequence.
# Look ahead by len of largest stop sequence
look_ahead = max([
len(tokenizer.encode(s, add_special_tokens=False))
for s in stop
])
# Group tokens into look_ahead groups
for items in groupwise(iterator, look_ahead):
# Check if group has a stop sequence
joined = "".join(items).strip()
has_stop_sequence = {s: joined.endswith(s) for s in stop}
# If so, yield tokens minus stop sequence and return
if any(has_stop_sequence.values()):
# which stop sequence was found?
offending_sequence = next(s for s, is_bad in has_stop_sequence.items() if is_bad)
# remove that bit, yield and exit
yield joined.split(offending_sequence)[0]
return
# Otherwise, keep yielding the first item in the group
first, *_ = items
if first.strip():
yield first
def generate(
messages: ChatML,
stop: list[str] = [],
timeout: int = 15,
stream: bool = False,
**kwargs
) -> TextIteratorStreamer | str:
# Prepare input
prompt = to_prompt(messages)
inputs = tokenizer(prompt, return_tensors="pt").to(0)
input_length = len(inputs["input_ids"].squeeze().tolist())
# Stopping criteria
stopping_criteria = (
StoppingCriteriaList([StopSequenceCriteria(
tokenizer=tokenizer,
stop=stop,
input_length=input_length,
)])
if stop else None
)
# Generation parameters
generation_kwargs = {
# defaults
"max_new_tokens": 40,
"repetition_penalty": 1.02,
"no_repeat_ngram_size": 4,
"renormalize_logits": True,
"temperature": 1.1,
#
# overrides
**kwargs,
#
# required params
"stopping_criteria": stopping_criteria,
# "assistant_model": assistant_model,
#
# add inputs
**inputs,
}
# If not streaming, run directly and return result
if not stream:
[output] = model.generate(**generation_kwargs)
result = tokenizer.decode(output[input_length:], skip_special_tokens=False)
# Remove the stop sequence at the end (needed)
for s in stop:
result = result.split(s)[0].strip()
return result
# If streaming, prepare streamer
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=timeout, skip_special_tokens=False)
generation_kwargs["streamer"] = streamer
# and start generating in new thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# stop sequence filter
return remove_stops(streamer, tokenizer, stop)
if __name__ == "__main__":
user_name: str = input("Enter your name")
message: str = input("Enter your message")
chatml = [
ChatMLMessage(role="user", name=user_name, content=message),
]
prompt = to_prompt(chatml)
# LLM settings
llm_settings = dict(
max_new_tokens=80,
stop=["<|", "\n\n"],
temperature=1.2,
)
# Generate streaming response
response_stream = generate(
chatml,
stream=True,
**llm_settings,
)
for m in response_stream:
print(m, end=" ")