vilarin commited on
Commit
84e1807
·
verified ·
1 Parent(s): d62aad8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -15
app.py CHANGED
@@ -9,7 +9,7 @@ import os
9
  import time
10
  import spaces
11
  import torch
12
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
13
  import gradio as gr
14
  from threading import Thread
15
 
@@ -83,28 +83,25 @@ def stream_chat(
83
  conversation.append({"role": "user", "content": message})
84
 
85
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(device)
86
-
87
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
88
-
89
- generate_kwargs = dict(
90
  input_ids=input_ids,
91
  max_new_tokens = max_new_tokens,
92
  do_sample = False if temperature == 0 else True,
93
  top_p = top_p,
94
  temperature = temperature,
95
- streamer=streamer,
96
  )
97
 
98
- with torch.no_grad():
99
- thread = Thread(target=model.multi_byte_generate, kwargs=generate_kwargs)
100
- thread.start()
101
-
102
- buffer = ""
103
- for new_text in streamer:
104
- buffer += new_text
105
- yield buffer
106
 
107
-
 
 
 
108
  chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
109
 
110
  with gr.Blocks(css=CSS, theme="soft") as demo:
 
9
  import time
10
  import spaces
11
  import torch
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
13
  import gradio as gr
14
  from threading import Thread
15
 
 
83
  conversation.append({"role": "user", "content": message})
84
 
85
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(device)
86
+
87
+ gen_out = model.multi_byte_generate(
 
 
88
  input_ids=input_ids,
89
  max_new_tokens = max_new_tokens,
90
  do_sample = False if temperature == 0 else True,
91
  top_p = top_p,
92
  temperature = temperature,
 
93
  )
94
 
95
+ response = tokenizer.decode(
96
+ gen_out[0][input_ids.shape[1]:],
97
+ skip_special_tokens=False,
98
+ clean_up_tokenization_spaces=False
99
+ )
 
 
 
100
 
101
+ for i in range(len(response)):
102
+ time.sleep(0.05)
103
+ yield response[: i + 1]
104
+
105
  chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
106
 
107
  with gr.Blocks(css=CSS, theme="soft") as demo: