Create chat.py
Browse files
chat.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import deque
|
2 |
+
import itertools as it
|
3 |
+
import random
|
4 |
+
from threading import Thread
|
5 |
+
from typing import Literal, Optional, TypedDict
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from transformers import (
|
9 |
+
AutoModelForCausalLM,
|
10 |
+
AutoTokenizer,
|
11 |
+
StoppingCriteria,
|
12 |
+
StoppingCriteriaList,
|
13 |
+
TextIteratorStreamer,
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
###########
|
18 |
+
## Types ##
|
19 |
+
###########
|
20 |
+
|
21 |
+
class ChatMLMessage(TypedDict):
|
22 |
+
name: Optional[str] = None
|
23 |
+
role: Literal["assistant", "system", "user"]
|
24 |
+
content: str
|
25 |
+
|
26 |
+
ChatML = list[ChatMLMessage]
|
27 |
+
|
28 |
+
# Example:
|
29 |
+
# [
|
30 |
+
# {"role": "system", "name": "situation", "content": "I am talking to Diwank"},
|
31 |
+
# {"role": "assistant", "name": "Samantha", "content": "Hey Diwank"},
|
32 |
+
# {"role": "user", "name": "Diwank", "content": "Hey!"},
|
33 |
+
# ]
|
34 |
+
|
35 |
+
############
|
36 |
+
## Consts ##
|
37 |
+
############
|
38 |
+
|
39 |
+
AGENT_NAME: str = "Samantha"
|
40 |
+
|
41 |
+
|
42 |
+
###########
|
43 |
+
## Model ##
|
44 |
+
###########
|
45 |
+
|
46 |
+
# assistant_model_id = "julep-ai/samantha-7b-ds-03"
|
47 |
+
# assistant_model = AutoModelForCausalLM.from_pretrained(assistant_model_id, torch_dtype=torch.bfloat16, device_map="auto")
|
48 |
+
|
49 |
+
# Load model and tokenizer
|
50 |
+
model_id = "julep-ai/samantha-33b"
|
51 |
+
tokenizer_id = "julep-ai/samantha-33b"
|
52 |
+
|
53 |
+
print("Loading model...")
|
54 |
+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
|
55 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, use_fast=False)
|
56 |
+
|
57 |
+
# warmup
|
58 |
+
model.generate(**tokenizer("Hello", return_tensors="pt").to(0), max_new_tokens=2)
|
59 |
+
|
60 |
+
print("Model loaded")
|
61 |
+
|
62 |
+
|
63 |
+
##############
|
64 |
+
## Generate ##
|
65 |
+
##############
|
66 |
+
|
67 |
+
class StopSequenceCriteria(StoppingCriteria):
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
tokenizer,
|
71 |
+
stop: list[str],
|
72 |
+
input_length,
|
73 |
+
*args,
|
74 |
+
**kwargs,
|
75 |
+
):
|
76 |
+
super().__init__(*args, **kwargs)
|
77 |
+
|
78 |
+
self.stop = stop
|
79 |
+
self.tokenizer = tokenizer
|
80 |
+
self.input_length = input_length
|
81 |
+
|
82 |
+
def __call__(
|
83 |
+
self,
|
84 |
+
input_ids: torch.LongTensor,
|
85 |
+
scores: torch.FloatTensor,
|
86 |
+
**kwargs,
|
87 |
+
) -> bool:
|
88 |
+
|
89 |
+
input_ids = input_ids.long().tolist()
|
90 |
+
new_input_ids = [i[self.input_length:] for i in input_ids]
|
91 |
+
|
92 |
+
for text in self.stop:
|
93 |
+
generated_so_far = ""
|
94 |
+
|
95 |
+
for input_id in new_input_ids:
|
96 |
+
decoded = self.tokenizer.decode(input_id, skip_special_tokens=False)
|
97 |
+
generated_so_far += decoded
|
98 |
+
|
99 |
+
if text in generated_so_far:
|
100 |
+
return True
|
101 |
+
|
102 |
+
return False
|
103 |
+
|
104 |
+
|
105 |
+
def message_role_to_prefix(message: ChatMLMessage) -> str:
|
106 |
+
match message:
|
107 |
+
case {"role": "system", "name": name, **rest}:
|
108 |
+
return name
|
109 |
+
case {"role": "user", "name": name, **rest}:
|
110 |
+
return f"person ({name})" if name else "person"
|
111 |
+
case {"role": "assistant", "name": name, **rest}:
|
112 |
+
return f"me ({name})" if name else "me"
|
113 |
+
|
114 |
+
|
115 |
+
def to_prompt(
|
116 |
+
messages: ChatML,
|
117 |
+
bos: str = "<|section|>",
|
118 |
+
eos: str = "<|endsection|>",
|
119 |
+
suffix: str = f"\n<|section|>me ({AGENT_NAME})\n",
|
120 |
+
) -> str:
|
121 |
+
# Input format:
|
122 |
+
# [
|
123 |
+
# {"role": "system", "name": "situation", "content": "I am talking to Diwank"},
|
124 |
+
# {"role": "assistant", "name": "Samantha", "content": "Hey Diwank"},
|
125 |
+
# {"role": "user", "name": "Diwank", "content": "Hey!"},
|
126 |
+
# ]
|
127 |
+
|
128 |
+
# Output format:
|
129 |
+
#
|
130 |
+
# <|section|>situation
|
131 |
+
# I am talking to Diwank<|endsection|>
|
132 |
+
# <|section|>me (Samantha)
|
133 |
+
# Hey Diwank<|endsection|>
|
134 |
+
# <|section|>person (Diwank)
|
135 |
+
# Hey<|endsection|>
|
136 |
+
# <|section|>me (Samantha)\n
|
137 |
+
|
138 |
+
|
139 |
+
prompt = "\n".join([
|
140 |
+
f"{bos}{message_role_to_prefix(message)}\n{message['content']}{eos}"
|
141 |
+
for message in messages
|
142 |
+
])
|
143 |
+
|
144 |
+
return prompt + suffix
|
145 |
+
|
146 |
+
|
147 |
+
def groupwise(iterable, n):
|
148 |
+
"""Like itertools.pairwise but for n elements"""
|
149 |
+
|
150 |
+
accum = deque((), n)
|
151 |
+
count = 0
|
152 |
+
|
153 |
+
for element in iterable:
|
154 |
+
accum.append(element)
|
155 |
+
count += 1
|
156 |
+
|
157 |
+
if len(accum) == n:
|
158 |
+
yield tuple(accum)
|
159 |
+
|
160 |
+
if count < n:
|
161 |
+
yield tuple(accum)
|
162 |
+
|
163 |
+
|
164 |
+
def wrap_iterator(iterator):
|
165 |
+
for item in iterator:
|
166 |
+
yield item
|
167 |
+
|
168 |
+
# TODO: Turn this into accepting regular expressions instead
|
169 |
+
def remove_stops(iterator, tokenizer, stop: list[str] = []):
|
170 |
+
|
171 |
+
# If there's nothing to check yield everything as is
|
172 |
+
if not stop:
|
173 |
+
yield from iterator
|
174 |
+
return
|
175 |
+
|
176 |
+
# We need to look ahead n number of tokens so that,
|
177 |
+
# we can check if a stop sequence is coming up
|
178 |
+
# and not yield starting part of the stop sequence.
|
179 |
+
|
180 |
+
# Look ahead by len of largest stop sequence
|
181 |
+
look_ahead = max([
|
182 |
+
len(tokenizer.encode(s, add_special_tokens=False))
|
183 |
+
for s in stop
|
184 |
+
])
|
185 |
+
|
186 |
+
# Group tokens into look_ahead groups
|
187 |
+
for items in groupwise(iterator, look_ahead):
|
188 |
+
|
189 |
+
# Check if group has a stop sequence
|
190 |
+
joined = "".join(items).strip()
|
191 |
+
has_stop_sequence = {s: joined.endswith(s) for s in stop}
|
192 |
+
|
193 |
+
# If so, yield tokens minus stop sequence and return
|
194 |
+
if any(has_stop_sequence.values()):
|
195 |
+
# which stop sequence was found?
|
196 |
+
offending_sequence = next(s for s, is_bad in has_stop_sequence.items() if is_bad)
|
197 |
+
|
198 |
+
# remove that bit, yield and exit
|
199 |
+
yield joined.split(offending_sequence)[0]
|
200 |
+
return
|
201 |
+
|
202 |
+
# Otherwise, keep yielding the first item in the group
|
203 |
+
first, *_ = items
|
204 |
+
|
205 |
+
if first.strip():
|
206 |
+
yield first
|
207 |
+
|
208 |
+
|
209 |
+
def generate(
|
210 |
+
messages: ChatML,
|
211 |
+
stop: list[str] = [],
|
212 |
+
timeout: int = 15,
|
213 |
+
stream: bool = False,
|
214 |
+
**kwargs
|
215 |
+
) -> TextIteratorStreamer | str:
|
216 |
+
|
217 |
+
# Prepare input
|
218 |
+
prompt = to_prompt(messages)
|
219 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(0)
|
220 |
+
input_length = len(inputs["input_ids"].squeeze().tolist())
|
221 |
+
|
222 |
+
# Stopping criteria
|
223 |
+
stopping_criteria = (
|
224 |
+
StoppingCriteriaList([StopSequenceCriteria(
|
225 |
+
tokenizer=tokenizer,
|
226 |
+
stop=stop,
|
227 |
+
input_length=input_length,
|
228 |
+
)])
|
229 |
+
if stop else None
|
230 |
+
)
|
231 |
+
|
232 |
+
# Generation parameters
|
233 |
+
generation_kwargs = {
|
234 |
+
# defaults
|
235 |
+
"max_new_tokens": 40,
|
236 |
+
"repetition_penalty": 1.02,
|
237 |
+
"no_repeat_ngram_size": 4,
|
238 |
+
"renormalize_logits": True,
|
239 |
+
"temperature": 1.1,
|
240 |
+
#
|
241 |
+
# overrides
|
242 |
+
**kwargs,
|
243 |
+
#
|
244 |
+
# required params
|
245 |
+
"stopping_criteria": stopping_criteria,
|
246 |
+
# "assistant_model": assistant_model,
|
247 |
+
#
|
248 |
+
# add inputs
|
249 |
+
**inputs,
|
250 |
+
}
|
251 |
+
|
252 |
+
# If not streaming, run directly and return result
|
253 |
+
if not stream:
|
254 |
+
[output] = model.generate(**generation_kwargs)
|
255 |
+
result = tokenizer.decode(output[input_length:], skip_special_tokens=False)
|
256 |
+
|
257 |
+
# Remove the stop sequence at the end (needed)
|
258 |
+
for s in stop:
|
259 |
+
result = result.split(s)[0].strip()
|
260 |
+
|
261 |
+
return result
|
262 |
+
|
263 |
+
# If streaming, prepare streamer
|
264 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=timeout, skip_special_tokens=False)
|
265 |
+
generation_kwargs["streamer"] = streamer
|
266 |
+
|
267 |
+
# and start generating in new thread
|
268 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
269 |
+
thread.start()
|
270 |
+
|
271 |
+
# stop sequence filter
|
272 |
+
return remove_stops(streamer, tokenizer, stop)
|
273 |
+
|
274 |
+
|
275 |
+
if __name__ == "__main__":
|
276 |
+
user_name: str = input("Enter your name")
|
277 |
+
message: str = input("Enter your message")
|
278 |
+
chatml = [
|
279 |
+
ChatMLMessage(role="user", name=user_name, content=message),
|
280 |
+
]
|
281 |
+
|
282 |
+
prompt = to_prompt(chatml)
|
283 |
+
|
284 |
+
# LLM settings
|
285 |
+
llm_settings = dict(
|
286 |
+
max_new_tokens=80,
|
287 |
+
stop=["<|", "\n\n"],
|
288 |
+
temperature=1.2,
|
289 |
+
)
|
290 |
+
|
291 |
+
# Generate streaming response
|
292 |
+
response_stream = generate(
|
293 |
+
chatml,
|
294 |
+
stream=True,
|
295 |
+
**llm_settings,
|
296 |
+
)
|
297 |
+
|
298 |
+
for m in response_stream:
|
299 |
+
print(m, end=" ")
|