cryscan commited on
Commit
99d39c4
1 Parent(s): 74ee93b

Add alternative and derail protection.

Browse files
Files changed (1) hide show
  1. app.py +37 -4
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- import os, gc, torch
3
  from datetime import datetime
4
  from huggingface_hub import hf_hub_download
5
  from pynvml import *
@@ -123,6 +123,15 @@ def user(message, chatbot):
123
  print(f"User: {message}")
124
  return "", chatbot + [[message, None]]
125
 
 
 
 
 
 
 
 
 
 
126
  def chat(
127
  prompt,
128
  user,
@@ -139,6 +148,9 @@ def chat(
139
  alpha_presence=float(presence_penalty),
140
  token_ban=[], # ban the generation of some tokens
141
  token_stop=[]) # stop generation whenever you see any token here
 
 
 
142
 
143
  message = chatbot[-1][0]
144
  message = message.strip().replace('\r\n','\n').replace('\n\n','\n')
@@ -154,11 +166,14 @@ def chat(
154
  prompt = f"\n{prompt}\n\n"
155
 
156
  out, state = model.forward(pipeline.encode(prompt), None)
157
- history = [state, []]
158
  print("History reloaded.")
159
 
160
- [state, all_tokens] = history
 
 
161
  out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:], state)
 
162
 
163
  print("Bot: ", end='')
164
 
@@ -208,11 +223,27 @@ def chat(
208
  if '\n\n' in out_str:
209
  break
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  gc.collect()
212
  torch.cuda.empty_cache()
213
 
214
  chatbot[-1][1] = out_str.strip()
215
- history = [state, all_tokens]
216
  yield chatbot, history
217
 
218
  with gr.Blocks(title=title) as demo:
@@ -245,6 +276,7 @@ with gr.Blocks(title=title) as demo:
245
  message = gr.Textbox(label="Message")
246
  with gr.Row():
247
  send = gr.Button("Send", variant="primary")
 
248
  clear = gr.Button("Clear", variant="secondary")
249
  with gr.Column():
250
  with gr.Row():
@@ -269,6 +301,7 @@ with gr.Blocks(title=title) as demo:
269
  chat_outputs = [chatbot, state]
270
  message.submit(user, [message, chatbot], [message, chatbot], queue=False).then(chat, chat_inputs, chat_outputs)
271
  send.click(user, [message, chatbot], [message, chatbot], queue=False).then(chat, chat_inputs, chat_outputs)
 
272
  clear.click(lambda: ([], None, ""), [], [chatbot, state, message], queue=False)
273
 
274
  demo.queue(max_size=10)
 
1
  import gradio as gr
2
+ import os, copy, gc, torch
3
  from datetime import datetime
4
  from huggingface_hub import hf_hub_download
5
  from pynvml import *
 
123
  print(f"User: {message}")
124
  return "", chatbot + [[message, None]]
125
 
126
+ def alternative(chatbot, history):
127
+ if not chatbot or not history:
128
+ return chatbot, history
129
+
130
+ chatbot[-1][1] = None
131
+ history[0] = copy.deepcopy(history[1])
132
+
133
+ return chatbot, history
134
+
135
  def chat(
136
  prompt,
137
  user,
 
148
  alpha_presence=float(presence_penalty),
149
  token_ban=[], # ban the generation of some tokens
150
  token_stop=[]) # stop generation whenever you see any token here
151
+
152
+ if not chatbot:
153
+ return chatbot, history
154
 
155
  message = chatbot[-1][0]
156
  message = message.strip().replace('\r\n','\n').replace('\n\n','\n')
 
166
  prompt = f"\n{prompt}\n\n"
167
 
168
  out, state = model.forward(pipeline.encode(prompt), None)
169
+ history = [state, None, []] # [state, state_pre, tokens]
170
  print("History reloaded.")
171
 
172
+ [state, _, all_tokens] = history
173
+ state_pre_0 = copy.deepcopy(state)
174
+
175
  out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:], state)
176
+ state_pre_1 = copy.deepcopy(state) # For recovery
177
 
178
  print("Bot: ", end='')
179
 
 
223
  if '\n\n' in out_str:
224
  break
225
 
226
+ # State recovery
227
+ if f'{user}:' in out_str or f'{bot}:' in out_str:
228
+ idx_user = out_str.find(f'{user}:')
229
+ idx_user = len(out_str) if idx_user == -1 else idx_user
230
+ idx_bot = out_str.find(f'{bot}:')
231
+ idx_bot = len(out_str) if idx_bot == -1 else idx_bot
232
+ idx = min(idx_user, idx_bot)
233
+
234
+ if idx < len(out_str):
235
+ out_str = f" {out_str[:idx].strip()}\n\n"
236
+ tokens = pipeline.encode(out_str)
237
+
238
+ all_tokens = all_tokens[:begin] + tokens
239
+ out, state = model.forward(tokens, state_pre_1)
240
+ break
241
+
242
  gc.collect()
243
  torch.cuda.empty_cache()
244
 
245
  chatbot[-1][1] = out_str.strip()
246
+ history = [state, state_pre_0, all_tokens]
247
  yield chatbot, history
248
 
249
  with gr.Blocks(title=title) as demo:
 
276
  message = gr.Textbox(label="Message")
277
  with gr.Row():
278
  send = gr.Button("Send", variant="primary")
279
+ alt = gr.Button("Alternative", variant="secondary")
280
  clear = gr.Button("Clear", variant="secondary")
281
  with gr.Column():
282
  with gr.Row():
 
301
  chat_outputs = [chatbot, state]
302
  message.submit(user, [message, chatbot], [message, chatbot], queue=False).then(chat, chat_inputs, chat_outputs)
303
  send.click(user, [message, chatbot], [message, chatbot], queue=False).then(chat, chat_inputs, chat_outputs)
304
+ alt.click(alternative, [chatbot, state], [chatbot, state], queue=False).then(chat, chat_inputs, chat_outputs)
305
  clear.click(lambda: ([], None, ""), [], [chatbot, state, message], queue=False)
306
 
307
  demo.queue(max_size=10)