File size: 3,531 Bytes
0f1d0ae 4d34ac2 bcea34d 0f1d0ae d8d8282 bcea34d 4d34ac2 bcea34d 0f1d0ae bcea34d 0f1d0ae bcea34d 0f1d0ae bcea34d 0f1d0ae bcea34d 0f1d0ae bcea34d 0f1d0ae bcea34d 0f1d0ae bcea34d 0f1d0ae bcea34d 0f1d0ae bcea34d 0f1d0ae bcea34d 0f1d0ae bcea34d 0f1d0ae bcea34d 0f1d0ae bcea34d 0f1d0ae bcea34d 0f1d0ae bcea34d 0f1d0ae bcea34d 0f1d0ae bcea34d 0f1d0ae 36383e0 bcea34d 0f1d0ae bcea34d 0f1d0ae bcea34d 0f1d0ae bcea34d 0f1d0ae bcea34d c7c7c28 0f1d0ae bcea34d 36383e0 0f1d0ae bcea34d 0f1d0ae bcea34d 1fc9ac8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import gradio as gr
# Initial neuron numbers and colors
input_size = 3
hidden_size = 4
output_size = 2
input_color = "lightgreen"
hidden_color ="skyblue"
output_color = "salmon"
# Create an empty directed graph for the visualization
G = nx.DiGraph()
# Update neurons and create the graph
def update_graph(input_size, hidden_size, output_size, input_color, hidden_color, output_color):
# Convert to integer type
input_size = int(input_size)
hidden_size = int(hidden_size)
output_size = int(output_size)
# Clear the graph
G.clear()
# Input layer neurons
for i in range(input_size):
G.add_node(f'I{i}', layer='input')
# Hidden layer neurons
for i in range(hidden_size):
G.add_node(f'H{i}', layer='hidden')
# Output layer neurons
for i in range(output_size):
G.add_node(f'O{i}', layer='output')
# Connections from input layer to hidden layer
for i in range(input_size):
for j in range(hidden_size):
G.add_edge(f'I{i}', f'H{j}', weight=np.random.rand())
# Connections from hidden layer to output layer
for j in range(hidden_size):
for k in range(output_size):
G.add_edge(f'H{j}', f'O{k}', weight=np.random.rand())
# Calculate neuron positions
pos = {}
# Input layer positions
for i in range(input_size):
pos[f'I{i}'] = (0, 1 - (i / (input_size - 1))) # Vertically aligned
# Hidden layer positions
for i in range(hidden_size):
pos[f'H{i}'] = (1, 1 - (i / (hidden_size - 1))) # Vertically aligned
# Output layer positions
for i in range(output_size):
pos[f'O{i}'] = (2, 1 - (i / (output_size - 1))) # Vertically aligned
# Visualize edge weights
edges = G.edges(data=True)
# Visualize the graph
plt.figure(figsize=(10, 6))
nx.draw(G, pos, with_labels=True, node_size=2000, node_color=[input_color] * input_size + [hidden_color] * hidden_size + [output_color] * output_size, font_size=12, font_weight='bold', arrows=True)
nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): f'{d["weight"]:.2f}' for u, v, d in edges})
plt.title("Visual MLP", fontsize=16)
plt.axis('off') # Turn off axes
plt.tight_layout() # Adjust layout
# Display the graph
buf = plt.gcf() # Get the current figure
plt.close() # Close the plot
return buf # Return the graph
# Define the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("### MLP Model Visualizer")
input_slider = gr.Slider(minimum=2, maximum=10, value=input_size, label="Input Layer")
hidden_slider = gr.Slider(minimum=2, maximum=10, value=hidden_size, label="Hidden Layer")
output_slider = gr.Slider(minimum=2, maximum=10, value=output_size, label="Output Layer")
with gr.Row():
input_color_picker = gr.ColorPicker(value=input_color, label="Input Layer Colour")
hidden_color_picker = gr.ColorPicker(value=hidden_color, label="Hidden Layer Colour")
output_color_picker = gr.ColorPicker(value=output_color, label="Output Colour")
output_plot = gr.Plot(label="MLP Model Graph")
update_button = gr.Button("Update")
update_button.click(fn=update_graph,
inputs=[input_slider, hidden_slider, output_slider, input_color_picker, hidden_color_picker, output_color_picker],
outputs=output_plot)
# Run the application
demo.launch()
|