Spaces:
Runtime error
Runtime error
import pandas as pd | |
import pdfplumber | |
import docx | |
import openai | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
import gradio as gr | |
openai.api_key = 'sk-proj-PMkGJxtGRdaihzh15yJYT3BlbkFJ0bEWbrsZjjwV5d3XYSFc' | |
def load_file(file): | |
file_type = file.name.split('.')[-1] | |
if file_type == 'csv': | |
return pd.read_csv(file.name) | |
elif file_type in ['xls', 'xlsx']: | |
return pd.read_excel(file.name) | |
elif file_type == 'pdf': | |
return load_pdf(file) | |
elif file_type in ['doc', 'docx']: | |
return load_doc(file) | |
else: | |
raise ValueError("Unsupported file type") | |
def load_pdf(file): | |
with pdfplumber.open(file.name) as pdf: | |
pages = [page.extract_text() for page in pdf.pages] | |
text = "\n".join(pages) | |
return pd.DataFrame({"text": [text]}) | |
def load_doc(file): | |
doc = docx.Document(file.name) | |
text = "\n".join([para.text for para in doc.paragraphs]) | |
return pd.DataFrame({"text": [text]}) | |
def generate_query(prompt): | |
response = openai.Completion.create( | |
engine="text-davinci-003", | |
prompt=prompt, | |
max_tokens=150 | |
) | |
return response.choices[0].text.strip() | |
def handle_query(query, df): | |
if "number of columns" in query.lower(): | |
return f"The number of columns is {df.shape[1]}" | |
elif "number of rows" in query.lower(): | |
return f"The number of rows is {df.shape[0]}" | |
else: | |
try: | |
# Try executing the query as a pandas query | |
result_df = df.query(query) | |
return result_df.to_html() | |
except Exception as e: | |
return str(e) | |
def draw_chart(query, df): | |
try: | |
result_df = df.query(query) | |
sns.scatterplot(data=result_df, x=result_df.columns[0], y=result_df.columns[1]) | |
plt.title("Generated Chart") | |
plt.xlabel(result_df.columns[0]) | |
plt.ylabel(result_df.columns[1]) | |
plt.savefig('/content/chart.png') | |
plt.close() | |
return '/content/chart.png' | |
except Exception as e: | |
return str(e) | |
def generate_query(prompt): | |
response = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": prompt} | |
] | |
) | |
return response['choices'][0]['message']['content'].strip() | |
def handle_query(query, df): | |
if "number of columns" in query.lower(): | |
return f"The number of columns is {df.shape[1]}" | |
elif "number of rows" in query.lower(): | |
return f"The number of rows is {df.shape[0]}" | |
else: | |
try: | |
result_df = df.query(query) | |
return result_df.to_html() | |
except Exception as e: | |
return str(e) | |
def draw_chart(query, df): | |
try: | |
result_df = df.query(query) | |
sns.scatterplot(data=result_df, x=result_df.columns[0], y=result_df.columns[1]) | |
plt.title("Generated Chart") | |
plt.xlabel(result_df.columns[0]) | |
plt.ylabel(result_df.columns[1]) | |
plt.savefig('/content/chart.png') | |
plt.close() | |
return '/content/chart.png' | |
except Exception as e: | |
return str(e) | |
def chatbot(file, input_text): | |
try: | |
# Load the file into a DataFrame | |
df = load_file(file) | |
# Generate a query from the input text | |
query = generate_query(input_text) | |
# Handle the query and generate a response | |
response = handle_query(query, df) | |
# If the query is suitable for generating a chart, do so | |
if "chart" in query.lower() or "graph" in query.lower(): | |
chart_path = draw_chart(query, df) | |
return chart_path, response | |
# Return the query response | |
return None, response | |
except Exception as e: | |
return None, str(e) | |
# Create a Gradio interface | |
iface = gr.Interface( | |
fn=chatbot, | |
inputs=[gr.File(type="filepath", label="Upload File"), gr.Textbox(lines=2, placeholder="Enter your query here...")], | |
outputs=["image", "html"], | |
title="Data Analyst Chatbot", | |
description="Upload a file and enter a query to get responses based on the data." | |
) | |
# Launch the interface | |
iface.launch() |