aka7774 commited on
Commit
e525f97
1 Parent(s): 58190b4

Update fn.py

Browse files
Files changed (1) hide show
  1. fn.py +23 -41
fn.py CHANGED
@@ -168,16 +168,20 @@ def chat(message, history = [], instruction = None, args = {}):
168
 
169
  model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
170
 
171
- streamer = TextIteratorStreamer(
172
- tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True,
173
- )
 
174
 
175
  generate_kwargs = dict(
176
  model_inputs,
177
- streamer=streamer,
178
  do_sample=True,
179
- num_beams=1,
180
  )
 
 
 
 
 
181
  for k in [
182
  'max_new_tokens',
183
  'temperature',
@@ -188,43 +192,21 @@ def chat(message, history = [], instruction = None, args = {}):
188
  if cfg[k]:
189
  generate_kwargs[k] = cfg[k]
190
 
191
- t = Thread(target=model.generate, kwargs=generate_kwargs)
192
- t.start()
193
-
194
- model_output = ""
195
- for new_text in streamer:
196
- model_output += new_text
197
- if 'fastapi' in args:
198
- # fastapiは差分だけを返して欲しい
199
- yield new_text
200
- else:
201
- # gradioは常に全文を返して欲しい
202
- yield model_output
203
-
204
- return model_output
205
-
206
- def infer(args: dict):
207
- global cfg
208
-
209
- if 'model_name' in args:
210
- load_model(args['model_name'], args['qtype'], args['dtype'])
211
 
212
- for k in [
213
- 'instruction',
214
- 'inst_template',
215
- 'chat_template',
216
- 'max_new_tokens',
217
- 'temperature',
218
- 'top_p',
219
- 'top_k',
220
- 'repetition_penalty'
221
- ]:
222
- cfg[k] = args[k]
223
 
224
- if 'messages' in args:
225
- return chat(args['input'], args['messages'])
226
- if 'instruction' in args:
227
- return instruct(args['instruction'], args['input'])
228
 
229
  def apply_template(messages):
230
  global tokenizer, cfg
@@ -235,6 +217,6 @@ def apply_template(messages):
235
  if type(messages) is str:
236
  if cfg['inst_template']:
237
  return cfg['inst_template'].format(instruction=cfg['instruction'], input=messages)
238
- return cfg['instruction']
239
  if type(messages) is list:
240
  return tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)
 
168
 
169
  model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
170
 
171
+ if 'fastapi' not in args or 'stream' in args and args['stream']:
172
+ streamer = TextIteratorStreamer(
173
+ tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True,
174
+ )
175
 
176
  generate_kwargs = dict(
177
  model_inputs,
 
178
  do_sample=True,
 
179
  )
180
+
181
+ if 'fastapi' not in args or 'stream' in args and args['stream']:
182
+ generate_kwargs['streamer'] = streamer
183
+ generate_kwargs['num_beams'] = 1
184
+
185
  for k in [
186
  'max_new_tokens',
187
  'temperature',
 
192
  if cfg[k]:
193
  generate_kwargs[k] = cfg[k]
194
 
195
+ if 'fastapi' not in args or 'stream' in args and args['stream']:
196
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
197
+ t.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
+ model_output = ""
200
+ for new_text in streamer:
201
+ model_output += new_text
202
+ if 'fastapi' in args:
203
+ # fastapiは差分だけを返して欲しい
204
+ yield new_text
205
+ else:
206
+ # gradioは常に全文を返して欲しい
207
+ yield model_output
 
 
208
 
209
+ return model.generate(**generate_kwargs)
 
 
 
210
 
211
  def apply_template(messages):
212
  global tokenizer, cfg
 
217
  if type(messages) is str:
218
  if cfg['inst_template']:
219
  return cfg['inst_template'].format(instruction=cfg['instruction'], input=messages)
220
+ return cfg['instruction'].format(input=messages)
221
  if type(messages) is list:
222
  return tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)