nileshhanotia commited on
Commit
cacc96f
1 Parent(s): 0d53dda

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -5
app.py CHANGED
@@ -1,13 +1,72 @@
1
  import gradio as gr
 
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
 
4
  model_name = "premai-io/prem-1B-SQL"
5
- tokenizer = AutoTokenizer.from_pretrained("premai-io/prem-1B-SQL")
6
- model = AutoModelForCausalLM.from_pretrained("premai-io/prem-1B-SQL")
7
 
8
  def generate_sql(natural_language_query):
9
- # Define your SQL generation logic here
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  return sql_query
11
 
12
- iface = gr.Interface(fn=generate_sql, inputs="text", outputs="text")
13
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import mysql.connector
3
+ from mysql.connector import Error
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
+ # Load the model and tokenizer
7
  model_name = "premai-io/prem-1B-SQL"
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForCausalLM.from_pretrained(model_name)
10
 
11
  def generate_sql(natural_language_query):
12
+ """Generate SQL query from natural language."""
13
+ # Define your schema information
14
+ schema_info = """
15
+ CREATE TABLE sales (
16
+ pizza_id DECIMAL(8,2) PRIMARY KEY,
17
+ order_id DECIMAL(8,2),
18
+ pizza_name_id VARCHAR(14),
19
+ quantity DECIMAL(4,2),
20
+ order_date DATE,
21
+ order_time VARCHAR(8),
22
+ unit_price DECIMAL(5,2),
23
+ total_price DECIMAL(5,2),
24
+ pizza_size VARCHAR(3),
25
+ pizza_category VARCHAR(7),
26
+ pizza_ingredients VARCHAR(97),
27
+ pizza_name VARCHAR(42)
28
+ );
29
+ """
30
+
31
+ # Construct the prompt
32
+ prompt = f"""### Task: Generate a SQL query to answer the following question.
33
+
34
+ ### Database Schema:
35
+ {schema_info}
36
+
37
+ ### Question: {natural_language_query}
38
+
39
+ ### SQL Query:"""
40
+
41
+ # Tokenize and generate
42
+ inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
43
+ outputs = model.generate(
44
+ inputs["input_ids"],
45
+ max_length=512,
46
+ temperature=0.1,
47
+ do_sample=True,
48
+ top_p=0.95,
49
+ num_return_sequences=1,
50
+ eos_token_id=tokenizer.eos_token_id,
51
+ pad_token_id=tokenizer.pad_token_id
52
+ )
53
+
54
+ # Decode and clean up the response
55
+ generated_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
56
+ sql_query = generated_query.split("### SQL Query:")[-1].strip()
57
+
58
  return sql_query
59
 
60
+ def main():
61
+ # Gradio interface setup
62
+ iface = gr.Interface(
63
+ fn=generate_sql,
64
+ inputs="text",
65
+ outputs="text",
66
+ title="Natural Language to SQL Query Generator",
67
+ description="Enter a natural language query to generate the corresponding SQL query."
68
+ )
69
+ iface.launch()
70
+
71
+ if __name__ == "__main__":
72
+ main()