File size: 3,178 Bytes
539de04
 
 
 
 
 
 
 
4213c4a
539de04
4213c4a
539de04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4213c4a
539de04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4213c4a
539de04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4213c4a
539de04
 
4213c4a
539de04
 
 
 
 
 
4213c4a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import pandas as pd
import pdfplumber
import docx
import openai
import seaborn as sns
import matplotlib.pyplot as plt
import gradio as gr

# Set your OpenAI API key
openai.api_key = 'sk-proj-PMkGJxtGRdaihzh15yJYT3BlbkFJ0bEWbrsZjjwV5d3XYSFc'

def load_file(file):
    file_type = file.name.split('.')[-1]
    if file_type == 'csv':
        return pd.read_csv(file.name)
    elif file_type in ['xls', 'xlsx']:
        return pd.read_excel(file.name)
    elif file_type == 'pdf':
        return load_pdf(file)
    elif file_type in ['doc', 'docx']:
        return load_doc(file)
    else:
        raise ValueError("Unsupported file type")

def load_pdf(file):
    with pdfplumber.open(file.name) as pdf:
        pages = [page.extract_text() for page in pdf.pages]
    text = "\n".join(pages)
    return pd.DataFrame({"text": [text]})

def load_doc(file):
    doc = docx.Document(file.name)
    text = "\n".join([para.text for para in doc.paragraphs])
    return pd.DataFrame({"text": [text]})

def generate_query(prompt):
    response = openai.ChatCompletion.create(
        model="gpt-3.5-turbo",
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt}
        ]
    )
    return response['choices'][0]['message']['content'].strip()

def handle_query(query, df):
    if "number of columns" in query.lower():
        return f"The number of columns is {df.shape[1]}"
    elif "number of rows" in query.lower():
        return f"The number of rows is {df.shape[0]}"
    else:
        try:
            result_df = df.query(query)
            return result_df.to_html()
        except Exception as e:
            return str(e)

def draw_chart(query, df):
    try:
        result_df = df.query(query)
        sns.scatterplot(data=result_df, x=result_df.columns[0], y=result_df.columns[1])
        plt.title("Generated Chart")
        plt.xlabel(result_df.columns[0])
        plt.ylabel(result_df.columns[1])
        plt.savefig('/content/chart.png')
        plt.close()
        return '/content/chart.png'
    except Exception as e:
        return str(e)

def chatbot(file, input_text):
    try:
        # Load the file into a DataFrame
        df = load_file(file)
        
        # Generate a query from the input text
        query = generate_query(input_text)
        
        # Handle the query and generate a response
        response = handle_query(query, df)
        
        # If the query is suitable for generating a chart, do so
        if "chart" in query.lower() or "graph" in query.lower():
            chart_path = draw_chart(query, df)
            return chart_path, response
        
        # Return the query response
        return None, response
    except Exception as e:
        return None, str(e)

# Create a Gradio interface
iface = gr.Interface(
    fn=chatbot,
    inputs=[gr.File(type="file", label="Upload File"), gr.Textbox(lines=2, placeholder="Enter your query here...")],
    outputs=["image", "html"],
    title="Data Analyst Chatbot",
    description="Upload a file and enter a query to get responses based on the data."
)

# Launch the interface
iface.launch()