import numpy as np import matplotlib.pyplot as plt import networkx as nx import gradio as gr # Initial neuron numbers and colors input_size = 3 hidden_size = 4 output_size = 2 input_color = "lightgreen" hidden_color ="skyblue" output_color = "salmon" # Create an empty directed graph for the visualization G = nx.DiGraph() # Update neurons and create the graph def update_graph(input_size, hidden_size, output_size, input_color, hidden_color, output_color): # Convert to integer type input_size = int(input_size) hidden_size = int(hidden_size) output_size = int(output_size) # Clear the graph G.clear() # Input layer neurons for i in range(input_size): G.add_node(f'I{i+1}', layer='input') # Hidden layer neurons for i in range(hidden_size): G.add_node(f'H{i+1}', layer='hidden') # Output layer neurons for i in range(output_size): G.add_node(f'O{i+1}', layer='output') # Connections from input layer to hidden layer for i in range(input_size): for j in range(hidden_size): G.add_edge(f'I{i+1}', f'H{j+1}', weight=np.random.rand()) # Connections from hidden layer to output layer for j in range(hidden_size): for k in range(output_size): G.add_edge(f'H{j+1}', f'O{k+1}', weight=np.random.rand()) # Calculate neuron positions pos = {} # Input layer positions for i in range(input_size): pos[f'I{i+1}'] = (0, 1 - (i / (input_size - 1))) # Vertically aligned # Hidden layer positions for i in range(hidden_size): pos[f'H{i+1}'] = (1, 1 - (i / (hidden_size - 1))) # Vertically aligned # Output layer positions for i in range(output_size): pos[f'O{i+1}'] = (2, 1 - (i / (output_size - 1))) # Vertically aligned # Visualize edge weights edges = G.edges(data=True) # Visualize the graph plt.figure(figsize=(10, 6)) nx.draw(G, pos, with_labels=True, node_size=2000, node_color=[input_color] * input_size + [hidden_color] * hidden_size + [output_color] * output_size, font_size=12, font_weight='bold', arrows=True) nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): f'{d["weight"]:.2f}' for u, v, d in edges}) plt.title("Visual MLP", fontsize=16) plt.axis('off') # Turn off axes plt.tight_layout() # Adjust layout # Display the graph buf = plt.gcf() # Get the current figure plt.close() # Close the plot return buf # Return the graph # Define the Gradio interface with gr.Blocks() as demo: gr.Markdown("### MLP Model Visualizer") input_slider = gr.Slider(minimum=2, maximum=10, value=input_size, label="Input Layer") hidden_slider = gr.Slider(minimum=2, maximum=10, value=hidden_size, label="Hidden Layer") output_slider = gr.Slider(minimum=2, maximum=10, value=output_size, label="Output Layer") with gr.Row(): input_color_picker = gr.ColorPicker(value=input_color, label="Input Layer Colour") hidden_color_picker = gr.ColorPicker(value=hidden_color, label="Hidden Layer Colour") output_color_picker = gr.ColorPicker(value=output_color, label="Output Colour") output_plot = gr.Plot(label="MLP Model Graph") update_button = gr.Button("Update") update_button.click(fn=update_graph, inputs=[input_slider, hidden_slider, output_slider, input_color_picker, hidden_color_picker, output_color_picker], outputs=output_plot) # Run the application demo.launch()