teaevo commited on
Commit
5fb63f2
·
1 Parent(s): 7e20218

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -6
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  from transformers import TapexTokenizer, BartForConditionalGeneration
4
  import pandas as pd
 
5
  #import pkg_resources
6
 
7
  '''
@@ -14,7 +15,7 @@ for package, version in installed_packages.items():
14
  '''
15
 
16
  # Load the chatbot model
17
- chatbot_model_name = "gpt2"
18
  chatbot_tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
19
  chatbot_model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)
20
 
@@ -36,10 +37,22 @@ table = pd.DataFrame.from_dict(data)
36
 
37
  def chatbot_response(user_message):
38
  # Generate chatbot response using the chatbot model
39
- inputs = chatbot_tokenizer.encode("User: " + user_message, return_tensors="pt")
40
- outputs = chatbot_model.generate(inputs, max_length=100, num_return_sequences=1)
41
- response = chatbot_tokenizer.decode(outputs[0], skip_special_tokens=True)
42
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  return response
44
 
45
  def sql_response(user_query):
@@ -75,7 +88,8 @@ sql_interface = gr.Interface(
75
  )
76
 
77
  # Launch the Gradio interface
78
- #if __name__ == "__main__":
79
- chatbot_interface.launch()
 
80
  sql_interface.launch()
81
 
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  from transformers import TapexTokenizer, BartForConditionalGeneration
4
  import pandas as pd
5
+ import torch
6
  #import pkg_resources
7
 
8
  '''
 
15
  '''
16
 
17
  # Load the chatbot model
18
+ chatbot_model_name = "microsoft/DialoGPT-medium"
19
  chatbot_tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
20
  chatbot_model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)
21
 
 
37
 
38
  def chatbot_response(user_message):
39
  # Generate chatbot response using the chatbot model
40
+ #inputs = chatbot_tokenizer.encode("User: " + user_message, return_tensors="pt")
41
+ #outputs = chatbot_model.generate(inputs, max_length=100, num_return_sequences=1)
42
+ #response = chatbot_tokenizer.decode(outputs[0], skip_special_tokens=True)
43
 
44
+ # encode the new user input, add the eos_token and return a tensor in Pytorch
45
+ new_user_input_ids = chatbot_tokenizer.encode(user_message + tokenizer.eos_token, return_tensors='pt')
46
+
47
+ # append the new user input tokens to the chat history
48
+ bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
49
+
50
+ # generated a response while limiting the total chat history to 1000 tokens,
51
+ chat_history_ids = chatbot_model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
52
+
53
+ # pretty print last ouput tokens from bot
54
+ response = "DialoGPT: {}".format(chatbot_tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True))
55
+
56
  return response
57
 
58
  def sql_response(user_query):
 
88
  )
89
 
90
  # Launch the Gradio interface
91
+ if __name__ == "__main__":
92
+ chatbot_interface.launch()
93
+
94
  sql_interface.launch()
95