Aman Sharma commited on
Commit
8462fc9
·
1 Parent(s): 74ecb89

gradio demo added

Browse files
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from sketch_nn.designer import NeuralNetworkDesigner
3
+ import tempfile
4
+ import os
5
+ import cv2
6
+ import numpy as np
7
+
8
+ def generate_nn_code(image):
9
+ designer = NeuralNetworkDesigner()
10
+
11
+ temp_dir = tempfile.mkdtemp()
12
+ temp_path = os.path.join(temp_dir, "input_image.png")
13
+
14
+ if isinstance(image, np.ndarray):
15
+ # If it's a numpy array (captured image or uploaded image)
16
+ cv2.imwrite(temp_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
17
+ else:
18
+ # If it's a file path (should not happen with current setup, but just in case)
19
+ os.rename(image, temp_path)
20
+
21
+ output_file = os.path.join(temp_dir, "custom_nn.py")
22
+ designer.design_network(temp_path, output_file)
23
+ with open(output_file, 'r') as f:
24
+ code = f.read()
25
+ return output_file, code
26
+
27
+ # Check if the version of Gradio supports the 'source' parameter
28
+ try:
29
+ image_input = gr.Image(source=["upload", "webcam"], type="numpy", label="Upload or Capture Flowchart")
30
+ except TypeError:
31
+ # Fallback for older Gradio versions
32
+ image_input = gr.Image(type="numpy", label="Upload Flowchart")
33
+
34
+ iface = gr.Interface(
35
+ fn=generate_nn_code,
36
+ inputs=[image_input],
37
+ outputs=[
38
+ gr.File(label="Download Generated Code"),
39
+ gr.Code(language="python", label="Generated PyTorch Code")
40
+ ],
41
+ title="Sketch NN: Neural Network Designer",
42
+ description="Upload a flowchart image or capture one using your webcam to generate PyTorch code for your neural network architecture."
43
+ )
44
+
45
+ if __name__ == "__main__":
46
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ opencv-python
4
+ numpy
5
+ pytesseract
sketch_ nn/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import the main class from the designer module
2
+ from .designer import NeuralNetworkDesigner
3
+
4
+ # Import any utility functions you want to make directly accessible
5
+ from .utils import save_uploaded_file # assuming you have this function in utils.py
6
+
7
+ # You can also define the version of your package here
8
+ __version__ = "0.1.0"
9
+
10
+ # If you want to control what gets imported with "from sketch_nn import *"
11
+ __all__ = ['NeuralNetworkDesigner', 'save_uploaded_file']
sketch_ nn/designer.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import pytesseract
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ class NeuralNetworkDesigner:
8
+ def __init__(self):
9
+ self.layer_maps = {}
10
+
11
+ def process_image(self, image_path):
12
+ # Read the image
13
+ image = cv2.imread(image_path)
14
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
15
+
16
+ # Threshold the image
17
+ _, binary = cv2.threshold(gray, 225, 255, cv2.THRESH_BINARY_INV)
18
+
19
+ # Find contours
20
+ contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
21
+
22
+ # Sort contours from top to bottom
23
+ contours = sorted(contours, key=lambda c: cv2.boundingRect(c)[1])
24
+
25
+ for i, contour in enumerate(contours):
26
+ x, y, w, h = cv2.boundingRect(contour)
27
+ roi = gray[y:y+h, x:x+w]
28
+
29
+ # Perform OCR on the ROI
30
+ text = pytesseract.image_to_string(roi).strip()
31
+ self.parse_layer_info(i, text)
32
+
33
+ def parse_layer_info(self, layer_index, text):
34
+ lines = text.split('\n')
35
+ layer_info = {'type': 'Unknown', 'text': text}
36
+
37
+ try:
38
+ if 'Input' in lines[0]:
39
+ layer_info['type'] = 'Input'
40
+ layer_info['channels'] = int(lines[1]) if len(lines) > 1 else None
41
+ elif 'Conv' in lines[0]:
42
+ layer_info['type'] = 'Conv2d'
43
+ layer_info['out_channels'] = int(lines[1]) if len(lines) > 1 else None
44
+ layer_info['kernel_size'] = int(lines[2]) if len(lines) > 2 else None
45
+ elif any(x in lines[0] for x in ['MaxPool', 'AvgPool']):
46
+ layer_info['type'] = 'MaxPool2d' if 'Max' in lines[0] else 'AvgPool2d'
47
+ layer_info['kernel_size'] = int(lines[1]) if len(lines) > 1 else None
48
+ elif 'Linear' in lines[0]:
49
+ layer_info['type'] = 'Linear'
50
+ if len(lines) > 1 and '*' in lines[1]:
51
+ layer_info['in_features'] = lines[1]
52
+ layer_info['out_features'] = int(lines[-1]) if lines[-1].isdigit() else None
53
+ elif 'BatchNorm' in lines[0]:
54
+ layer_info['type'] = 'BatchNorm2d'
55
+ layer_info['num_features'] = int(lines[1]) if len(lines) > 1 else None
56
+ elif any(x in lines[0] for x in ['ReLU', 'LeakyReLU', 'Sigmoid', 'Tanh']):
57
+ layer_info['type'] = lines[0]
58
+ elif 'Dropout' in lines[0]:
59
+ layer_info['type'] = 'Dropout'
60
+ layer_info['p'] = float(lines[1]) if len(lines) > 1 else 0.5
61
+ elif 'Transformer' in lines[0]:
62
+ layer_info['type'] = 'Transformer'
63
+ layer_info['d_model'] = int(lines[1]) if len(lines) > 1 else 512
64
+ layer_info['nhead'] = int(lines[2]) if len(lines) > 2 else 8
65
+ elif 'Attention' in lines[0]:
66
+ layer_info['type'] = 'MultiheadAttention'
67
+ layer_info['embed_dim'] = int(lines[1]) if len(lines) > 1 else 512
68
+ layer_info['num_heads'] = int(lines[2]) if len(lines) > 2 else 8
69
+ elif 'LSTM' in lines[0] or 'GRU' in lines[0]:
70
+ layer_info['type'] = lines[0]
71
+ layer_info['hidden_size'] = int(lines[1]) if len(lines) > 1 else 256
72
+ layer_info['num_layers'] = int(lines[2]) if len(lines) > 2 else 1
73
+ except ValueError as e:
74
+ print(f"Error parsing layer {layer_index}: {e}")
75
+
76
+ self.layer_maps[layer_index] = layer_info
77
+ print(f"Parsed layer {layer_index}: {layer_info}") # Debug print
78
+
79
+ def generate_pytorch_code(self):
80
+ code = "import torch\nimport torch.nn as nn\n\n"
81
+ code += "class CustomNN(nn.Module):\n"
82
+ code += " def __init__(self):\n"
83
+ code += " super(CustomNN, self).__init__()\n"
84
+
85
+ forward_code = " def forward(self, x):\n"
86
+
87
+ in_channels = None
88
+ for i, layer_info in sorted(self.layer_maps.items()):
89
+ if layer_info['type'] == 'Input':
90
+ in_channels = layer_info.get('channels', 3)
91
+ continue
92
+
93
+ if layer_info['type'] == 'Conv2d':
94
+ out_channels = layer_info.get('out_channels', 64)
95
+ kernel_size = layer_info.get('kernel_size', 3)
96
+ code += f" self.conv{i} = nn.Conv2d({in_channels}, {out_channels}, kernel_size={kernel_size}, padding=1)\n"
97
+ forward_code += f" x = self.conv{i}(x)\n"
98
+ in_channels = out_channels
99
+
100
+ elif layer_info['type'] in ['MaxPool2d', 'AvgPool2d']:
101
+ kernel_size = layer_info.get('kernel_size', 2)
102
+ code += f" self.pool{i} = nn.{layer_info['type']}(kernel_size={kernel_size})\n"
103
+ forward_code += f" x = self.pool{i}(x)\n"
104
+
105
+ elif layer_info['type'] == 'Linear':
106
+ out_features = layer_info.get('out_features')
107
+ if i == 1 or (i > 1 and self.layer_maps[i-1]['type'] not in ['Linear', 'Flatten']):
108
+ code += f" self.flatten = nn.Flatten()\n"
109
+ forward_code += f" x = self.flatten(x)\n"
110
+ in_features = layer_info.get('in_features', 'x.shape[1]')
111
+ else:
112
+ in_features = self.layer_maps[i-1].get('out_features', 64)
113
+ code += f" self.fc{i} = nn.Linear({in_features}, {out_features})\n"
114
+ forward_code += f" x = self.fc{i}(x)\n"
115
+
116
+ elif layer_info['type'] == 'BatchNorm2d':
117
+ num_features = layer_info.get('num_features', in_channels)
118
+ code += f" self.bn{i} = nn.BatchNorm2d({num_features})\n"
119
+ forward_code += f" x = self.bn{i}(x)\n"
120
+
121
+ elif layer_info['type'] in ['ReLU', 'LeakyReLU', 'Sigmoid', 'Tanh']:
122
+ code += f" self.act{i} = nn.{layer_info['type']}()\n"
123
+ forward_code += f" x = self.act{i}(x)\n"
124
+
125
+ elif layer_info['type'] == 'Dropout':
126
+ p = layer_info.get('p', 0.5)
127
+ code += f" self.dropout{i} = nn.Dropout(p={p})\n"
128
+ forward_code += f" x = self.dropout{i}(x)\n"
129
+
130
+ elif layer_info['type'] == 'Transformer':
131
+ d_model = layer_info.get('d_model', 512)
132
+ nhead = layer_info.get('nhead', 8)
133
+ code += f" self.transformer{i} = nn.Transformer(d_model={d_model}, nhead={nhead})\n"
134
+ forward_code += f" x = self.transformer{i}(x)\n"
135
+
136
+ elif layer_info['type'] == 'MultiheadAttention':
137
+ embed_dim = layer_info.get('embed_dim', 512)
138
+ num_heads = layer_info.get('num_heads', 8)
139
+ code += f" self.attention{i} = nn.MultiheadAttention(embed_dim={embed_dim}, num_heads={num_heads})\n"
140
+ forward_code += f" x, _ = self.attention{i}(x, x, x)\n"
141
+
142
+ elif layer_info['type'] in ['LSTM', 'GRU']:
143
+ hidden_size = layer_info.get('hidden_size', 256)
144
+ num_layers = layer_info.get('num_layers', 1)
145
+ code += f" self.rnn{i} = nn.{layer_info['type']}(input_size={in_channels}, hidden_size={hidden_size}, num_layers={num_layers}, batch_first=True)\n"
146
+ forward_code += f" x, _ = self.rnn{i}(x)\n"
147
+
148
+ elif layer_info['type'] == 'Unknown':
149
+ print(f"Warning: Unknown layer type at index {i}. Layer info: {layer_info}")
150
+
151
+ code += "\n" + forward_code
152
+ code += " return x\n"
153
+
154
+ return code
155
+
156
+ def write_to_file(self, code, filename):
157
+ with open(filename, 'w') as f:
158
+ f.write(code)
159
+
160
+ def design_network(self, image_path, output_file):
161
+ self.process_image(image_path)
162
+ pytorch_code = self.generate_pytorch_code()
163
+ self.write_to_file(pytorch_code, output_file)
164
+ # print(f"Neural network code has been generated and saved to '{output_file}'")
165
+ # print("\nGenerated PyTorch Code:")
166
+ # print(pytorch_code)
sketch_ nn/utils.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+
4
+ def save_uploaded_file(uploaded_file):
5
+ temp_dir = tempfile.mkdtemp()
6
+ temp_path = os.path.join(temp_dir, "uploaded_image.png")
7
+ uploaded_file.save(temp_path)
8
+ return temp_path