nileshhanotia commited on
Commit
9e11341
1 Parent(s): b4d7a19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -117
app.py CHANGED
@@ -1,120 +1,13 @@
1
- import mysql.connector
2
- from mysql.connector import Error
3
- import requests
4
- import json
5
- import os
6
 
7
- def generate_sql_query(natural_language_query, schema_info, space_url):
8
- """Generate SQL query using Hugging Face Space API."""
9
-
10
- # Construct a more structured prompt
11
- prompt = f"""Given this SQL table schema:
12
- {schema_info}
13
 
14
- Write a SQL query that will:
15
- {natural_language_query}
 
16
 
17
- The query should be valid MySQL syntax and include only the SELECT statement."""
18
-
19
- # Make API request to the Hugging Face Space
20
- payload = {
21
- "inputs": prompt,
22
- "options": {
23
- "use_cache": False
24
- }
25
- }
26
-
27
- try:
28
- response = requests.post(space_url, json=payload)
29
- if response.status_code == 200:
30
- return response.json().get('generated_text', '').strip()
31
- else:
32
- raise Exception(f"API request failed: {response.text}")
33
- except Exception as e:
34
- print(f"API Error: {str(e)}")
35
- return None
36
-
37
- def main():
38
- try:
39
- # Define the Hugging Face Space URL
40
- space_url = "https://huggingface.co/spaces/nileshhanotia/sql"
41
-
42
- # Define your schema information
43
- schema_info = """
44
- CREATE TABLE sales (
45
- pizza_id DECIMAL(8,2) PRIMARY KEY,
46
- order_id DECIMAL(8,2),
47
- pizza_name_id VARCHAR(14),
48
- quantity DECIMAL(4,2),
49
- order_date DATE,
50
- order_time VARCHAR(8),
51
- unit_price DECIMAL(5,2),
52
- total_price DECIMAL(5,2),
53
- pizza_size VARCHAR(3),
54
- pizza_category VARCHAR(7),
55
- pizza_ingredients VARCHAR(97),
56
- pizza_name VARCHAR(42)
57
- );
58
- """
59
-
60
- # Establish connection to the database
61
- connection = mysql.connector.connect(
62
- host="localhost",
63
- database="pizza",
64
- user="root",
65
- password="root",
66
- port=8889
67
- )
68
-
69
- if connection.is_connected():
70
- cursor = connection.cursor()
71
- print("Database connected successfully!")
72
-
73
- while True:
74
- try:
75
- # Get user input
76
- print("\nEnter your question (or 'exit' to quit):")
77
- natural_language_query = input("> ").strip()
78
-
79
- if natural_language_query.lower() == 'exit':
80
- break
81
-
82
- # Generate and execute query
83
- sql_query = generate_sql_query(natural_language_query, schema_info, space_url)
84
-
85
- if sql_query:
86
- print(f"\nExecuting SQL Query:\n{sql_query}")
87
- cursor.execute(sql_query)
88
- records = cursor.fetchall()
89
-
90
- # Print results
91
- if records:
92
- print("\nResults:")
93
- # Get column names
94
- columns = [desc[0] for desc in cursor.description]
95
- print(" | ".join(columns))
96
- print("-" * (len(" | ".join(columns)) + 10))
97
- for row in records:
98
- print(" | ".join(str(val) for val in row))
99
- else:
100
- print("\nNo results found.")
101
-
102
- except KeyboardInterrupt:
103
- print("\nOperation cancelled by user.")
104
- continue
105
- except Exception as e:
106
- print(f"\nError: {str(e)}")
107
- continue
108
-
109
- except Error as e:
110
- print(f"\nDatabase error: {str(e)}")
111
- except Exception as e:
112
- print(f"\nApplication error: {str(e)}")
113
- finally:
114
- if 'connection' in locals() and connection.is_connected():
115
- cursor.close()
116
- connection.close()
117
- print("\nMySQL connection closed.")
118
-
119
- if __name__ == "__main__":
120
- main()
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
3
 
4
+ model_name = "defog/sqlcoder-7b-2"
5
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
6
+ model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
7
 
8
+ def generate_sql(natural_language_query):
9
+ # Define your SQL generation logic here
10
+ return sql_query
11
 
12
+ iface = gr.Interface(fn=generate_sql, inputs="text", outputs="text")
13
+ iface.launch()