cad-recode / app.py
filapro's picture
Update app.py
02332b1 verified
raw
history blame
10.1 kB
import os
import spaces
import trimesh
import traceback
import numpy as np
import gradio as gr
from multiprocessing import Process, Queue
import torch
from torch import nn
from transformers import (
AutoTokenizer, Qwen2ForCausalLM, Qwen2Model, PreTrainedModel)
from transformers.modeling_outputs import CausalLMOutputWithPast
class FourierPointEncoder(nn.Module):
def __init__(self, hidden_size):
super().__init__()
frequencies = 2.0 ** torch.arange(8, dtype=torch.float32)
self.register_buffer('frequencies', frequencies, persistent=False)
self.projection = nn.Linear(54, hidden_size)
def forward(self, points):
x = points[..., :3]
x = (x.unsqueeze(-1) * self.frequencies).view(*x.shape[:-1], -1)
x = torch.cat((points[..., :3], x.sin(), x.cos()), dim=-1)
x = self.projection(torch.cat((x, points[..., 3:]), dim=-1))
return x
class CADRecode(Qwen2ForCausalLM):
def __init__(self, config):
PreTrainedModel.__init__(self, config)
self.model = Qwen2Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
torch.set_default_dtype(torch.float32)
self.point_encoder = FourierPointEncoder(config.hidden_size)
torch.set_default_dtype(torch.bfloat16)
def forward(self,
input_ids=None,
attention_mask=None,
point_cloud=None,
position_ids=None,
past_key_values=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
cache_position=None):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# concatenate point and text embeddings
if past_key_values is None or past_key_values.get_seq_length() == 0:
assert inputs_embeds is None
inputs_embeds = self.model.embed_tokens(input_ids)
point_embeds = self.point_encoder(point_cloud).bfloat16()
inputs_embeds[attention_mask == -1] = point_embeds.reshape(-1, point_embeds.shape[2])
attention_mask[attention_mask == -1] = 1
input_ids = None
position_ids = None
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions)
def prepare_inputs_for_generation(self, *args, **kwargs):
model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
model_inputs['point_cloud'] = kwargs['point_cloud']
return model_inputs
def mesh_to_point_cloud(mesh, n_points=256):
vertices, faces = trimesh.sample.sample_surface(mesh, n_points)
point_cloud = np.concatenate((
np.asarray(vertices),
mesh.face_normals[faces]
), axis=1)
ids = np.lexsort((point_cloud[:, 0], point_cloud[:, 1], point_cloud[:, 2]))
point_cloud = point_cloud[ids]
return point_cloud
def py_string_to_mesh_file(py_string, mesh_path, queue):
try:
exec(py_string, globals())
compound = globals()['r'].val()
vertices, faces = compound.tessellate(0.001, 0.1)
mesh = trimesh.Trimesh([(v.x, v.y, v.z) for v in vertices], faces)
mesh.export(mesh_path)
except:
queue.put(traceback.format_exc())
def py_string_to_mesh_file_safe(py_string, mesh_path):
# CadQuery code predicted by LLM may be unsafe and cause memory leaks.
# That's why we execute it in a separace Process with timeout.
queue = Queue()
process = Process(
target=py_string_to_mesh_file,
args=(py_string, mesh_path, queue))
process.start()
process.join(5)
if process.is_alive():
process.terminate()
process.join()
raise RuntimeError('Process is alive after 3 seconds')
if not queue.empty():
raise RuntimeError(queue.get())
@spaces.GPU(duration=20)
def run_gpu(model, input_ids, attention_mask, point_cloud, pad_token_id):
if torch.cuda.is_available():
model = model.cuda()
with torch.no_grad():
batch_ids = model.generate(
input_ids=torch.tensor(input_ids).unsqueeze(0).to(model.device),
attention_mask=torch.tensor(attention_mask).unsqueeze(0).to(model.device),
point_cloud=torch.tensor(point_cloud.astype(np.float32)).unsqueeze(0).to(model.device),
max_new_tokens=768,
pad_token_id=pad_token_id).cpu()
return batch_ids
def run_test(in_mesh_path, seed, results):
mesh = trimesh.load(in_mesh_path)
mesh.apply_translation(-(mesh.bounds[0] + mesh.bounds[1]) / 2.0)
mesh.apply_scale(2.0 / max(mesh.extents))
np.random.seed(seed)
point_cloud = mesh_to_point_cloud(mesh)
pcd_path = '/tmp/pcd.obj'
trimesh.points.PointCloud(point_cloud[:, :3]).export(pcd_path)
results.append(pcd_path)
tokenizer = AutoTokenizer.from_pretrained(
'Qwen/Qwen2-1.5B',
pad_token='<|im_end|>',
padding_side='left')
model = CADRecode.from_pretrained(
'filapro/cad-recode',
torch_dtype='auto').eval()
input_ids = [tokenizer.pad_token_id] * len(point_cloud) + [tokenizer('<|im_start|>')['input_ids'][0]]
attention_mask = [-1] * len(point_cloud) + [1]
batch_ids = run_gpu(model, input_ids, attention_mask, point_cloud, tokenizer.pad_token_id)
py_string = tokenizer.batch_decode(batch_ids)[0]
begin = py_string.find('<|im_start|>') + 12
end = py_string.find('<|endoftext|>')
py_string = py_string[begin: end]
results.append(py_string)
out_mesh_path = '/tmp/mesh.stl'
py_string_to_mesh_file_safe(py_string, out_mesh_path)
results.append(out_mesh_path)
def run_test_safe(in_mesh_path, seed):
results, log = list(), str()
try:
run_test(in_mesh_path, seed, results)
except:
log += 'Status: FAILED\n' + traceback.format_exc()
return results + [None] * (3 - len(results)) + [log]
def run():
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown('## CAD-Recode Demo\n'
'Upload mesh or select from examples and press Run! Mesh ⇾ 256 points ⇾ Python code by CAD-Recode ⇾ CAD model.')
with gr.Row(equal_height=True):
in_model = gr.Model3D(label='1. Input Mesh', interactive=True)
point_model = gr.Model3D(label='2. Sampled Point Cloud', display_mode='point_cloud', interactive=False)
out_model = gr.Model3D(
label='4. Result CAD Model', interactive=False
)
with gr.Row():
with gr.Column():
with gr.Row():
seed_slider = gr.Slider(label='Random Seed', value=42, interactive=True)
with gr.Row():
_ = gr.Examples(
examples=[
['./data/49215_5368e45e_0000.stl', 42],
['./data/00882236.stl', 6],
['./data/User Library-engrenage.stl', 18],
['./data/00010900.stl', 42],
['./data/21492_8bd34fc1_0008.stl', 42],
['./data/00375556.stl', 96],
['./data/49121_adb01620_0000.stl', 42]],
example_labels=[
'fusion360_table1', 'deepcad_star', 'cc3d_gear', 'deepcad_barrels',
'fusion360_gear', 'deepcad_house', 'fusion360_table2'],
inputs=[in_model, seed_slider],
cache_examples=False)
with gr.Row():
run_button = gr.Button('Run')
with gr.Column():
out_code = gr.Code(language='python', label='3. Generated Python Code', wrap_lines=True, interactive=False)
with gr.Column():
log_textbox = gr.Textbox(label='Log', placeholder='Status: OK', interactive=False)
run_button.click(
run_test_safe, inputs=[in_model, seed_slider], outputs=[point_model, out_code, out_model, log_textbox])
demo.launch()
os.environ['TOKENIZERS_PARALLELISM'] = 'False'
run()