import os import spaces import trimesh import traceback import numpy as np import gradio as gr from functools import partial from multiprocessing import Process, Queue import torch from torch import nn from transformers import ( AutoTokenizer, Qwen2ForCausalLM, Qwen2Model, PreTrainedModel) from transformers.modeling_outputs import CausalLMOutputWithPast class FourierPointEncoder(nn.Module): def __init__(self, hidden_size): super().__init__() frequencies = 2.0 ** torch.arange(8, dtype=torch.float32) self.register_buffer('frequencies', frequencies, persistent=False) self.projection = nn.Linear(54, hidden_size) def forward(self, points): x = points[..., :3] x = (x.unsqueeze(-1) * self.frequencies).view(*x.shape[:-1], -1) x = torch.cat((points[..., :3], x.sin(), x.cos()), dim=-1) x = self.projection(torch.cat((x, points[..., 3:]), dim=-1)) return x class CADRecode(Qwen2ForCausalLM): def __init__(self, config): PreTrainedModel.__init__(self, config) self.model = Qwen2Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) torch.set_default_dtype(torch.float32) self.point_encoder = FourierPointEncoder(config.hidden_size) torch.set_default_dtype(torch.bfloat16) def forward(self, input_ids=None, attention_mask=None, point_cloud=None, position_ids=None, past_key_values=None, inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, cache_position=None): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states return_dict = return_dict if return_dict is not None else self.config.use_return_dict # concatenate point and text embeddings if past_key_values is None or past_key_values.get_seq_length() == 0: assert inputs_embeds is None inputs_embeds = self.model.embed_tokens(input_ids) point_embeds = self.point_encoder(point_cloud).bfloat16() inputs_embeds[attention_mask == -1] = point_embeds.reshape(-1, point_embeds.shape[2]) attention_mask[attention_mask == -1] = 1 input_ids = None position_ids = None # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position) hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions) def prepare_inputs_for_generation(self, *args, **kwargs): model_inputs = super().prepare_inputs_for_generation(*args, **kwargs) model_inputs['point_cloud'] = kwargs['point_cloud'] return model_inputs def mesh_to_point_cloud(mesh, n_points=256): vertices, faces = trimesh.sample.sample_surface(mesh, n_points) point_cloud = np.concatenate(( np.asarray(vertices), mesh.face_normals[faces] ), axis=1) ids = np.lexsort((point_cloud[:, 0], point_cloud[:, 1], point_cloud[:, 2])) point_cloud = point_cloud[ids] return point_cloud def py_string_to_mesh_file(py_string, mesh_path, queue): try: exec(py_string, globals()) compound = globals()['r'].val() vertices, faces = compound.tessellate(0.001, 0.1) mesh = trimesh.Trimesh([(v.x, v.y, v.z) for v in vertices], faces) mesh.export(mesh_path) except: queue.put(traceback.format_exc()) def py_string_to_mesh_file_safe(py_string, mesh_path): # CadQuery code predicted by LLM may be unsafe and cause memory leaks. # That's why we execute it in a separace Process with timeout. queue = Queue() process = Process( target=py_string_to_mesh_file, args=(py_string, mesh_path, queue)) process.start() process.join(5) if process.is_alive(): process.terminate() process.join() raise gr.Error('Process is alive after 3 seconds') if not queue.empty(): raise gr.Error(queue.get()) def run_point_cloud(in_mesh_path, seed): try: mesh = trimesh.load(in_mesh_path) mesh.apply_translation(-(mesh.bounds[0] + mesh.bounds[1]) / 2.0) mesh.apply_scale(2.0 / max(mesh.extents)) np.random.seed(seed) point_cloud = mesh_to_point_cloud(mesh) pcd_path = '/tmp/pcd.obj' trimesh.points.PointCloud(point_cloud[:, :3]).export(pcd_path) return point_cloud, pcd_path except: raise gr.Error(traceback.format_exc()) @spaces.GPU(duration=20) def run_cad_recode(point_cloud): try: input_ids = [tokenizer.pad_token_id] * len(point_cloud) + [tokenizer('<|im_start|>')['input_ids'][0]] attention_mask = [-1] * len(point_cloud) + [1] if torch.cuda.is_available(): model = cad_recode.cuda() with torch.no_grad(): batch_ids = cad_recode.generate( input_ids=torch.tensor(input_ids).unsqueeze(0).to(model.device), attention_mask=torch.tensor(attention_mask).unsqueeze(0).to(model.device), point_cloud=torch.tensor(point_cloud.astype(np.float32)).unsqueeze(0).to(model.device), max_new_tokens=768, pad_token_id=tokenizer.pad_token_id).cpu() py_string = tokenizer.batch_decode(batch_ids)[0] begin = py_string.find('<|im_start|>') + 12 end = py_string.find('<|endoftext|>') py_string = py_string[begin: end] return py_string, py_string except: raise gr.Error(traceback.format_exc()) def run_mesh(py_string): try: out_mesh_path = '/tmp/mesh.stl' py_string_to_mesh_file_safe(py_string, out_mesh_path) return out_mesh_path except: raise gr.Error(traceback.format_exc()) def run(): with gr.Blocks() as demo: with gr.Row(): gr.Markdown('## CAD-Recode Demo\n' 'Upload mesh or select from examples and press Run! Mesh ⇾ 256 points ⇾ Python code by CAD-Recode ⇾ CAD model.') with gr.Row(equal_height=True): in_model = gr.Model3D(label='1. Input Mesh', interactive=True) point_model = gr.Model3D(label='2. Sampled Point Cloud', display_mode='point_cloud', interactive=False) out_model = gr.Model3D( label='4. Result CAD Model', interactive=False ) with gr.Row(): with gr.Column(): with gr.Row(): seed_slider = gr.Slider(label='Random Seed', value=42, interactive=True) with gr.Row(): gr.Examples( examples=[ ['./data/49215_5368e45e_0000.stl', 42], ['./data/00882236.stl', 6], ['./data/User Library-engrenage.stl', 18], ['./data/00010900.stl', 42], ['./data/21492_8bd34fc1_0008.stl', 42], ['./data/00375556.stl', 53], # todo: 96? ['./data/49121_adb01620_0000.stl', 42]], example_labels=[ 'fusion360_table1', 'deepcad_star', 'cc3d_gear', 'deepcad_barrels', 'fusion360_gear', 'deepcad_house', 'fusion360_table2'], inputs=[in_model, seed_slider], cache_examples=False) with gr.Row(): run_button = gr.Button('Run') with gr.Column(): out_code = gr.Code(language='python', label='3. Generated Python Code', wrap_lines=True, interactive=False) with gr.Column(): pass state = gr.State() run_button.click( run_point_cloud, inputs=[in_model, seed_slider], outputs=[state, point_model] ).success( run_cad_recode, inputs=[state], outputs=[state, out_code] ).success( run_mesh, inputs=[state], outputs=[out_model] ) demo.launch(show_error=True) tokenizer = AutoTokenizer.from_pretrained( 'Qwen/Qwen2-1.5B', pad_token='<|im_end|>', padding_side='left') cad_recode = CADRecode.from_pretrained( 'filapro/cad-recode', torch_dtype='auto').eval() os.environ['TOKENIZERS_PARALLELISM'] = 'False' run()