File size: 4,210 Bytes
1faa9b0
 
 
 
 
 
 
2afc588
 
 
 
1faa9b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f26755
1faa9b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f26755
1faa9b0
3e716d4
9f26755
 
 
 
 
 
 
 
 
 
1faa9b0
 
 
 
 
 
 
 
 
 
3e716d4
1faa9b0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import pandas as pd
from openai import OpenAI
import os
from google.cloud import bigquery
import numpy as np
import gradio as gr

project_id = os.getenv('project_id')
dataset_id = os.getenv('dataset_id')
table_id = os.getenv('table_id')

openai_client = OpenAI()

def fetch_table_schema(project_id, dataset_id, table_id):
    bqclient = bigquery.Client(project=project_id)

    table_ref = f"{project_id}.{dataset_id}.{table_id}"

    table = bqclient.get_table(table_ref)

    schema_dict = {}
    for schema_field in table.schema:
        schema_dict[schema_field.name] = schema_field.field_type

    return schema_dict

def get_sql_query(description):
    prompt = f'''
    Generate the SQL query for the following task:\n{description}.\n
    The database you need is called {dataset_id} and the table is called {table_id}.
    Use the format {dataset_id}.{table_id} as the table name in the queries.
    Enclose column names in backticks(`) not quotation marks.
    Do not assign aliases to the columns.
    Do not calculate new columns, unless specifically called to.
    Return only the SQL query, nothing else.
    Do not use WITHIN GROUP clause.
    \nThe list of all the columns is as follows: {schema} /n
    '''
    try:
      completion = openai_client.chat.completions.create(
          model='gpt-4o',
          messages = [
                  {"role": "system", "content": "You are an expert Data Scientist with in-depth knowledge of SQL, working on Network Telemetry Data."},
                  {"role": "user", "content": f'{prompt}'},
                ]
      )

    except Exception as e:
     print(f'The following error ocurred: {e}\n')
     pass

    sql_query = completion.choices[0].message.content.strip().split('```sql')[1].split('```')[0]
    return sql_query

schema = fetch_table_schema(project_id, dataset_id, table_id)

def execute_sql_query(query):
    client = bigquery.Client()

    try:
     result = client.query(query).to_dataframe()
     message = f'The query : {query}\n was successfully executed and returned the above result.\n'

    except Exception as e:
     result = 'No output returned'
     message = f'The query : {query}\n could not be executed due to exception {e}\n'

    return result, message

def echo(text):
  query = get_sql_query(text)
  result, message = execute_sql_query(query)
  return result, message

def gradio_interface(text):
    result, message = echo(text)
    if isinstance(result, pd.DataFrame):
        return gr.Dataframe(value=result), message
    else:
        return result, message

demo = gr.Blocks(
        title="Text-to-SQL",
        theme='remilia/ghostly',
)

with demo:

  gr.Markdown(
    '''
    # <p style="text-align: center;">Text to SQL Query Engine</p>

    <p style="text-align: center;">
    Welcome to our Text2SQL Engine.
    <br>
    Enter your query in natural language and we'll convert it to SQL and return the result to you.
    </p>
    '''
    )

  with gr.Row():
    with gr.Column(scale=1):
      text_input = gr.Textbox(label="Enter your query")
      button = gr.Button("Submit")  
      gr.Examples([
        'Find the correlation between RTT and Jitter for each Market',
        'Find the variance in Jitter for each 5G_Reliability_Category',
        'Find the count of records per 5G_Reliability_Category where 5G_Reliability_Value is below the average for the category',
        'Calculate the standard deviation of 5G_Reliability_Score for each Network_Engineer',
        'Determine the Sector with the highest variance in 5G Reliability Value and its corresponding average Context Drop Percent'
        ],
                inputs=[text_input]
                )
    with gr.Column(scale=3):
      output_text = gr.Textbox(label="Output", interactive=False)
      output_df = gr.Dataframe(interactive=False)

    def update_output(text):
        result, message = gradio_interface(text)
        if isinstance(result, pd.DataFrame):
            return gr.update(visible=True), result, message
        else:
            return gr.update(visible=False), result, message

    button.click(update_output, inputs=text_input, outputs=[output_df, output_text])

demo.launch(debug=True, auth=("admin", "Text2SQL"))