|
import pandas as pd |
|
from openai import OpenAI |
|
import os |
|
from google.cloud import bigquery |
|
import numpy as np |
|
import gradio as gr |
|
|
|
project_id = os.getenv('project_id') |
|
dataset_id = os.getenv('dataset_id') |
|
table_id = os.getenv('table_id') |
|
|
|
openai_client = OpenAI() |
|
|
|
def fetch_table_schema(project_id, dataset_id, table_id): |
|
bqclient = bigquery.Client(project=project_id) |
|
|
|
table_ref = f"{project_id}.{dataset_id}.{table_id}" |
|
|
|
table = bqclient.get_table(table_ref) |
|
|
|
schema_dict = {} |
|
for schema_field in table.schema: |
|
schema_dict[schema_field.name] = schema_field.field_type |
|
|
|
return schema_dict |
|
|
|
def get_sql_query(description): |
|
prompt = f''' |
|
Generate the SQL query for the following task:\n{description}.\n |
|
The database you need is called {dataset_id} and the table is called {table_id}. |
|
Use the format {dataset_id}.{table_id} as the table name in the queries. |
|
Enclose column names in backticks(`) not quotation marks. |
|
Do not assign aliases to the columns. |
|
Do not calculate new columns, unless specifically called to. |
|
Return only the SQL query, nothing else. |
|
Do not use WITHIN GROUP clause. |
|
\nThe list of all the columns is as follows: {schema} /n |
|
''' |
|
try: |
|
completion = openai_client.chat.completions.create( |
|
model='gpt-4o', |
|
messages = [ |
|
{"role": "system", "content": "You are an expert Data Scientist with in-depth knowledge of SQL, working on Network Telemetry Data."}, |
|
{"role": "user", "content": f'{prompt}'}, |
|
] |
|
) |
|
|
|
except Exception as e: |
|
print(f'The following error ocurred: {e}\n') |
|
pass |
|
|
|
sql_query = completion.choices[0].message.content.strip().split('```sql')[1].split('```')[0] |
|
return sql_query |
|
|
|
schema = fetch_table_schema(project_id, dataset_id, table_id) |
|
|
|
def execute_sql_query(query): |
|
client = bigquery.Client() |
|
|
|
try: |
|
result = client.query(query).to_dataframe() |
|
message = f'The query : {query}\n was successfully executed and returned the above result.\n' |
|
|
|
except Exception as e: |
|
result = 'No output returned' |
|
message = f'The query : {query}\n could not be executed due to exception {e}\n' |
|
|
|
return result, message |
|
|
|
def echo(text): |
|
query = get_sql_query(text) |
|
result, message = execute_sql_query(query) |
|
return result, message |
|
|
|
def gradio_interface(text): |
|
result, message = echo(text) |
|
if isinstance(result, pd.DataFrame): |
|
return gr.Dataframe(value=result), message |
|
else: |
|
return result, message |
|
|
|
demo = gr.Blocks( |
|
title="Text-to-SQL", |
|
theme='remilia/ghostly', |
|
) |
|
|
|
with demo: |
|
|
|
gr.Markdown( |
|
''' |
|
# <p style="text-align: center;">Text to SQL Query Engine</p> |
|
|
|
<p style="text-align: center;"> |
|
Welcome to our Text2SQL Engine. |
|
<br> |
|
Enter your query in natural language and we'll convert it to SQL and return the result to you. |
|
</p> |
|
''' |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
text_input = gr.Textbox(label="Enter your query") |
|
button = gr.Button("Submit") |
|
gr.Examples([ |
|
'Find the correlation between RTT and Jitter for each Market', |
|
'Find the variance in Jitter for each 5G_Reliability_Category', |
|
'Find the count of records per 5G_Reliability_Category where 5G_Reliability_Value is below the average for the category', |
|
'Calculate the standard deviation of 5G_Reliability_Score for each Network_Engineer', |
|
'Determine the Sector with the highest variance in 5G Reliability Value and its corresponding average Context Drop Percent' |
|
], |
|
inputs=[text_input] |
|
) |
|
with gr.Column(scale=3): |
|
output_text = gr.Textbox(label="Output", interactive=False) |
|
output_df = gr.Dataframe(interactive=False) |
|
|
|
def update_output(text): |
|
result, message = gradio_interface(text) |
|
if isinstance(result, pd.DataFrame): |
|
return gr.update(visible=True), result, message |
|
else: |
|
return gr.update(visible=False), result, message |
|
|
|
button.click(update_output, inputs=text_input, outputs=[output_df, output_text]) |
|
|
|
demo.launch(debug=True, auth=("admin", "Text2SQL")) |