teaevo commited on
Commit
830c2c9
·
1 Parent(s): 4f9338d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -23
app.py CHANGED
@@ -40,17 +40,11 @@ data = {
40
  }
41
  table = pd.DataFrame.from_dict(data)
42
 
43
- chat_history_ids = None
44
- #bot_input_ids = None
45
 
 
46
 
47
- def chatbot_response(user_message):
48
-
49
- global chat_history_ids
50
- #global bot_input_ids
51
-
52
- print(chat_history_ids is None)
53
- #print(bot_input_ids is None)
54
  # Check if the user input is a question
55
  is_question = "?" in user_message
56
 
@@ -67,27 +61,28 @@ def chatbot_response(user_message):
67
  outputs = chatbot_model.generate(inputs, max_length=100, num_return_sequences=1)
68
  response = chatbot_tokenizer.decode(outputs[0], skip_special_tokens=True)
69
  '''
70
- # encode the new user input, add the eos_token and return a tensor in Pytorch
71
- new_user_input_ids = chatbot_tokenizer.encode(user_message + chatbot_tokenizer.eos_token, return_tensors='pt')
72
 
73
  # append the new user input tokens to the chat history
74
- if chat_history_ids is not None:
75
- bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
76
- else:
77
- bot_input_ids = new_user_input_ids
78
-
79
- # generated a response while limiting the total chat history to 1000 tokens,
80
- chat_history_ids = chatbot_model.generate(bot_input_ids, max_length=1000, pad_token_id=chatbot_tokenizer.eos_token_id)
81
-
82
- response = chatbot_tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
83
 
84
- return response
 
 
 
 
85
 
86
  # Define the chatbot and SQL execution interfaces using Gradio
87
  chatbot_interface = gr.Interface(
88
  fn=chatbot_response,
89
- inputs=gr.Textbox(prompt="You:"),
90
- outputs=gr.Textbox(),
 
 
91
  live=True,
92
  capture_session=True,
93
  title="ST Chatbot",
 
40
  }
41
  table = pd.DataFrame.from_dict(data)
42
 
43
+ history = None
 
44
 
45
+ def chatbot_response(user_message, history=[]):
46
 
47
+ global history
 
 
 
 
 
 
48
  # Check if the user input is a question
49
  is_question = "?" in user_message
50
 
 
61
  outputs = chatbot_model.generate(inputs, max_length=100, num_return_sequences=1)
62
  response = chatbot_tokenizer.decode(outputs[0], skip_special_tokens=True)
63
  '''
64
+ # tokenize the new input sentence
65
+ new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
66
 
67
  # append the new user input tokens to the chat history
68
+ bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
69
+
70
+ # generate a response
71
+ history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
 
 
 
 
 
72
 
73
+ # convert the tokens to text, and then split the responses into the right format
74
+ response = tokenizer.decode(history[0]).split("<|endoftext|>")
75
+ response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
76
+
77
+ return response, history
78
 
79
  # Define the chatbot and SQL execution interfaces using Gradio
80
  chatbot_interface = gr.Interface(
81
  fn=chatbot_response,
82
+ #inputs=gr.Textbox(prompt="You:"),
83
+ #outputs=gr.Textbox(),
84
+ inputs=["text", "state"],
85
+ outputs=["chatbot", "state"],
86
  live=True,
87
  capture_session=True,
88
  title="ST Chatbot",