teaevo commited on
Commit
f65b03e
1 Parent(s): 52fb501

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -91
app.py CHANGED
@@ -1,105 +1,35 @@
1
- import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- from transformers import TapexTokenizer, BartForConditionalGeneration
4
- import pandas as pd
5
  import torch
 
6
 
7
- import numpy as np
8
- import time
9
- import os
10
- #import pkg_resources
11
-
12
- '''
13
- # Get a list of installed packages and their versions
14
- installed_packages = {pkg.key: pkg.version for pkg in pkg_resources.working_set}
15
-
16
- # Print the list of packages
17
- for package, version in installed_packages.items():
18
- print(f"{package}=={version}")
19
- '''
20
-
21
- # Load the chatbot model
22
- chatbot_model_name = "microsoft/DialoGPT-medium" #"gpt2"
23
- chatbot_tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
24
- chatbot_model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)
25
 
26
 
27
- # Load the SQL Model
28
- #wikisql take longer to process
29
- #model_name = "microsoft/tapex-large-finetuned-wikisql" # You can change this to any other model from the list above
30
- #model_name = "microsoft/tapex-base-finetuned-wikisql"
31
- #model_name = "microsoft/tapex-base-finetuned-wtq"
32
- model_name = "microsoft/tapex-large-finetuned-wtq"
33
- #model_name = "google/tapas-base-finetuned-wtq"
34
- sql_tokenizer = TapexTokenizer.from_pretrained(model_name)
35
- sql_model = BartForConditionalGeneration.from_pretrained(model_name)
36
 
37
- data = {
38
- "year": [1896, 1900, 1904, 2004, 2008, 2012],
39
- "city": ["athens", "paris", "st. louis", "athens", "beijing", "london"]
40
- }
41
- table = pd.DataFrame.from_dict(data)
42
 
43
- def chatbot_response(user_message, history=[]):
 
44
 
45
- # Check if the user input is a question
46
- is_question = "?" in user_message
 
 
47
 
48
- if is_question:
49
- # If the user input is a question, use TAPEx for question-answering
50
- #inputs = user_query
51
- encoding = sql_tokenizer(table=table, query=user_message, return_tensors="pt")
52
- #outputs = sql_model.generate(**encoding)
53
- #response = sql_tokenizer.batch_decode(outputs, skip_special_tokens=True)
54
 
55
- # append the new user input tokens to the chat history
56
- bot_input_ids = torch.cat([torch.LongTensor(history), encoding], dim=-1)
57
-
58
- # generate a response
59
- history = sql_model.generate(bot_input_ids, max_length=1000, pad_token_id=sql_tokenizer.eos_token_id).tolist()
60
-
61
- # convert the tokens to text, and then split the responses into the right format
62
- response = sql_tokenizer.decode(history[0]).split("<|endoftext|>")
63
- response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
64
-
65
-
66
- else:
67
- # Generate chatbot response using the chatbot model
68
- '''
69
- inputs = chatbot_tokenizer.encode("User: " + user_message, return_tensors="pt")
70
- outputs = chatbot_model.generate(inputs, max_length=100, num_return_sequences=1)
71
- response = chatbot_tokenizer.decode(outputs[0], skip_special_tokens=True)
72
- '''
73
- # tokenize the new input sentence
74
- new_user_input_ids = chatbot_tokenizer.encode(user_message + chatbot_tokenizer.eos_token, return_tensors='pt')
75
-
76
- # append the new user input tokens to the chat history
77
- bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
78
-
79
- # generate a response
80
- history = chatbot_model.generate(bot_input_ids, max_length=1000, pad_token_id=chatbot_tokenizer.eos_token_id).tolist()
81
-
82
- # convert the tokens to text, and then split the responses into the right format
83
- response = chatbot_tokenizer.decode(history[0]).split("<|endoftext|>")
84
- response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
85
-
86
- return response, history
87
 
88
- # Define the chatbot and SQL execution interfaces using Gradio
89
- chatbot_interface = gr.Interface(
90
- fn=chatbot_response,
91
- #inputs=gr.Textbox(prompt="You:"),
92
- #outputs=gr.Textbox(),
93
  inputs=["text", "state"],
94
  outputs=["chatbot", "state"],
95
- live=True,
96
- capture_session=True,
97
- title="ST Chatbot",
98
- description="Type your message in the box above, and the chatbot will respond.",
99
  )
100
 
101
- # Launch the Gradio interface
102
- if __name__ == "__main__":
103
- chatbot_interface.launch()
104
-
105
-
 
 
 
 
 
1
  import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
5
+ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
+ def predict(input, history=[]):
9
+ # tokenize the new input sentence
10
+ new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
 
 
 
 
 
 
11
 
12
+ # append the new user input tokens to the chat history
13
+ bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
 
 
 
14
 
15
+ # generate a response
16
+ history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
17
 
18
+ # convert the tokens to text, and then split the responses into the right format
19
+ response = tokenizer.decode(history[0]).split("<|endoftext|>")
20
+ response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
21
+ return response, history
22
 
 
 
 
 
 
 
23
 
24
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ interface = gr.Interface(
27
+ fn=predict,
28
+ theme="default",
29
+ css=".footer {display:none !important}",
 
30
  inputs=["text", "state"],
31
  outputs=["chatbot", "state"],
 
 
 
 
32
  )
33
 
34
+ if __name__ == '__main__':
35
+ interface.launch()