Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import spaces | |
import trimesh | |
import traceback | |
import numpy as np | |
import gradio as gr | |
from functools import partial | |
from multiprocessing import Process, Queue | |
import torch | |
from torch import nn | |
from transformers import ( | |
AutoTokenizer, Qwen2ForCausalLM, Qwen2Model, PreTrainedModel) | |
from transformers.modeling_outputs import CausalLMOutputWithPast | |
class FourierPointEncoder(nn.Module): | |
def __init__(self, hidden_size): | |
super().__init__() | |
frequencies = 2.0 ** torch.arange(8, dtype=torch.float32) | |
self.register_buffer('frequencies', frequencies, persistent=False) | |
self.projection = nn.Linear(54, hidden_size) | |
def forward(self, points): | |
x = points[..., :3] | |
x = (x.unsqueeze(-1) * self.frequencies).view(*x.shape[:-1], -1) | |
x = torch.cat((points[..., :3], x.sin(), x.cos()), dim=-1) | |
x = self.projection(torch.cat((x, points[..., 3:]), dim=-1)) | |
return x | |
class CADRecode(Qwen2ForCausalLM): | |
def __init__(self, config): | |
PreTrainedModel.__init__(self, config) | |
self.model = Qwen2Model(config) | |
self.vocab_size = config.vocab_size | |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
torch.set_default_dtype(torch.float32) | |
self.point_encoder = FourierPointEncoder(config.hidden_size) | |
torch.set_default_dtype(torch.bfloat16) | |
def forward(self, | |
input_ids=None, | |
attention_mask=None, | |
point_cloud=None, | |
position_ids=None, | |
past_key_values=None, | |
inputs_embeds=None, | |
labels=None, | |
use_cache=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
cache_position=None): | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
# concatenate point and text embeddings | |
if past_key_values is None or past_key_values.get_seq_length() == 0: | |
assert inputs_embeds is None | |
inputs_embeds = self.model.embed_tokens(input_ids) | |
point_embeds = self.point_encoder(point_cloud).bfloat16() | |
inputs_embeds[attention_mask == -1] = point_embeds.reshape(-1, point_embeds.shape[2]) | |
attention_mask[attention_mask == -1] = 1 | |
input_ids = None | |
position_ids = None | |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | |
outputs = self.model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
cache_position=cache_position) | |
hidden_states = outputs[0] | |
logits = self.lm_head(hidden_states) | |
logits = logits.float() | |
loss = None | |
if labels is not None: | |
# Shift so that tokens < n predict n | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
# Flatten the tokens | |
loss_fct = nn.CrossEntropyLoss() | |
shift_logits = shift_logits.view(-1, self.config.vocab_size) | |
shift_labels = shift_labels.view(-1) | |
# Enable model parallelism | |
shift_labels = shift_labels.to(shift_logits.device) | |
loss = loss_fct(shift_logits, shift_labels) | |
if not return_dict: | |
output = (logits,) + outputs[1:] | |
return (loss,) + output if loss is not None else output | |
return CausalLMOutputWithPast( | |
loss=loss, | |
logits=logits, | |
past_key_values=outputs.past_key_values, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions) | |
def prepare_inputs_for_generation(self, *args, **kwargs): | |
model_inputs = super().prepare_inputs_for_generation(*args, **kwargs) | |
model_inputs['point_cloud'] = kwargs['point_cloud'] | |
return model_inputs | |
def mesh_to_point_cloud(mesh, n_points=256): | |
vertices, faces = trimesh.sample.sample_surface(mesh, n_points) | |
point_cloud = np.concatenate(( | |
np.asarray(vertices), | |
mesh.face_normals[faces] | |
), axis=1) | |
ids = np.lexsort((point_cloud[:, 0], point_cloud[:, 1], point_cloud[:, 2])) | |
point_cloud = point_cloud[ids] | |
return point_cloud | |
def py_string_to_mesh_file(py_string, mesh_path, queue): | |
try: | |
exec(py_string, globals()) | |
compound = globals()['r'].val() | |
vertices, faces = compound.tessellate(0.001, 0.1) | |
mesh = trimesh.Trimesh([(v.x, v.y, v.z) for v in vertices], faces) | |
mesh.export(mesh_path) | |
except: | |
queue.put(traceback.format_exc()) | |
def py_string_to_mesh_file_safe(py_string, mesh_path): | |
# CadQuery code predicted by LLM may be unsafe and cause memory leaks. | |
# That's why we execute it in a separace Process with timeout. | |
queue = Queue() | |
process = Process( | |
target=py_string_to_mesh_file, | |
args=(py_string, mesh_path, queue)) | |
process.start() | |
process.join(5) | |
if process.is_alive(): | |
process.terminate() | |
process.join() | |
raise gr.Error('Process is alive after 3 seconds') | |
if not queue.empty(): | |
raise gr.Error(queue.get()) | |
def run_point_cloud(in_mesh_path, seed): | |
try: | |
mesh = trimesh.load(in_mesh_path) | |
mesh.apply_translation(-(mesh.bounds[0] + mesh.bounds[1]) / 2.0) | |
mesh.apply_scale(2.0 / max(mesh.extents)) | |
np.random.seed(seed) | |
point_cloud = mesh_to_point_cloud(mesh) | |
pcd_path = '/tmp/pcd.obj' | |
trimesh.points.PointCloud(point_cloud[:, :3]).export(pcd_path) | |
return point_cloud, pcd_path | |
except: | |
raise gr.Error(traceback.format_exc()) | |
def run_cad_recode(point_cloud): | |
try: | |
input_ids = [tokenizer.pad_token_id] * len(point_cloud) + [tokenizer('<|im_start|>')['input_ids'][0]] | |
attention_mask = [-1] * len(point_cloud) + [1] | |
if torch.cuda.is_available(): | |
model = cad_recode.cuda() | |
with torch.no_grad(): | |
batch_ids = cad_recode.generate( | |
input_ids=torch.tensor(input_ids).unsqueeze(0).to(model.device), | |
attention_mask=torch.tensor(attention_mask).unsqueeze(0).to(model.device), | |
point_cloud=torch.tensor(point_cloud.astype(np.float32)).unsqueeze(0).to(model.device), | |
max_new_tokens=768, | |
pad_token_id=tokenizer.pad_token_id).cpu() | |
py_string = tokenizer.batch_decode(batch_ids)[0] | |
begin = py_string.find('<|im_start|>') + 12 | |
end = py_string.find('<|endoftext|>') | |
py_string = py_string[begin: end] | |
return py_string, py_string | |
except: | |
raise gr.Error(traceback.format_exc()) | |
def run_mesh(py_string): | |
try: | |
out_mesh_path = '/tmp/mesh.stl' | |
py_string_to_mesh_file_safe(py_string, out_mesh_path) | |
return out_mesh_path | |
except: | |
raise gr.Error(traceback.format_exc()) | |
def run(): | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
gr.Markdown('## CAD-Recode Demo\n' | |
'Upload mesh or select from examples and press Run! Mesh ⇾ 256 points ⇾ Python code by CAD-Recode ⇾ CAD model.') | |
with gr.Row(equal_height=True): | |
in_model = gr.Model3D(label='1. Input Mesh', interactive=True) | |
point_model = gr.Model3D(label='2. Sampled Point Cloud', display_mode='point_cloud', interactive=False) | |
out_model = gr.Model3D( | |
label='4. Result CAD Model', interactive=False | |
) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
seed_slider = gr.Slider(label='Random Seed', value=42, interactive=True) | |
with gr.Row(): | |
gr.Examples( | |
examples=[ | |
['./data/49215_5368e45e_0000.stl', 42], | |
['./data/00882236.stl', 6], | |
['./data/User Library-engrenage.stl', 18], | |
['./data/00010900.stl', 42], | |
['./data/21492_8bd34fc1_0008.stl', 42], | |
['./data/00375556.stl', 53], # todo: 96? | |
['./data/49121_adb01620_0000.stl', 42]], | |
example_labels=[ | |
'fusion360_table1', 'deepcad_star', 'cc3d_gear', 'deepcad_barrels', | |
'fusion360_gear', 'deepcad_house', 'fusion360_table2'], | |
inputs=[in_model, seed_slider], | |
cache_examples=False) | |
with gr.Row(): | |
run_button = gr.Button('Run') | |
with gr.Column(): | |
out_code = gr.Code(language='python', label='3. Generated Python Code', wrap_lines=True, interactive=False) | |
with gr.Column(): | |
pass | |
state = gr.State() | |
run_button.click( | |
run_point_cloud, | |
inputs=[in_model, seed_slider], | |
outputs=[state, point_model] | |
).success( | |
run_cad_recode, | |
inputs=[state], | |
outputs=[state, out_code] | |
).success( | |
run_mesh, | |
inputs=[state], | |
outputs=[out_model] | |
) | |
demo.launch(show_error=True) | |
tokenizer = AutoTokenizer.from_pretrained( | |
'Qwen/Qwen2-1.5B', | |
pad_token='<|im_end|>', | |
padding_side='left') | |
cad_recode = CADRecode.from_pretrained( | |
'filapro/cad-recode', | |
torch_dtype='auto').eval() | |
os.environ['TOKENIZERS_PARALLELISM'] = 'False' | |
run() |