File size: 3,531 Bytes
0f1d0ae
 
 
4d34ac2
 
bcea34d
0f1d0ae
 
 
d8d8282
 
bcea34d
4d34ac2
bcea34d
0f1d0ae
 
bcea34d
 
 
0f1d0ae
 
 
 
bcea34d
0f1d0ae
 
bcea34d
0f1d0ae
 
 
bcea34d
0f1d0ae
 
 
bcea34d
0f1d0ae
 
 
bcea34d
0f1d0ae
 
 
 
bcea34d
0f1d0ae
 
 
 
bcea34d
0f1d0ae
 
bcea34d
0f1d0ae
bcea34d
0f1d0ae
bcea34d
0f1d0ae
bcea34d
0f1d0ae
bcea34d
0f1d0ae
bcea34d
0f1d0ae
bcea34d
0f1d0ae
 
bcea34d
0f1d0ae
bcea34d
0f1d0ae
36383e0
bcea34d
 
0f1d0ae
bcea34d
 
 
 
 
0f1d0ae
bcea34d
0f1d0ae
bcea34d
0f1d0ae
bcea34d
 
 
 
c7c7c28
 
 
 
0f1d0ae
bcea34d
 
36383e0
0f1d0ae
bcea34d
 
 
0f1d0ae
bcea34d
1fc9ac8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import gradio as gr

# Initial neuron numbers and colors
input_size = 3
hidden_size = 4
output_size = 2
input_color = "lightgreen"  
hidden_color ="skyblue"
output_color = "salmon"

# Create an empty directed graph for the visualization
G = nx.DiGraph()

# Update neurons and create the graph
def update_graph(input_size, hidden_size, output_size, input_color, hidden_color, output_color):
    # Convert to integer type
    input_size = int(input_size)
    hidden_size = int(hidden_size)
    output_size = int(output_size)

    # Clear the graph
    G.clear()

    # Input layer neurons
    for i in range(input_size):
        G.add_node(f'I{i}', layer='input')

    # Hidden layer neurons
    for i in range(hidden_size):
        G.add_node(f'H{i}', layer='hidden')

    # Output layer neurons
    for i in range(output_size):
        G.add_node(f'O{i}', layer='output')

    # Connections from input layer to hidden layer
    for i in range(input_size):
        for j in range(hidden_size):
            G.add_edge(f'I{i}', f'H{j}', weight=np.random.rand())

    # Connections from hidden layer to output layer
    for j in range(hidden_size):
        for k in range(output_size):
            G.add_edge(f'H{j}', f'O{k}', weight=np.random.rand())

    # Calculate neuron positions
    pos = {}
    
    # Input layer positions
    for i in range(input_size):
        pos[f'I{i}'] = (0, 1 - (i / (input_size - 1)))  # Vertically aligned

    # Hidden layer positions
    for i in range(hidden_size):
        pos[f'H{i}'] = (1, 1 - (i / (hidden_size - 1)))  # Vertically aligned

    # Output layer positions
    for i in range(output_size):
        pos[f'O{i}'] = (2, 1 - (i / (output_size - 1)))  # Vertically aligned

    # Visualize edge weights
    edges = G.edges(data=True)

    # Visualize the graph
    plt.figure(figsize=(10, 6))
    nx.draw(G, pos, with_labels=True, node_size=2000, node_color=[input_color] * input_size + [hidden_color] * hidden_size + [output_color] * output_size, font_size=12, font_weight='bold', arrows=True)
    nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): f'{d["weight"]:.2f}' for u, v, d in edges})
    plt.title("Visual MLP", fontsize=16)
    plt.axis('off')  # Turn off axes
    plt.tight_layout()  # Adjust layout
    
    # Display the graph
    buf = plt.gcf()  # Get the current figure
    plt.close()  # Close the plot
    
    return buf  # Return the graph

# Define the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("### MLP Model Visualizer")
    
    input_slider = gr.Slider(minimum=2, maximum=10, value=input_size, label="Input Layer")
    hidden_slider = gr.Slider(minimum=2, maximum=10, value=hidden_size, label="Hidden Layer")
    output_slider = gr.Slider(minimum=2, maximum=10, value=output_size, label="Output Layer")

    with gr.Row():
        input_color_picker = gr.ColorPicker(value=input_color, label="Input Layer Colour")
        hidden_color_picker = gr.ColorPicker(value=hidden_color, label="Hidden Layer Colour")
        output_color_picker = gr.ColorPicker(value=output_color, label="Output Colour")
    
    output_plot = gr.Plot(label="MLP Model Graph")

    update_button = gr.Button("Update")
    
    update_button.click(fn=update_graph, 
                        inputs=[input_slider, hidden_slider, output_slider, input_color_picker, hidden_color_picker, output_color_picker], 
                        outputs=output_plot)

# Run the application
demo.launch()