Mikhil-jivus commited on
Commit
b17ecc2
1 Parent(s): ae4333a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -15
app.py CHANGED
@@ -18,6 +18,16 @@ model = AutoModelForCausalLM.from_pretrained(
18
  device_map="auto" # Automatically use available GPU/CPU efficiently
19
  )
20
 
 
 
 
 
 
 
 
 
 
 
21
  def respond(
22
  message,
23
  history: list[tuple[str, str]],
@@ -26,18 +36,17 @@ def respond(
26
  temperature,
27
  top_p,
28
  ):
29
- messages = [{"role": "system", "content": system_message}]
30
-
31
- for val in history:
32
- if val[0]:
33
- messages.append({"role": "user", "content": val[0]})
34
- if val[1]:
35
- messages.append({"role": "assistant", "content": val[1]})
36
 
37
- messages.append({"role": "user", "content": message})
 
 
38
 
39
  # Tokenize the input messages
40
- input_text = system_message + " ".join([f"{msg['role']}: {msg['content']}" for msg in messages])
41
  input_ids = tokenizer.encode(input_text, return_tensors="pt")
42
 
43
  # Move input_ids to the GPU
@@ -60,15 +69,19 @@ def respond(
60
  # Decode the response
61
  response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
62
 
63
- yield response
 
 
 
 
 
 
64
 
65
- """
66
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
67
- """
68
  demo = gr.ChatInterface(
69
  respond,
70
  additional_inputs=[
71
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
72
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
73
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
74
  gr.Slider(
@@ -82,4 +95,4 @@ demo = gr.ChatInterface(
82
  )
83
 
84
  if __name__ == "__main__":
85
- demo.launch(share=True)
 
18
  device_map="auto" # Automatically use available GPU/CPU efficiently
19
  )
20
 
21
+ # Define a function to clean up any repeated segments in the generated response
22
+ def clean_response(response, history):
23
+ # Check for repetition in the response and remove it
24
+ if len(history) > 0:
25
+ last_user_message, last_bot_response = history[-1]
26
+ if last_bot_response in response:
27
+ response = response.replace(last_bot_response, "").strip()
28
+
29
+ return response
30
+
31
  def respond(
32
  message,
33
  history: list[tuple[str, str]],
 
36
  temperature,
37
  top_p,
38
  ):
39
+ # Add system prompt only once at the beginning of the conversation
40
+ if len(history) == 0:
41
+ input_text = f"system: {system_message}\nuser: {message}\n"
42
+ else:
43
+ input_text = f"user: {message}\n"
 
 
44
 
45
+ # Append previous conversation history to the input text
46
+ for user_msg, bot_msg in history:
47
+ input_text += f"user: {user_msg}\nassistant: {bot_msg}\n"
48
 
49
  # Tokenize the input messages
 
50
  input_ids = tokenizer.encode(input_text, return_tensors="pt")
51
 
52
  # Move input_ids to the GPU
 
69
  # Decode the response
70
  response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
71
 
72
+ # Clean the response to remove any repeated or unnecessary text
73
+ response = clean_response(response, history)
74
+
75
+ # Update history with the new user message and bot response
76
+ history.append((message, response))
77
+
78
+ return response
79
 
80
+ # Set up the Gradio app interface
 
 
81
  demo = gr.ChatInterface(
82
  respond,
83
  additional_inputs=[
84
+ gr.Textbox(value="You are a helpful and friendly assistant.", label="System message"),
85
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
86
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
87
  gr.Slider(
 
95
  )
96
 
97
  if __name__ == "__main__":
98
+ demo.launch(share=True)