teaevo commited on
Commit
c17ba77
·
1 Parent(s): ffc0ad6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -4
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  from transformers import TapexTokenizer, BartForConditionalGeneration
4
  import pandas as pd
5
- #import torch
6
  #import pkg_resources
7
 
8
  '''
@@ -15,7 +15,7 @@ for package, version in installed_packages.items():
15
  '''
16
 
17
  # Load the chatbot model
18
- chatbot_model_name = "gpt2" #"microsoft/DialoGPT-medium"
19
  chatbot_tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
20
  chatbot_model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)
21
 
@@ -24,8 +24,9 @@ chatbot_model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)
24
  #wikisql take longer to process
25
  #model_name = "microsoft/tapex-large-finetuned-wikisql" # You can change this to any other model from the list above
26
  #model_name = "microsoft/tapex-base-finetuned-wikisql"
27
- model_name = "microsoft/tapex-large-finetuned-wtq"
28
  #model_name = "microsoft/tapex-base-finetuned-wtq"
 
 
29
  sql_tokenizer = TapexTokenizer.from_pretrained(model_name)
30
  sql_model = BartForConditionalGeneration.from_pretrained(model_name)
31
 
@@ -35,7 +36,11 @@ data = {
35
  }
36
  table = pd.DataFrame.from_dict(data)
37
 
 
 
38
  def chatbot_response(user_message):
 
 
39
  # Check if the user input is a question
40
  is_question = "?" in user_message
41
 
@@ -47,10 +52,24 @@ def chatbot_response(user_message):
47
  response = sql_tokenizer.batch_decode(outputs, skip_special_tokens=True)
48
  else:
49
  # Generate chatbot response using the chatbot model
 
50
  inputs = chatbot_tokenizer.encode("User: " + user_message, return_tensors="pt")
51
  outputs = chatbot_model.generate(inputs, max_length=100, num_return_sequences=1)
52
  response = chatbot_tokenizer.decode(outputs[0], skip_special_tokens=True)
53
-
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  return response
55
 
56
  # Define the chatbot and SQL execution interfaces using Gradio
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  from transformers import TapexTokenizer, BartForConditionalGeneration
4
  import pandas as pd
5
+ import torch
6
  #import pkg_resources
7
 
8
  '''
 
15
  '''
16
 
17
  # Load the chatbot model
18
+ chatbot_model_name = "microsoft/DialoGPT-medium" #"gpt2"
19
  chatbot_tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
20
  chatbot_model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)
21
 
 
24
  #wikisql take longer to process
25
  #model_name = "microsoft/tapex-large-finetuned-wikisql" # You can change this to any other model from the list above
26
  #model_name = "microsoft/tapex-base-finetuned-wikisql"
 
27
  #model_name = "microsoft/tapex-base-finetuned-wtq"
28
+ #model_name = "microsoft/tapex-large-finetuned-wtq"
29
+ model_name = "google/tapas-base-finetuned-wtq"
30
  sql_tokenizer = TapexTokenizer.from_pretrained(model_name)
31
  sql_model = BartForConditionalGeneration.from_pretrained(model_name)
32
 
 
36
  }
37
  table = pd.DataFrame.from_dict(data)
38
 
39
+ new_chat = True
40
+
41
  def chatbot_response(user_message):
42
+
43
+ global new_chat
44
  # Check if the user input is a question
45
  is_question = "?" in user_message
46
 
 
52
  response = sql_tokenizer.batch_decode(outputs, skip_special_tokens=True)
53
  else:
54
  # Generate chatbot response using the chatbot model
55
+ '''
56
  inputs = chatbot_tokenizer.encode("User: " + user_message, return_tensors="pt")
57
  outputs = chatbot_model.generate(inputs, max_length=100, num_return_sequences=1)
58
  response = chatbot_tokenizer.decode(outputs[0], skip_special_tokens=True)
59
+ '''
60
+ # encode the new user input, add the eos_token and return a tensor in Pytorch
61
+ new_user_input_ids = chatbot_tokenizer.encode(input(">> User:") + chatbot_tokenizer.eos_token, return_tensors='pt')
62
+
63
+ # append the new user input tokens to the chat history
64
+ bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if new_chat is False else new_user_input_ids
65
+
66
+ # generated a response while limiting the total chat history to 1000 tokens,
67
+ chat_history_ids = chatbot_model.generate(bot_input_ids, max_length=1000, pad_token_id=chatbot_tokenizer.eos_token_id)
68
+
69
+ response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
70
+
71
+ new_chat = False
72
+
73
  return response
74
 
75
  # Define the chatbot and SQL execution interfaces using Gradio