File size: 5,292 Bytes
45c21ee
8ac21ea
 
 
 
 
45c21ee
8fe3e35
 
8ac21ea
8fe3e35
 
 
8ac21ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b4173f
8ac21ea
 
 
b216009
8ac21ea
 
8fe3e35
8ac21ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fe3e35
f5e3f06
8ac21ea
 
029cd9d
8ac21ea
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os
import shutil
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import gradio as gr

css = '''
.gradio-container{max-width: 900px !important}
h1{text-align:center}
'''

def create_visualizations(data):
    plots = []
    
    # Create figures directory
    figures_dir = "./figures"
    shutil.rmtree(figures_dir, ignore_errors=True)
    os.makedirs(figures_dir, exist_ok=True)
    
    # Histograms for numeric columns
    numeric_cols = data.select_dtypes(include=['number']).columns
    for col in numeric_cols:
        plt.figure()
        sns.histplot(data[col], kde=True)
        plt.title(f'Histogram of {col}')
        plt.xlabel(col)
        plt.ylabel('Frequency')
        hist_path = os.path.join(figures_dir, f'histogram_{col}.png')
        plt.savefig(hist_path)
        plt.close()
        plots.append(hist_path)
    
    # Box plots for numeric columns
    for col in numeric_cols:
        plt.figure()
        sns.boxplot(x=data[col])
        plt.title(f'Box Plot of {col}')
        box_path = os.path.join(figures_dir, f'boxplot_{col}.png')
        plt.savefig(box_path)
        plt.close()
        plots.append(box_path)
    
    # Scatter plot matrix
    if len(numeric_cols) > 1:
        plt.figure()
        sns.pairplot(data[numeric_cols])
        plt.title('Scatter Plot Matrix')
        scatter_matrix_path = os.path.join(figures_dir, 'scatter_matrix.png')
        plt.savefig(scatter_matrix_path)
        plt.close()
        plots.append(scatter_matrix_path)
    
    # Correlation heatmap
    if len(numeric_cols) > 1:
        plt.figure()
        corr = data[numeric_cols].corr()
        sns.heatmap(corr, annot=True, cmap='coolwarm')
        plt.title('Correlation Heatmap')
        heatmap_path = os.path.join(figures_dir, 'correlation_heatmap.png')
        plt.savefig(heatmap_path)
        plt.close()
        plots.append(heatmap_path)
    
    # Bar charts for categorical columns
    categorical_cols = data.select_dtypes(include=['object']).columns
    if not categorical_cols.empty:
        for col in categorical_cols:
            plt.figure()
            data[col].value_counts().plot(kind='bar')
            plt.title(f'Bar Chart of {col}')
            plt.xlabel(col)
            plt.ylabel('Count')
            bar_path = os.path.join(figures_dir, f'bar_chart_{col}.png')
            plt.savefig(bar_path)
            plt.close()
            plots.append(bar_path)
    
    # Line charts (if a 'date' column is present)
    if 'date' in data.columns:
        plt.figure()
        data['date'] = pd.to_datetime(data['date'])
        data.set_index('date').plot()
        plt.title('Line Chart of Date Series')
        line_chart_path = os.path.join(figures_dir, 'line_chart.png')
        plt.savefig(line_chart_path)
        plt.close()
        plots.append(line_chart_path)
    
    # Scatter plot using Plotly
    if len(numeric_cols) >= 2:
        fig = px.scatter(data, x=numeric_cols[0], y=numeric_cols[1], title='Scatter Plot')
        scatter_plot_path = os.path.join(figures_dir, 'scatter_plot.html')
        fig.write_html(scatter_plot_path)
        plots.append(scatter_plot_path)
    
    # Pie chart for categorical columns (only the first categorical column)
    if not categorical_cols.empty:
        fig = px.pie(data, names=categorical_cols[0], title='Pie Chart of ' + categorical_cols[0])
        pie_chart_path = os.path.join(figures_dir, 'pie_chart.html')
        fig.write_html(pie_chart_path)
        plots.append(pie_chart_path)
    
    # Heatmaps (e.g., for a correlation matrix or cross-tabulation)
    if len(numeric_cols) > 1:
        heatmap_data = data[numeric_cols].corr()
        fig = px.imshow(heatmap_data, text_auto=True, title='Heatmap of Numeric Variables')
        heatmap_plot_path = os.path.join(figures_dir, 'heatmap_plot.html')
        fig.write_html(heatmap_plot_path)
        plots.append(heatmap_plot_path)
    
    # Violin plots for numeric columns
    for col in numeric_cols:
        plt.figure()
        sns.violinplot(x=data[col])
        plt.title(f'Violin Plot of {col}')
        violin_path = os.path.join(figures_dir, f'violin_plot_{col}.png')
        plt.savefig(violin_path)
        plt.close()
        plots.append(violin_path)
    
    return plots

def analyze_data(file_input):
    data = pd.read_csv(file_input.name)
    return create_visualizations(data)

# Example file path
example_file_path = "./example/example.csv"

with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue=gr.themes.colors.orange, secondary_hue=gr.themes.colors.blue)) as demo:
    gr.Markdown("# DATA BOARD📊\nUpload a `.csv` file to generate various visualizations and interactive plots.")
    
    file_input = gr.File(label="Upload your `.csv` file")
    submit = gr.Button("Generate Dashboards")
    
    # Display images and interactive plots in a gallery
    gallery = gr.Gallery(label="Visualizations")
    
    # Example block with cache_examples set to True
    examples = gr.Examples(
        examples=[[example_file_path]],
        inputs=file_input,
        outputs=gallery,
        cache_examples=True  # Enable caching
    )
    
    submit.click(analyze_data, file_input, gallery)

if __name__ == "__main__":
    demo.launch(share=True)