diwank commited on
Commit
62e5c1e
1 Parent(s): 9334595

Create chat.py

Browse files
Files changed (1) hide show
  1. chat.py +299 -0
chat.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ import itertools as it
3
+ import random
4
+ from threading import Thread
5
+ from typing import Literal, Optional, TypedDict
6
+
7
+ import torch
8
+ from transformers import (
9
+ AutoModelForCausalLM,
10
+ AutoTokenizer,
11
+ StoppingCriteria,
12
+ StoppingCriteriaList,
13
+ TextIteratorStreamer,
14
+ )
15
+
16
+
17
+ ###########
18
+ ## Types ##
19
+ ###########
20
+
21
+ class ChatMLMessage(TypedDict):
22
+ name: Optional[str] = None
23
+ role: Literal["assistant", "system", "user"]
24
+ content: str
25
+
26
+ ChatML = list[ChatMLMessage]
27
+
28
+ # Example:
29
+ # [
30
+ # {"role": "system", "name": "situation", "content": "I am talking to Diwank"},
31
+ # {"role": "assistant", "name": "Samantha", "content": "Hey Diwank"},
32
+ # {"role": "user", "name": "Diwank", "content": "Hey!"},
33
+ # ]
34
+
35
+ ############
36
+ ## Consts ##
37
+ ############
38
+
39
+ AGENT_NAME: str = "Samantha"
40
+
41
+
42
+ ###########
43
+ ## Model ##
44
+ ###########
45
+
46
+ # assistant_model_id = "julep-ai/samantha-7b-ds-03"
47
+ # assistant_model = AutoModelForCausalLM.from_pretrained(assistant_model_id, torch_dtype=torch.bfloat16, device_map="auto")
48
+
49
+ # Load model and tokenizer
50
+ model_id = "julep-ai/samantha-33b"
51
+ tokenizer_id = "julep-ai/samantha-33b"
52
+
53
+ print("Loading model...")
54
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
55
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, use_fast=False)
56
+
57
+ # warmup
58
+ model.generate(**tokenizer("Hello", return_tensors="pt").to(0), max_new_tokens=2)
59
+
60
+ print("Model loaded")
61
+
62
+
63
+ ##############
64
+ ## Generate ##
65
+ ##############
66
+
67
+ class StopSequenceCriteria(StoppingCriteria):
68
+ def __init__(
69
+ self,
70
+ tokenizer,
71
+ stop: list[str],
72
+ input_length,
73
+ *args,
74
+ **kwargs,
75
+ ):
76
+ super().__init__(*args, **kwargs)
77
+
78
+ self.stop = stop
79
+ self.tokenizer = tokenizer
80
+ self.input_length = input_length
81
+
82
+ def __call__(
83
+ self,
84
+ input_ids: torch.LongTensor,
85
+ scores: torch.FloatTensor,
86
+ **kwargs,
87
+ ) -> bool:
88
+
89
+ input_ids = input_ids.long().tolist()
90
+ new_input_ids = [i[self.input_length:] for i in input_ids]
91
+
92
+ for text in self.stop:
93
+ generated_so_far = ""
94
+
95
+ for input_id in new_input_ids:
96
+ decoded = self.tokenizer.decode(input_id, skip_special_tokens=False)
97
+ generated_so_far += decoded
98
+
99
+ if text in generated_so_far:
100
+ return True
101
+
102
+ return False
103
+
104
+
105
+ def message_role_to_prefix(message: ChatMLMessage) -> str:
106
+ match message:
107
+ case {"role": "system", "name": name, **rest}:
108
+ return name
109
+ case {"role": "user", "name": name, **rest}:
110
+ return f"person ({name})" if name else "person"
111
+ case {"role": "assistant", "name": name, **rest}:
112
+ return f"me ({name})" if name else "me"
113
+
114
+
115
+ def to_prompt(
116
+ messages: ChatML,
117
+ bos: str = "<|section|>",
118
+ eos: str = "<|endsection|>",
119
+ suffix: str = f"\n<|section|>me ({AGENT_NAME})\n",
120
+ ) -> str:
121
+ # Input format:
122
+ # [
123
+ # {"role": "system", "name": "situation", "content": "I am talking to Diwank"},
124
+ # {"role": "assistant", "name": "Samantha", "content": "Hey Diwank"},
125
+ # {"role": "user", "name": "Diwank", "content": "Hey!"},
126
+ # ]
127
+
128
+ # Output format:
129
+ #
130
+ # <|section|>situation
131
+ # I am talking to Diwank<|endsection|>
132
+ # <|section|>me (Samantha)
133
+ # Hey Diwank<|endsection|>
134
+ # <|section|>person (Diwank)
135
+ # Hey<|endsection|>
136
+ # <|section|>me (Samantha)\n
137
+
138
+
139
+ prompt = "\n".join([
140
+ f"{bos}{message_role_to_prefix(message)}\n{message['content']}{eos}"
141
+ for message in messages
142
+ ])
143
+
144
+ return prompt + suffix
145
+
146
+
147
+ def groupwise(iterable, n):
148
+ """Like itertools.pairwise but for n elements"""
149
+
150
+ accum = deque((), n)
151
+ count = 0
152
+
153
+ for element in iterable:
154
+ accum.append(element)
155
+ count += 1
156
+
157
+ if len(accum) == n:
158
+ yield tuple(accum)
159
+
160
+ if count < n:
161
+ yield tuple(accum)
162
+
163
+
164
+ def wrap_iterator(iterator):
165
+ for item in iterator:
166
+ yield item
167
+
168
+ # TODO: Turn this into accepting regular expressions instead
169
+ def remove_stops(iterator, tokenizer, stop: list[str] = []):
170
+
171
+ # If there's nothing to check yield everything as is
172
+ if not stop:
173
+ yield from iterator
174
+ return
175
+
176
+ # We need to look ahead n number of tokens so that,
177
+ # we can check if a stop sequence is coming up
178
+ # and not yield starting part of the stop sequence.
179
+
180
+ # Look ahead by len of largest stop sequence
181
+ look_ahead = max([
182
+ len(tokenizer.encode(s, add_special_tokens=False))
183
+ for s in stop
184
+ ])
185
+
186
+ # Group tokens into look_ahead groups
187
+ for items in groupwise(iterator, look_ahead):
188
+
189
+ # Check if group has a stop sequence
190
+ joined = "".join(items).strip()
191
+ has_stop_sequence = {s: joined.endswith(s) for s in stop}
192
+
193
+ # If so, yield tokens minus stop sequence and return
194
+ if any(has_stop_sequence.values()):
195
+ # which stop sequence was found?
196
+ offending_sequence = next(s for s, is_bad in has_stop_sequence.items() if is_bad)
197
+
198
+ # remove that bit, yield and exit
199
+ yield joined.split(offending_sequence)[0]
200
+ return
201
+
202
+ # Otherwise, keep yielding the first item in the group
203
+ first, *_ = items
204
+
205
+ if first.strip():
206
+ yield first
207
+
208
+
209
+ def generate(
210
+ messages: ChatML,
211
+ stop: list[str] = [],
212
+ timeout: int = 15,
213
+ stream: bool = False,
214
+ **kwargs
215
+ ) -> TextIteratorStreamer | str:
216
+
217
+ # Prepare input
218
+ prompt = to_prompt(messages)
219
+ inputs = tokenizer(prompt, return_tensors="pt").to(0)
220
+ input_length = len(inputs["input_ids"].squeeze().tolist())
221
+
222
+ # Stopping criteria
223
+ stopping_criteria = (
224
+ StoppingCriteriaList([StopSequenceCriteria(
225
+ tokenizer=tokenizer,
226
+ stop=stop,
227
+ input_length=input_length,
228
+ )])
229
+ if stop else None
230
+ )
231
+
232
+ # Generation parameters
233
+ generation_kwargs = {
234
+ # defaults
235
+ "max_new_tokens": 40,
236
+ "repetition_penalty": 1.02,
237
+ "no_repeat_ngram_size": 4,
238
+ "renormalize_logits": True,
239
+ "temperature": 1.1,
240
+ #
241
+ # overrides
242
+ **kwargs,
243
+ #
244
+ # required params
245
+ "stopping_criteria": stopping_criteria,
246
+ # "assistant_model": assistant_model,
247
+ #
248
+ # add inputs
249
+ **inputs,
250
+ }
251
+
252
+ # If not streaming, run directly and return result
253
+ if not stream:
254
+ [output] = model.generate(**generation_kwargs)
255
+ result = tokenizer.decode(output[input_length:], skip_special_tokens=False)
256
+
257
+ # Remove the stop sequence at the end (needed)
258
+ for s in stop:
259
+ result = result.split(s)[0].strip()
260
+
261
+ return result
262
+
263
+ # If streaming, prepare streamer
264
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=timeout, skip_special_tokens=False)
265
+ generation_kwargs["streamer"] = streamer
266
+
267
+ # and start generating in new thread
268
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
269
+ thread.start()
270
+
271
+ # stop sequence filter
272
+ return remove_stops(streamer, tokenizer, stop)
273
+
274
+
275
+ if __name__ == "__main__":
276
+ user_name: str = input("Enter your name")
277
+ message: str = input("Enter your message")
278
+ chatml = [
279
+ ChatMLMessage(role="user", name=user_name, content=message),
280
+ ]
281
+
282
+ prompt = to_prompt(chatml)
283
+
284
+ # LLM settings
285
+ llm_settings = dict(
286
+ max_new_tokens=80,
287
+ stop=["<|", "\n\n"],
288
+ temperature=1.2,
289
+ )
290
+
291
+ # Generate streaming response
292
+ response_stream = generate(
293
+ chatml,
294
+ stream=True,
295
+ **llm_settings,
296
+ )
297
+
298
+ for m in response_stream:
299
+ print(m, end=" ")