qgyd2021 commited on
Commit
92dd16e
1 Parent(s): 9afc761

[update]edit main

Browse files
Files changed (1) hide show
  1. main.py +49 -6
main.py CHANGED
@@ -2,11 +2,12 @@
2
  # -*- coding: utf-8 -*-
3
  import argparse
4
  import os
 
5
 
6
  import gradio as gr
7
  from transformers import AutoModel, AutoTokenizer
8
  from transformers.models.auto import AutoModelForCausalLM, AutoTokenizer
9
- # from transformers.utils.quantization_config import BitsAndBytesConfig
10
  import torch
11
 
12
  from project_settings import project_path
@@ -72,7 +73,7 @@ def main():
72
  )
73
  model = model.bfloat16().eval()
74
 
75
- def fn(text: str):
76
  input_ids = tokenizer(
77
  text,
78
  return_tensors="pt",
@@ -84,8 +85,12 @@ def main():
84
 
85
  with torch.no_grad():
86
  outputs = model.generate(
87
- input_ids=input_ids, max_new_tokens=args.max_new_tokens, do_sample=True,
88
- top_p=args.top_p, temperature=args.temperature, repetition_penalty=args.repetition_penalty,
 
 
 
 
89
  eos_token_id=tokenizer.eos_token_id
90
  )
91
  outputs = outputs.tolist()[0][len(input_ids[0]):]
@@ -94,6 +99,44 @@ def main():
94
 
95
  return [(text, response)]
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  with gr.Blocks() as blocks:
98
  gr.Markdown(value=description)
99
 
@@ -108,8 +151,8 @@ def main():
108
 
109
  gr.Examples(examples, text_box)
110
 
111
- text_box.submit(fn, [text_box], [chatbot])
112
- submit_button.click(fn, [text_box], [chatbot])
113
  clear_button.click(
114
  fn=lambda: ("", ""),
115
  outputs=[text_box, chatbot],
 
2
  # -*- coding: utf-8 -*-
3
  import argparse
4
  import os
5
+ from threading import Thread
6
 
7
  import gradio as gr
8
  from transformers import AutoModel, AutoTokenizer
9
  from transformers.models.auto import AutoModelForCausalLM, AutoTokenizer
10
+ from transformers.generation.streamers import TextIteratorStreamer
11
  import torch
12
 
13
  from project_settings import project_path
 
73
  )
74
  model = model.bfloat16().eval()
75
 
76
+ def fn_non_stream(text: str):
77
  input_ids = tokenizer(
78
  text,
79
  return_tensors="pt",
 
85
 
86
  with torch.no_grad():
87
  outputs = model.generate(
88
+ input_ids=input_ids,
89
+ max_new_tokens=args.max_new_tokens,
90
+ do_sample=True,
91
+ top_p=args.top_p,
92
+ temperature=args.temperature,
93
+ repetition_penalty=args.repetition_penalty,
94
  eos_token_id=tokenizer.eos_token_id
95
  )
96
  outputs = outputs.tolist()[0][len(input_ids[0]):]
 
99
 
100
  return [(text, response)]
101
 
102
+ def fn_stream(text: str):
103
+ input_ids = tokenizer(
104
+ text,
105
+ return_tensors="pt",
106
+ add_special_tokens=False,
107
+ ).input_ids.to(args.device)
108
+ bos_token_id = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to(args.device)
109
+ eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long).to(args.device)
110
+ input_ids = torch.concat([bos_token_id, input_ids, eos_token_id], dim=1)
111
+
112
+ streamer = TextIteratorStreamer(tokenizer=tokenizer)
113
+
114
+ generation_kwargs = dict(
115
+ inputs=input_ids,
116
+ max_new_tokens=args.max_new_tokens,
117
+ do_sample=True,
118
+ top_p=args.top_p,
119
+ temperature=args.temperature,
120
+ repetition_penalty=args.repetition_penalty,
121
+ eos_token_id=tokenizer.eos_token_id,
122
+ pad_token_id=tokenizer.pad_token_id,
123
+ streamer=streamer,
124
+ )
125
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
126
+ thread.start()
127
+
128
+ output = ""
129
+ for output_ in streamer:
130
+ output_ = output_.tolist()[0][len(input_ids[0]):]
131
+ output_ = tokenizer.decode(output_)
132
+ output_ = output_.strip().replace(tokenizer.eos_token, "").strip()
133
+
134
+ output += output_
135
+
136
+ result = [(text, output)]
137
+ chatbot.value = result
138
+ yield result
139
+
140
  with gr.Blocks() as blocks:
141
  gr.Markdown(value=description)
142
 
 
151
 
152
  gr.Examples(examples, text_box)
153
 
154
+ text_box.submit(fn_stream, [text_box], [chatbot])
155
+ submit_button.click(fn_stream, [text_box], [chatbot])
156
  clear_button.click(
157
  fn=lambda: ("", ""),
158
  outputs=[text_box, chatbot],