filapro commited on
Commit
f70260b
1 Parent(s): 81672ab

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +261 -0
app.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import spaces
3
+ import trimesh
4
+ import traceback
5
+ import numpy as np
6
+ import gradio as gr
7
+ from multiprocessing import Process, Queue
8
+
9
+ import torch
10
+ from torch import nn
11
+ from transformers import (
12
+ AutoTokenizer, Qwen2ForCausalLM, Qwen2Model, PreTrainedModel)
13
+ from transformers.modeling_outputs import CausalLMOutputWithPast
14
+
15
+
16
+ class FourierPointEncoder(nn.Module):
17
+ def __init__(self, hidden_size):
18
+ super().__init__()
19
+ frequencies = 2.0 ** torch.arange(8, dtype=torch.float32)
20
+ self.register_buffer('frequencies', frequencies, persistent=False)
21
+ self.projection = nn.Linear(54, hidden_size)
22
+
23
+ def forward(self, points):
24
+ x = points[..., :3]
25
+ x = (x.unsqueeze(-1) * self.frequencies).view(*x.shape[:-1], -1)
26
+ x = torch.cat((points[..., :3], x.sin(), x.cos()), dim=-1)
27
+ x = self.projection(torch.cat((x, points[..., 3:]), dim=-1))
28
+ return x
29
+
30
+
31
+ class CADRecode(Qwen2ForCausalLM):
32
+ def __init__(self, config):
33
+ PreTrainedModel.__init__(self, config)
34
+ self.model = Qwen2Model(config)
35
+ self.vocab_size = config.vocab_size
36
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
37
+
38
+ torch.set_default_dtype(torch.float32)
39
+ self.point_encoder = FourierPointEncoder(config.hidden_size)
40
+ torch.set_default_dtype(torch.bfloat16)
41
+
42
+ def forward(self,
43
+ input_ids=None,
44
+ attention_mask=None,
45
+ point_cloud=None,
46
+ position_ids=None,
47
+ past_key_values=None,
48
+ inputs_embeds=None,
49
+ labels=None,
50
+ use_cache=None,
51
+ output_attentions=None,
52
+ output_hidden_states=None,
53
+ return_dict=None,
54
+ cache_position=None):
55
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
56
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
57
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
58
+
59
+ # concatenate point and text embeddings
60
+ if past_key_values is None or past_key_values.get_seq_length() == 0:
61
+ assert inputs_embeds is None
62
+ inputs_embeds = self.model.embed_tokens(input_ids)
63
+ point_embeds = self.point_encoder(point_cloud).bfloat16()
64
+ inputs_embeds[attention_mask == -1] = point_embeds.reshape(-1, point_embeds.shape[2])
65
+ attention_mask[attention_mask == -1] = 1
66
+ input_ids = None
67
+ position_ids = None
68
+
69
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
70
+ outputs = self.model(
71
+ input_ids=input_ids,
72
+ attention_mask=attention_mask,
73
+ position_ids=position_ids,
74
+ past_key_values=past_key_values,
75
+ inputs_embeds=inputs_embeds,
76
+ use_cache=use_cache,
77
+ output_attentions=output_attentions,
78
+ output_hidden_states=output_hidden_states,
79
+ return_dict=return_dict,
80
+ cache_position=cache_position)
81
+
82
+ hidden_states = outputs[0]
83
+ logits = self.lm_head(hidden_states)
84
+ logits = logits.float()
85
+
86
+ loss = None
87
+ if labels is not None:
88
+ # Shift so that tokens < n predict n
89
+ shift_logits = logits[..., :-1, :].contiguous()
90
+ shift_labels = labels[..., 1:].contiguous()
91
+ # Flatten the tokens
92
+ loss_fct = nn.CrossEntropyLoss()
93
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
94
+ shift_labels = shift_labels.view(-1)
95
+ # Enable model parallelism
96
+ shift_labels = shift_labels.to(shift_logits.device)
97
+ loss = loss_fct(shift_logits, shift_labels)
98
+
99
+ if not return_dict:
100
+ output = (logits,) + outputs[1:]
101
+ return (loss,) + output if loss is not None else output
102
+
103
+ return CausalLMOutputWithPast(
104
+ loss=loss,
105
+ logits=logits,
106
+ past_key_values=outputs.past_key_values,
107
+ hidden_states=outputs.hidden_states,
108
+ attentions=outputs.attentions)
109
+
110
+ def prepare_inputs_for_generation(self, *args, **kwargs):
111
+ model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
112
+ model_inputs['point_cloud'] = kwargs['point_cloud']
113
+ return model_inputs
114
+
115
+
116
+ def mesh_to_point_cloud(mesh, n_points=256):
117
+ vertices, faces = trimesh.sample.sample_surface(mesh, n_points)
118
+ point_cloud = np.concatenate((
119
+ np.asarray(vertices),
120
+ mesh.face_normals[faces]
121
+ ), axis=1)
122
+ ids = np.lexsort((point_cloud[:, 0], point_cloud[:, 1], point_cloud[:, 2]))
123
+ point_cloud = point_cloud[ids]
124
+ return point_cloud
125
+
126
+
127
+ def py_string_to_mesh_file(py_string, mesh_path, queue):
128
+ try:
129
+ exec(py_string, globals())
130
+ compound = globals()['r'].val()
131
+ vertices, faces = compound.tessellate(0.001, 0.1)
132
+ mesh = trimesh.Trimesh([(v.x, v.y, v.z) for v in vertices], faces)
133
+ mesh.export(mesh_path)
134
+ except:
135
+ queue.put(traceback.format_exc())
136
+
137
+
138
+ def py_string_to_mesh_file_safe(py_string, mesh_path):
139
+ # CadQuery code predicted by LLM may be unsafe and cause memory leaks.
140
+ # That's why we execute it in a separace Process with timeout.
141
+ queue = Queue()
142
+ process = Process(
143
+ target=py_string_to_mesh_file,
144
+ args=(py_string, mesh_path, queue))
145
+ process.start()
146
+ process.join(3)
147
+
148
+ if process.is_alive():
149
+ process.terminate()
150
+ process.join()
151
+ raise RuntimeError('Process is alive after 3 seconds')
152
+
153
+ if not queue.empty():
154
+ raise RuntimeError(queue.get())
155
+
156
+
157
+ @spaces.GPU(duration=20)
158
+ def run_gpu(model, input_ids, attention_mask, point_cloud, pad_token_id):
159
+ if torch.cuda.is_available():
160
+ model = model.cuda()
161
+ with torch.no_grad():
162
+ batch_ids = model.generate(
163
+ input_ids=torch.tensor(input_ids).unsqueeze(0).to(model.device),
164
+ attention_mask=torch.tensor(attention_mask).unsqueeze(0).to(model.device),
165
+ point_cloud=torch.tensor(point_cloud.astype(np.float32)).unsqueeze(0).to(model.device),
166
+ max_new_tokens=768,
167
+ pad_token_id=pad_token_id).cpu()
168
+ return batch_ids
169
+
170
+
171
+ def run_test(in_mesh_path, seed, results):
172
+ mesh = trimesh.load(in_mesh_path)
173
+ mesh.apply_translation(-(mesh.bounds[0] + mesh.bounds[1]) / 2.0)
174
+ mesh.apply_scale(2.0 / max(mesh.extents))
175
+ np.random.seed(seed)
176
+ point_cloud = mesh_to_point_cloud(mesh)
177
+
178
+ pcd_path = '/tmp/pcd.obj'
179
+ trimesh.points.PointCloud(point_cloud[:, :3]).export(pcd_path)
180
+ results.append(pcd_path)
181
+
182
+ tokenizer = AutoTokenizer.from_pretrained(
183
+ 'Qwen/Qwen2-1.5B',
184
+ pad_token='<|im_end|>',
185
+ padding_side='left')
186
+ model = CADRecode.from_pretrained(
187
+ 'filapro/cad-recode',
188
+ torch_dtype='auto').eval()
189
+
190
+ input_ids = [tokenizer.pad_token_id] * len(point_cloud) + [tokenizer('<|im_start|>')['input_ids'][0]]
191
+ attention_mask = [-1] * len(point_cloud) + [1]
192
+ batch_ids = run_gpu(model, input_ids, attention_mask, point_cloud, tokenizer.pad_token_id)
193
+ py_string = tokenizer.batch_decode(batch_ids)[0]
194
+ begin = py_string.find('<|im_start|>') + 12
195
+ end = py_string.find('<|endoftext|>')
196
+ py_string = py_string[begin: end]
197
+ results.append(py_string)
198
+
199
+ out_mesh_path = '/tmp/mesh.stl'
200
+ py_string_to_mesh_file_safe(py_string, out_mesh_path)
201
+ results.append(out_mesh_path)
202
+
203
+
204
+ def run_test_safe(in_mesh_path, seed):
205
+ results, log = list(), str()
206
+ try:
207
+ run_test(in_mesh_path, seed, results)
208
+ except:
209
+ log += 'Status: FAILED\n' + traceback.format_exc()
210
+ return results + [None] * (3 - len(results)) + [log]
211
+
212
+
213
+ def run():
214
+ with gr.Blocks() as demo:
215
+ with gr.Row():
216
+ gr.Markdown('## CAD-Recode Demo\n'
217
+ 'Upload mesh or select from examples and press Run! Mesh ⇾ 256 points ⇾ Python code by CAD-Recode ⇾ CAD model.')
218
+
219
+ with gr.Row(equal_height=True):
220
+ in_model = gr.Model3D(label='1. Input Mesh', interactive=True)
221
+ point_model = gr.Model3D(label='2. Sampled Point Cloud', display_mode='point_cloud', interactive=False)
222
+ out_model = gr.Model3D(
223
+ label='4. Result CAD Model', interactive=False
224
+ )
225
+
226
+ with gr.Row():
227
+ with gr.Column():
228
+ with gr.Row():
229
+ seed_slider = gr.Slider(label='Random Seed', value=42, interactive=True)
230
+ with gr.Row():
231
+ _ = gr.Examples(
232
+ examples=[
233
+ ['./data/49215_5368e45e_0000.stl', 42],
234
+ ['./data/00882236.stl', 6],
235
+ ['./data/User Library-engrenage.stl', 18],
236
+ ['./data/00010900.stl', 42],
237
+ ['./data/21492_8bd34fc1_0008.stl', 42],
238
+ ['./data/00375556.stl', 53],
239
+ ['./data/49121_adb01620_0000.stl', 42]],
240
+ example_labels=[
241
+ 'fusion360_table1', 'deepcad_star', 'cc3d_gear', 'deepcad_barrels',
242
+ 'fusion360_gear', 'deepcad_house', 'fusion360_table2'],
243
+ inputs=[in_model, seed_slider],
244
+ cache_examples=False)
245
+ with gr.Row():
246
+ run_button = gr.Button('Run')
247
+
248
+ with gr.Column():
249
+ out_code = gr.Code(language='python', label='3. Generated Python Code', wrap_lines=True, interactive=False)
250
+
251
+ with gr.Column():
252
+ log_textbox = gr.Textbox(label='Log', placeholder='Status: OK', interactive=False)
253
+
254
+ run_button.click(
255
+ run_test_safe, inputs=[in_model, seed_slider], outputs=[point_model, out_code, out_model, log_textbox])
256
+
257
+ demo.launch()
258
+
259
+
260
+ os.environ['TOKENIZERS_PARALLELISM'] = 'False'
261
+ run()