teaevo commited on
Commit
c8d5ecf
·
1 Parent(s): 5d0a9ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -21
app.py CHANGED
@@ -26,8 +26,8 @@ chatbot_model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)
26
  #model_name = "microsoft/tapex-base-finetuned-wikisql"
27
  model_name = "microsoft/tapex-large-finetuned-wtq"
28
  #model_name = "microsoft/tapex-base-finetuned-wtq"
29
- tokenizer = TapexTokenizer.from_pretrained(model_name)
30
- model = BartForConditionalGeneration.from_pretrained(model_name)
31
 
32
  data = {
33
  "year": [1896, 1900, 1904, 2004, 2008, 2012],
@@ -37,37 +37,42 @@ table = pd.DataFrame.from_dict(data)
37
 
38
  bot_input_ids = None
39
 
 
 
 
 
40
  def chatbot_response(user_message):
41
  # Generate chatbot response using the chatbot model
42
  #inputs = chatbot_tokenizer.encode("User: " + user_message, return_tensors="pt")
43
  #outputs = chatbot_model.generate(inputs, max_length=100, num_return_sequences=1)
44
  #response = chatbot_tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- global bot_input_ids
47
- # encode the new user input, add the eos_token and return a tensor in Pytorch
48
- new_user_input_ids = chatbot_tokenizer.encode(user_message + chatbot_tokenizer.eos_token, return_tensors='pt')
49
-
50
- # append the new user input tokens to the chat history
51
- if bot_input_ids is None:
52
- bot_input_ids = new_user_input_ids
53
- else:
54
- bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
55
-
56
- # generated a response while limiting the total chat history to 1000 tokens,
57
- chat_history_ids = chatbot_model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
58
-
59
- # pretty print last ouput tokens from bot
60
- response = chatbot_tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
61
-
62
  return response
63
 
64
  def sql_response(user_query):
65
 
66
  #inputs = tokenizer.encode("User: " + user_query, return_tensors="pt")
67
  inputs = user_query
68
- encoding = tokenizer(table=table, query=inputs, return_tensors="pt")
69
- outputs = model.generate(**encoding)
70
- response = tokenizer.batch_decode(outputs, skip_special_tokens=True)
71
 
72
  return response
73
 
 
26
  #model_name = "microsoft/tapex-base-finetuned-wikisql"
27
  model_name = "microsoft/tapex-large-finetuned-wtq"
28
  #model_name = "microsoft/tapex-base-finetuned-wtq"
29
+ sql_tokenizer = TapexTokenizer.from_pretrained(model_name)
30
+ sql_model = BartForConditionalGeneration.from_pretrained(model_name)
31
 
32
  data = {
33
  "year": [1896, 1900, 1904, 2004, 2008, 2012],
 
37
 
38
  bot_input_ids = None
39
 
40
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
41
+ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
42
+
43
+
44
  def chatbot_response(user_message):
45
  # Generate chatbot response using the chatbot model
46
  #inputs = chatbot_tokenizer.encode("User: " + user_message, return_tensors="pt")
47
  #outputs = chatbot_model.generate(inputs, max_length=100, num_return_sequences=1)
48
  #response = chatbot_tokenizer.decode(outputs[0], skip_special_tokens=True)
49
+ response = None
50
+
51
+ # Let's chat for 5 lines
52
+ for step in range(1):
53
+ # encode the new user input, add the eos_token and return a tensor in Pytorch
54
+ new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='pt')
55
+
56
+ # append the new user input tokens to the chat history
57
+ bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
58
+
59
+ # generated a response while limiting the total chat history to 1000 tokens,
60
+ chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
61
+
62
+ # pretty print last ouput tokens from bot
63
+ #print("DialoGPT: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))
64
 
65
+ response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
66
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  return response
68
 
69
  def sql_response(user_query):
70
 
71
  #inputs = tokenizer.encode("User: " + user_query, return_tensors="pt")
72
  inputs = user_query
73
+ encoding = sql_tokenizer(table=table, query=inputs, return_tensors="pt")
74
+ outputs = sql_model.generate(**encoding)
75
+ response = sql_tokenizer.batch_decode(outputs, skip_special_tokens=True)
76
 
77
  return response
78