Spaces:
Sleeping
Sleeping
Update fn.py
Browse files
fn.py
CHANGED
@@ -168,16 +168,20 @@ def chat(message, history = [], instruction = None, args = {}):
|
|
168 |
|
169 |
model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
|
170 |
|
171 |
-
|
172 |
-
|
173 |
-
|
|
|
174 |
|
175 |
generate_kwargs = dict(
|
176 |
model_inputs,
|
177 |
-
streamer=streamer,
|
178 |
do_sample=True,
|
179 |
-
num_beams=1,
|
180 |
)
|
|
|
|
|
|
|
|
|
|
|
181 |
for k in [
|
182 |
'max_new_tokens',
|
183 |
'temperature',
|
@@ -188,43 +192,21 @@ def chat(message, history = [], instruction = None, args = {}):
|
|
188 |
if cfg[k]:
|
189 |
generate_kwargs[k] = cfg[k]
|
190 |
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
model_output = ""
|
195 |
-
for new_text in streamer:
|
196 |
-
model_output += new_text
|
197 |
-
if 'fastapi' in args:
|
198 |
-
# fastapiは差分だけを返して欲しい
|
199 |
-
yield new_text
|
200 |
-
else:
|
201 |
-
# gradioは常に全文を返して欲しい
|
202 |
-
yield model_output
|
203 |
-
|
204 |
-
return model_output
|
205 |
-
|
206 |
-
def infer(args: dict):
|
207 |
-
global cfg
|
208 |
-
|
209 |
-
if 'model_name' in args:
|
210 |
-
load_model(args['model_name'], args['qtype'], args['dtype'])
|
211 |
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
]:
|
222 |
-
cfg[k] = args[k]
|
223 |
|
224 |
-
|
225 |
-
return chat(args['input'], args['messages'])
|
226 |
-
if 'instruction' in args:
|
227 |
-
return instruct(args['instruction'], args['input'])
|
228 |
|
229 |
def apply_template(messages):
|
230 |
global tokenizer, cfg
|
@@ -235,6 +217,6 @@ def apply_template(messages):
|
|
235 |
if type(messages) is str:
|
236 |
if cfg['inst_template']:
|
237 |
return cfg['inst_template'].format(instruction=cfg['instruction'], input=messages)
|
238 |
-
return cfg['instruction']
|
239 |
if type(messages) is list:
|
240 |
return tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)
|
|
|
168 |
|
169 |
model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
|
170 |
|
171 |
+
if 'fastapi' not in args or 'stream' in args and args['stream']:
|
172 |
+
streamer = TextIteratorStreamer(
|
173 |
+
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True,
|
174 |
+
)
|
175 |
|
176 |
generate_kwargs = dict(
|
177 |
model_inputs,
|
|
|
178 |
do_sample=True,
|
|
|
179 |
)
|
180 |
+
|
181 |
+
if 'fastapi' not in args or 'stream' in args and args['stream']:
|
182 |
+
generate_kwargs['streamer'] = streamer
|
183 |
+
generate_kwargs['num_beams'] = 1
|
184 |
+
|
185 |
for k in [
|
186 |
'max_new_tokens',
|
187 |
'temperature',
|
|
|
192 |
if cfg[k]:
|
193 |
generate_kwargs[k] = cfg[k]
|
194 |
|
195 |
+
if 'fastapi' not in args or 'stream' in args and args['stream']:
|
196 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
197 |
+
t.start()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
+
model_output = ""
|
200 |
+
for new_text in streamer:
|
201 |
+
model_output += new_text
|
202 |
+
if 'fastapi' in args:
|
203 |
+
# fastapiは差分だけを返して欲しい
|
204 |
+
yield new_text
|
205 |
+
else:
|
206 |
+
# gradioは常に全文を返して欲しい
|
207 |
+
yield model_output
|
|
|
|
|
208 |
|
209 |
+
return model.generate(**generate_kwargs)
|
|
|
|
|
|
|
210 |
|
211 |
def apply_template(messages):
|
212 |
global tokenizer, cfg
|
|
|
217 |
if type(messages) is str:
|
218 |
if cfg['inst_template']:
|
219 |
return cfg['inst_template'].format(instruction=cfg['instruction'], input=messages)
|
220 |
+
return cfg['instruction'].format(input=messages)
|
221 |
if type(messages) is list:
|
222 |
return tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)
|