teaevo commited on
Commit
e030ac0
·
1 Parent(s): 32680f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -16
app.py CHANGED
@@ -1,28 +1,56 @@
1
  import gradio as gr
2
- from transformers import AutoModelForQuestionAnswering, AutoTokenizer
3
 
4
- # Load the Tapas model and tokenizer
5
- model_name = "google/tapas-large-finetuned-wtq"
6
- tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- model = AutoModelForQuestionAnswering.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def execute_sql(user_query):
10
- inputs = tokenizer(user_query, return_tensors="pt")
11
- outputs = model(**inputs)
12
- answer = tokenizer.decode(inputs['input_ids'][0][outputs['start_logits'].argmax():outputs['end_logits'].argmax() + 1])
13
- return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- # Define the chatbot interface using Gradio
16
- iface = gr.Interface(
17
  fn=execute_sql,
18
- inputs=gr.Textbox(prompt="Enter your question:"),
19
  outputs=gr.Textbox(),
20
  live=True,
21
  capture_session=True,
22
- title="Database Question Answering Chatbot",
23
- description="Type your questions about the database in the box above, and the chatbot will provide answers.",
24
  )
25
 
26
- # Launch the Gradio interface
 
 
 
27
  if __name__ == "__main__":
28
- iface.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
+ # Load the chatbot model
5
+ chatbot_model_name = "facebook/bart-large-mnli"
6
+ chatbot_tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
7
+ chatbot_model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)
8
+
9
+ # Load the SQL model
10
+ sql_model_name = "your_sql_model_name" # Replace with the name of the SQL model you want to use
11
+ sql_tokenizer = AutoTokenizer.from_pretrained(sql_model_name)
12
+ sql_model = AutoModelForCausalLM.from_pretrained(sql_model_name)
13
+
14
+ def chatbot_response(user_message):
15
+ # Generate chatbot response using the chatbot model
16
+ inputs = chatbot_tokenizer.encode("User: " + user_message, return_tensors="pt")
17
+ outputs = chatbot_model.generate(inputs, max_length=100, num_return_sequences=1)
18
+ response = chatbot_tokenizer.decode(outputs[0], skip_special_tokens=True)
19
+
20
+ return response
21
 
22
  def execute_sql(user_query):
23
+ # Execute SQL query using the SQL model
24
+ inputs = sql_tokenizer(user_query, return_tensors="pt")
25
+ outputs = sql_model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], max_length=1000)
26
+ response = sql_tokenizer.decode(outputs[0], skip_special_tokens=True)
27
+
28
+ return response
29
+
30
+ # Define the chatbot and SQL execution interfaces using Gradio
31
+ chatbot_interface = gr.Interface(
32
+ fn=chatbot_response,
33
+ inputs=gr.Textbox(prompt="You:"),
34
+ outputs=gr.Textbox(),
35
+ live=True,
36
+ capture_session=True,
37
+ title="Chatbot",
38
+ description="Type your message in the box above, and the chatbot will respond.",
39
+ )
40
 
41
+ sql_execution_interface = gr.Interface(
 
42
  fn=execute_sql,
43
+ inputs=gr.Textbox(prompt="Enter your SQL query:"),
44
  outputs=gr.Textbox(),
45
  live=True,
46
  capture_session=True,
47
+ title="SQL Execution",
48
+ description="Type your SQL query in the box above, and the chatbot will execute it.",
49
  )
50
 
51
+ # Combine the chatbot and SQL execution interfaces
52
+ combined_interface = gr.Interface([chatbot_interface, sql_execution_interface], layout="horizontal")
53
+
54
+ # Launch the combined Gradio interface
55
  if __name__ == "__main__":
56
+ combined_interface.launch()