import os import spaces import trimesh import traceback import numpy as np import gradio as gr from multiprocessing import Process, Queue import torch from torch import nn from transformers import ( AutoTokenizer, Qwen2ForCausalLM, Qwen2Model, PreTrainedModel) from transformers.modeling_outputs import CausalLMOutputWithPast class FourierPointEncoder(nn.Module): def __init__(self, hidden_size): super().__init__() frequencies = 2.0 ** torch.arange(8, dtype=torch.float32) self.register_buffer('frequencies', frequencies, persistent=False) self.projection = nn.Linear(54, hidden_size) def forward(self, points): x = points[..., :3] x = (x.unsqueeze(-1) * self.frequencies).view(*x.shape[:-1], -1) x = torch.cat((points[..., :3], x.sin(), x.cos()), dim=-1) x = self.projection(torch.cat((x, points[..., 3:]), dim=-1)) return x class CADRecode(Qwen2ForCausalLM): def __init__(self, config): PreTrainedModel.__init__(self, config) self.model = Qwen2Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) torch.set_default_dtype(torch.float32) self.point_encoder = FourierPointEncoder(config.hidden_size) torch.set_default_dtype(torch.bfloat16) def forward(self, input_ids=None, attention_mask=None, point_cloud=None, position_ids=None, past_key_values=None, inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, cache_position=None): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states return_dict = return_dict if return_dict is not None else self.config.use_return_dict # concatenate point and text embeddings if past_key_values is None or past_key_values.get_seq_length() == 0: assert inputs_embeds is None inputs_embeds = self.model.embed_tokens(input_ids) point_embeds = self.point_encoder(point_cloud).bfloat16() inputs_embeds[attention_mask == -1] = point_embeds.reshape(-1, point_embeds.shape[2]) attention_mask[attention_mask == -1] = 1 input_ids = None position_ids = None # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position) hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions) def prepare_inputs_for_generation(self, *args, **kwargs): model_inputs = super().prepare_inputs_for_generation(*args, **kwargs) model_inputs['point_cloud'] = kwargs['point_cloud'] return model_inputs def mesh_to_point_cloud(mesh, n_points=256): vertices, faces = trimesh.sample.sample_surface(mesh, n_points) point_cloud = np.concatenate(( np.asarray(vertices), mesh.face_normals[faces] ), axis=1) ids = np.lexsort((point_cloud[:, 0], point_cloud[:, 1], point_cloud[:, 2])) point_cloud = point_cloud[ids] return point_cloud def py_string_to_mesh_file(py_string, mesh_path, queue): try: exec(py_string, globals()) compound = globals()['r'].val() vertices, faces = compound.tessellate(0.001, 0.1) mesh = trimesh.Trimesh([(v.x, v.y, v.z) for v in vertices], faces) mesh.export(mesh_path) except: queue.put(traceback.format_exc()) def py_string_to_mesh_file_safe(py_string, mesh_path): # CadQuery code predicted by LLM may be unsafe and cause memory leaks. # That's why we execute it in a separace Process with timeout. queue = Queue() process = Process( target=py_string_to_mesh_file, args=(py_string, mesh_path, queue)) process.start() process.join(5) if process.is_alive(): process.terminate() process.join() raise RuntimeError('Process is alive after 3 seconds') if not queue.empty(): raise RuntimeError(queue.get()) @spaces.GPU(duration=20) def run_gpu(model, input_ids, attention_mask, point_cloud, pad_token_id): if torch.cuda.is_available(): model = model.cuda() with torch.no_grad(): batch_ids = model.generate( input_ids=torch.tensor(input_ids).unsqueeze(0).to(model.device), attention_mask=torch.tensor(attention_mask).unsqueeze(0).to(model.device), point_cloud=torch.tensor(point_cloud.astype(np.float32)).unsqueeze(0).to(model.device), max_new_tokens=768, pad_token_id=pad_token_id).cpu() return batch_ids def run_test(in_mesh_path, seed, results): mesh = trimesh.load(in_mesh_path) mesh.apply_translation(-(mesh.bounds[0] + mesh.bounds[1]) / 2.0) mesh.apply_scale(2.0 / max(mesh.extents)) np.random.seed(seed) point_cloud = mesh_to_point_cloud(mesh) pcd_path = '/tmp/pcd.obj' trimesh.points.PointCloud(point_cloud[:, :3]).export(pcd_path) results.append(pcd_path) tokenizer = AutoTokenizer.from_pretrained( 'Qwen/Qwen2-1.5B', pad_token='<|im_end|>', padding_side='left') model = CADRecode.from_pretrained( 'filapro/cad-recode', torch_dtype='auto').eval() input_ids = [tokenizer.pad_token_id] * len(point_cloud) + [tokenizer('<|im_start|>')['input_ids'][0]] attention_mask = [-1] * len(point_cloud) + [1] batch_ids = run_gpu(model, input_ids, attention_mask, point_cloud, tokenizer.pad_token_id) py_string = tokenizer.batch_decode(batch_ids)[0] begin = py_string.find('<|im_start|>') + 12 end = py_string.find('<|endoftext|>') py_string = py_string[begin: end] results.append(py_string) out_mesh_path = '/tmp/mesh.stl' py_string_to_mesh_file_safe(py_string, out_mesh_path) results.append(out_mesh_path) def run_test_safe(in_mesh_path, seed): results, log = list(), str() try: run_test(in_mesh_path, seed, results) except: log += 'Status: FAILED\n' + traceback.format_exc() return results + [None] * (3 - len(results)) + [log] def run(): with gr.Blocks() as demo: with gr.Row(): gr.Markdown('## CAD-Recode Demo\n' 'Upload mesh or select from examples and press Run! Mesh ⇾ 256 points ⇾ Python code by CAD-Recode ⇾ CAD model.') with gr.Row(equal_height=True): in_model = gr.Model3D(label='1. Input Mesh', interactive=True) point_model = gr.Model3D(label='2. Sampled Point Cloud', display_mode='point_cloud', interactive=False) out_model = gr.Model3D( label='4. Result CAD Model', interactive=False ) with gr.Row(): with gr.Column(): with gr.Row(): seed_slider = gr.Slider(label='Random Seed', value=42, interactive=True) with gr.Row(): _ = gr.Examples( examples=[ ['./data/49215_5368e45e_0000.stl', 42], ['./data/00882236.stl', 6], ['./data/User Library-engrenage.stl', 18], ['./data/00010900.stl', 42], ['./data/21492_8bd34fc1_0008.stl', 42], ['./data/00375556.stl', 96], ['./data/49121_adb01620_0000.stl', 42]], example_labels=[ 'fusion360_table1', 'deepcad_star', 'cc3d_gear', 'deepcad_barrels', 'fusion360_gear', 'deepcad_house', 'fusion360_table2'], inputs=[in_model, seed_slider], cache_examples=False) with gr.Row(): run_button = gr.Button('Run') with gr.Column(): out_code = gr.Code(language='python', label='3. Generated Python Code', wrap_lines=True, interactive=False) with gr.Column(): log_textbox = gr.Textbox(label='Log', placeholder='Status: OK', interactive=False) run_button.click( run_test_safe, inputs=[in_model, seed_slider], outputs=[point_model, out_code, out_model, log_textbox]) demo.launch() os.environ['TOKENIZERS_PARALLELISM'] = 'False' run()