filapro commited on
Commit
1cfe6b6
1 Parent(s): 420f5d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -97
app.py CHANGED
@@ -4,6 +4,7 @@ import trimesh
4
  import traceback
5
  import numpy as np
6
  import gradio as gr
 
7
  from multiprocessing import Process, Queue
8
 
9
  import torch
@@ -148,113 +149,124 @@ def py_string_to_mesh_file_safe(py_string, mesh_path):
148
  if process.is_alive():
149
  process.terminate()
150
  process.join()
151
- raise RuntimeError('Process is alive after 3 seconds')
152
 
153
  if not queue.empty():
154
- raise RuntimeError(queue.get())
155
-
156
-
157
- def run_gpu(model, input_ids, attention_mask, point_cloud, pad_token_id):
158
- if torch.cuda.is_available():
159
- model = model.cuda()
160
- with torch.no_grad():
161
- batch_ids = model.generate(
162
- input_ids=torch.tensor(input_ids).unsqueeze(0).to(model.device),
163
- attention_mask=torch.tensor(attention_mask).unsqueeze(0).to(model.device),
164
- point_cloud=torch.tensor(point_cloud.astype(np.float32)).unsqueeze(0).to(model.device),
165
- max_new_tokens=768,
166
- pad_token_id=pad_token_id).cpu()
167
- return batch_ids
168
-
169
-
170
- def run_test(in_mesh_path, seed, results):
171
- mesh = trimesh.load(in_mesh_path)
172
- mesh.apply_translation(-(mesh.bounds[0] + mesh.bounds[1]) / 2.0)
173
- mesh.apply_scale(2.0 / max(mesh.extents))
174
- np.random.seed(seed)
175
- point_cloud = mesh_to_point_cloud(mesh)
176
-
177
- pcd_path = '/tmp/pcd.obj'
178
- trimesh.points.PointCloud(point_cloud[:, :3]).export(pcd_path)
179
- results.append(pcd_path)
180
-
181
- tokenizer = AutoTokenizer.from_pretrained(
182
- 'Qwen/Qwen2-1.5B',
183
- pad_token='<|im_end|>',
184
- padding_side='left')
185
- model = CADRecode.from_pretrained(
186
- 'filapro/cad-recode',
187
- torch_dtype='auto').eval()
188
-
189
- input_ids = [tokenizer.pad_token_id] * len(point_cloud) + [tokenizer('<|im_start|>')['input_ids'][0]]
190
- attention_mask = [-1] * len(point_cloud) + [1]
191
- batch_ids = run_gpu(model, input_ids, attention_mask, point_cloud, tokenizer.pad_token_id)
192
- py_string = tokenizer.batch_decode(batch_ids)[0]
193
- begin = py_string.find('<|im_start|>') + 12
194
- end = py_string.find('<|endoftext|>')
195
- py_string = py_string[begin: end]
196
- results.append(py_string)
197
-
198
- out_mesh_path = '/tmp/mesh.stl'
199
- py_string_to_mesh_file_safe(py_string, out_mesh_path)
200
- results.append(out_mesh_path)
201
 
202
 
203
  @spaces.GPU(duration=20)
204
- def run_test_safe(in_mesh_path, seed):
205
- results, log = list(), str()
206
  try:
207
- run_test(in_mesh_path, seed, results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  except:
209
- log += 'Status: FAILED\n' + traceback.format_exc()
210
- return results + [None] * (3 - len(results)) + [log]
211
-
212
 
213
- os.environ['TOKENIZERS_PARALLELISM'] = 'False'
214
 
 
 
 
 
 
 
 
215
 
216
- with gr.Blocks() as demo:
217
- with gr.Row():
218
- gr.Markdown('## CAD-Recode Demo\n'
219
- 'Upload mesh or select from examples and press Run! Mesh ⇾ 256 points ⇾ Python code by CAD-Recode ⇾ CAD model.')
220
 
221
- with gr.Row(equal_height=True):
222
- in_model = gr.Model3D(label='1. Input Mesh', interactive=True)
223
- point_model = gr.Model3D(label='2. Sampled Point Cloud', display_mode='point_cloud', interactive=False)
224
- out_model = gr.Model3D(
225
- label='4. Result CAD Model', interactive=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  )
227
-
228
- with gr.Row():
229
- with gr.Column():
230
- with gr.Row():
231
- seed_slider = gr.Slider(label='Random Seed', value=42, interactive=True)
232
- with gr.Row():
233
- _ = gr.Examples(
234
- examples=[
235
- ['./data/49215_5368e45e_0000.stl', 42],
236
- ['./data/00882236.stl', 6],
237
- ['./data/User Library-engrenage.stl', 18],
238
- ['./data/00010900.stl', 42],
239
- ['./data/21492_8bd34fc1_0008.stl', 42],
240
- ['./data/00375556.stl', 96],
241
- ['./data/49121_adb01620_0000.stl', 42]],
242
- example_labels=[
243
- 'fusion360_table1', 'deepcad_star', 'cc3d_gear', 'deepcad_barrels',
244
- 'fusion360_gear', 'deepcad_house', 'fusion360_table2'],
245
- inputs=[in_model, seed_slider],
246
- cache_examples=False)
247
- with gr.Row():
248
- run_button = gr.Button('Run')
249
- _ = gr.LoginButton()
250
-
251
- with gr.Column():
252
- out_code = gr.Code(language='python', label='3. Generated Python Code', wrap_lines=True, interactive=False)
253
-
254
- with gr.Column():
255
- log_textbox = gr.Textbox(label='Log', placeholder='Status: OK', interactive=False)
256
-
257
- run_button.click(
258
- run_test_safe, inputs=[in_model, seed_slider], outputs=[point_model, out_code, out_model, log_textbox])
259
 
260
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import traceback
5
  import numpy as np
6
  import gradio as gr
7
+ from functools import partial
8
  from multiprocessing import Process, Queue
9
 
10
  import torch
 
149
  if process.is_alive():
150
  process.terminate()
151
  process.join()
152
+ raise gr.Error('Process is alive after 3 seconds')
153
 
154
  if not queue.empty():
155
+ raise gr.Error(queue.get())
156
+
157
+
158
+ def run_point_cloud(in_mesh_path, seed):
159
+ try:
160
+ mesh = trimesh.load(in_mesh_path)
161
+ mesh.apply_translation(-(mesh.bounds[0] + mesh.bounds[1]) / 2.0)
162
+ mesh.apply_scale(2.0 / max(mesh.extents))
163
+ np.random.seed(seed)
164
+ point_cloud = mesh_to_point_cloud(mesh)
165
+ pcd_path = '/tmp/pcd.obj'
166
+ trimesh.points.PointCloud(point_cloud[:, :3]).export(pcd_path)
167
+ return point_cloud, pcd_path
168
+ except:
169
+ raise gr.Error(traceback.format_exc())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
 
172
  @spaces.GPU(duration=20)
173
+ def run_cad_recode(point_cloud):
 
174
  try:
175
+ input_ids = [tokenizer.pad_token_id] * len(point_cloud) + [tokenizer('<|im_start|>')['input_ids'][0]]
176
+ attention_mask = [-1] * len(point_cloud) + [1]
177
+ if torch.cuda.is_available():
178
+ model = cad_recode.cuda()
179
+ with torch.no_grad():
180
+ batch_ids = cad_recode.generate(
181
+ input_ids=torch.tensor(input_ids).unsqueeze(0).to(model.device),
182
+ attention_mask=torch.tensor(attention_mask).unsqueeze(0).to(model.device),
183
+ point_cloud=torch.tensor(point_cloud.astype(np.float32)).unsqueeze(0).to(model.device),
184
+ max_new_tokens=768,
185
+ pad_token_id=tokenizer.pad_token_id).cpu()
186
+ py_string = tokenizer.batch_decode(batch_ids)[0]
187
+ begin = py_string.find('<|im_start|>') + 12
188
+ end = py_string.find('<|endoftext|>')
189
+ py_string = py_string[begin: end]
190
+ return py_string, py_string
191
  except:
192
+ raise gr.Error(traceback.format_exc())
 
 
193
 
 
194
 
195
+ def run_mesh(py_string):
196
+ try:
197
+ out_mesh_path = '/tmp/mesh.stl'
198
+ py_string_to_mesh_file_safe(py_string, out_mesh_path)
199
+ return out_mesh_path
200
+ except:
201
+ raise gr.Error(traceback.format_exc())
202
 
 
 
 
 
203
 
204
+ def run():
205
+ with gr.Blocks() as demo:
206
+ with gr.Row():
207
+ gr.Markdown('## CAD-Recode Demo\n'
208
+ 'Upload mesh or select from examples and press Run! Mesh ⇾ 256 points ⇾ Python code by CAD-Recode ⇾ CAD model.')
209
+
210
+ with gr.Row(equal_height=True):
211
+ in_model = gr.Model3D(label='1. Input Mesh', interactive=True)
212
+ point_model = gr.Model3D(label='2. Sampled Point Cloud', display_mode='point_cloud', interactive=False)
213
+ out_model = gr.Model3D(
214
+ label='4. Result CAD Model', interactive=False
215
+ )
216
+
217
+ with gr.Row():
218
+ with gr.Column():
219
+ with gr.Row():
220
+ seed_slider = gr.Slider(label='Random Seed', value=42, interactive=True)
221
+ with gr.Row():
222
+ gr.Examples(
223
+ examples=[
224
+ ['./data/49215_5368e45e_0000.stl', 42],
225
+ ['./data/00882236.stl', 6],
226
+ ['./data/User Library-engrenage.stl', 18],
227
+ ['./data/00010900.stl', 42],
228
+ ['./data/21492_8bd34fc1_0008.stl', 42],
229
+ ['./data/00375556.stl', 53], # todo: 96?
230
+ ['./data/49121_adb01620_0000.stl', 42]],
231
+ example_labels=[
232
+ 'fusion360_table1', 'deepcad_star', 'cc3d_gear', 'deepcad_barrels',
233
+ 'fusion360_gear', 'deepcad_house', 'fusion360_table2'],
234
+ inputs=[in_model, seed_slider],
235
+ cache_examples=False)
236
+ with gr.Row():
237
+ run_button = gr.Button('Run')
238
+
239
+ with gr.Column():
240
+ out_code = gr.Code(language='python', label='3. Generated Python Code', wrap_lines=True, interactive=False)
241
+
242
+ with gr.Column():
243
+ pass
244
+
245
+ state = gr.State()
246
+ run_button.click(
247
+ run_point_cloud,
248
+ inputs=[in_model, seed_slider],
249
+ outputs=[state, point_model]
250
+ ).success(
251
+ run_cad_recode,
252
+ inputs=[state],
253
+ outputs=[state, out_code]
254
+ ).success(
255
+ run_mesh,
256
+ inputs=[state],
257
+ outputs=[out_model]
258
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
+ demo.launch(show_error=True)
261
+
262
+
263
+ tokenizer = AutoTokenizer.from_pretrained(
264
+ 'Qwen/Qwen2-1.5B',
265
+ pad_token='<|im_end|>',
266
+ padding_side='left')
267
+ cad_recode = CADRecode.from_pretrained(
268
+ 'filapro/cad-recode',
269
+ torch_dtype='auto').eval()
270
+
271
+ os.environ['TOKENIZERS_PARALLELISM'] = 'False'
272
+ run()