rishiraj commited on
Commit
02f39a0
β€’
1 Parent(s): 816f11c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +233 -0
app.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ import os
3
+ import logging
4
+ import sys
5
+ import gradio as gr
6
+ import torch
7
+ import gc
8
+ from app_modules.utils import *
9
+ from app_modules.presets import *
10
+ from app_modules.overwrites import *
11
+
12
+ logging.basicConfig(
13
+ level=logging.DEBUG,
14
+ format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
15
+ )
16
+
17
+ base_model = "HuggingFaceH4/zephyr-7b-beta"
18
+ adapter_model = None
19
+ tokenizer,model,device = load_tokenizer_and_model(base_model,adapter_model)
20
+
21
+ total_count = 0
22
+ def predict(text,
23
+ chatbot,
24
+ history,
25
+ top_p,
26
+ temperature,
27
+ max_length_tokens,
28
+ max_context_length_tokens,):
29
+ if text=="":
30
+ yield chatbot,history,"Empty context."
31
+ return
32
+ try:
33
+ model
34
+ except:
35
+ yield [[text,"No Model Found"]],[],"No Model Found"
36
+ return
37
+
38
+ inputs = generate_prompt_with_history(text,history,tokenizer,max_length=max_context_length_tokens)
39
+ if inputs is None:
40
+ yield chatbot,history,"Input too long."
41
+ return
42
+ else:
43
+ prompt,inputs=inputs
44
+ begin_length = len(prompt)
45
+ input_ids = inputs["input_ids"][:,-max_context_length_tokens:].to(device)
46
+ torch.cuda.empty_cache()
47
+ global total_count
48
+ total_count += 1
49
+ print(total_count)
50
+ if total_count % 50 == 0 :
51
+ os.system("nvidia-smi")
52
+ with torch.no_grad():
53
+ for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p):
54
+ if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False:
55
+ if "[|Human|]" in x:
56
+ x = x[:x.index("[|Human|]")].strip()
57
+ if "[|AI|]" in x:
58
+ x = x[:x.index("[|AI|]")].strip()
59
+ x = x.strip()
60
+ a, b= [[y[0],convert_to_markdown(y[1])] for y in history]+[[text, convert_to_markdown(x)]],history + [[text,x]]
61
+ yield a, b, "Generating..."
62
+ if shared_state.interrupted:
63
+ shared_state.recover()
64
+ try:
65
+ yield a, b, "Stop: Success"
66
+ return
67
+ except:
68
+ pass
69
+ del input_ids
70
+ gc.collect()
71
+ torch.cuda.empty_cache()
72
+ #print(text)
73
+ #print(x)
74
+ #print("="*80)
75
+ try:
76
+ yield a,b,"Generate: Success"
77
+ except:
78
+ pass
79
+
80
+ def retry(
81
+ text,
82
+ chatbot,
83
+ history,
84
+ top_p,
85
+ temperature,
86
+ max_length_tokens,
87
+ max_context_length_tokens,
88
+ ):
89
+ logging.info("Retry...")
90
+ if len(history) == 0:
91
+ yield chatbot, history, f"Empty context"
92
+ return
93
+ chatbot.pop()
94
+ inputs = history.pop()[0]
95
+ for x in predict(inputs,chatbot,history,top_p,temperature,max_length_tokens,max_context_length_tokens):
96
+ yield x
97
+
98
+
99
+ gr.Chatbot.postprocess = postprocess
100
+
101
+ with open("assets/custom.css", "r", encoding="utf-8") as f:
102
+ customCSS = f.read()
103
+
104
+ with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
105
+ history = gr.State([])
106
+ user_question = gr.State("")
107
+ with gr.Row():
108
+ gr.HTML(title)
109
+ status_display = gr.Markdown("Success", elem_id="status_display")
110
+ gr.Markdown(description_top)
111
+ with gr.Row(scale=1).style(equal_height=True):
112
+ with gr.Column(scale=5):
113
+ with gr.Row(scale=1):
114
+ chatbot = gr.Chatbot(elem_id="chuanhu_chatbot").style(height="100%")
115
+ with gr.Row(scale=1):
116
+ with gr.Column(scale=12):
117
+ user_input = gr.Textbox(
118
+ show_label=False, placeholder="Enter text"
119
+ ).style(container=False)
120
+ with gr.Column(min_width=70, scale=1):
121
+ submitBtn = gr.Button("Send")
122
+ with gr.Column(min_width=70, scale=1):
123
+ cancelBtn = gr.Button("Stop")
124
+ with gr.Row(scale=1):
125
+ emptyBtn = gr.Button(
126
+ "🧹 New Conversation",
127
+ )
128
+ retryBtn = gr.Button("πŸ”„ Regenerate")
129
+ delLastBtn = gr.Button("πŸ—‘οΈ Remove Last Turn")
130
+ with gr.Column():
131
+ with gr.Column(min_width=50, scale=1):
132
+ with gr.Tab(label="Parameter Setting"):
133
+ gr.Markdown("# Parameters")
134
+ top_p = gr.Slider(
135
+ minimum=-0,
136
+ maximum=1.0,
137
+ value=0.95,
138
+ step=0.05,
139
+ interactive=True,
140
+ label="Top-p",
141
+ )
142
+ temperature = gr.Slider(
143
+ minimum=0.1,
144
+ maximum=2.0,
145
+ value=1,
146
+ step=0.1,
147
+ interactive=True,
148
+ label="Temperature",
149
+ )
150
+ max_length_tokens = gr.Slider(
151
+ minimum=0,
152
+ maximum=512,
153
+ value=512,
154
+ step=8,
155
+ interactive=True,
156
+ label="Max Generation Tokens",
157
+ )
158
+ max_context_length_tokens = gr.Slider(
159
+ minimum=0,
160
+ maximum=4096,
161
+ value=2048,
162
+ step=128,
163
+ interactive=True,
164
+ label="Max History Tokens",
165
+ )
166
+ gr.Markdown(description)
167
+
168
+ predict_args = dict(
169
+ fn=predict,
170
+ inputs=[
171
+ user_question,
172
+ chatbot,
173
+ history,
174
+ top_p,
175
+ temperature,
176
+ max_length_tokens,
177
+ max_context_length_tokens,
178
+ ],
179
+ outputs=[chatbot, history, status_display],
180
+ show_progress=True,
181
+ )
182
+ retry_args = dict(
183
+ fn=retry,
184
+ inputs=[
185
+ user_input,
186
+ chatbot,
187
+ history,
188
+ top_p,
189
+ temperature,
190
+ max_length_tokens,
191
+ max_context_length_tokens,
192
+ ],
193
+ outputs=[chatbot, history, status_display],
194
+ show_progress=True,
195
+ )
196
+
197
+ reset_args = dict(
198
+ fn=reset_textbox, inputs=[], outputs=[user_input, status_display]
199
+ )
200
+
201
+ # Chatbot
202
+ transfer_input_args = dict(
203
+ fn=transfer_input, inputs=[user_input], outputs=[user_question, user_input, submitBtn], show_progress=True
204
+ )
205
+
206
+ predict_event1 = user_input.submit(**transfer_input_args).then(**predict_args)
207
+
208
+ predict_event2 = submitBtn.click(**transfer_input_args).then(**predict_args)
209
+
210
+ emptyBtn.click(
211
+ reset_state,
212
+ outputs=[chatbot, history, status_display],
213
+ show_progress=True,
214
+ )
215
+ emptyBtn.click(**reset_args)
216
+
217
+ predict_event3 = retryBtn.click(**retry_args)
218
+
219
+ delLastBtn.click(
220
+ delete_last_conversation,
221
+ [chatbot, history],
222
+ [chatbot, history, status_display],
223
+ show_progress=True,
224
+ )
225
+ cancelBtn.click(
226
+ cancel_outputing, [], [status_display],
227
+ cancels=[
228
+ predict_event1,predict_event2,predict_event3
229
+ ]
230
+ )
231
+ demo.title = "🌷 Brain"
232
+
233
+ demo.queue(concurrency_count=1).launch()