filapro commited on
Commit
1090857
1 Parent(s): 608acbe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -174,8 +174,7 @@ def run_cad_recode(point_cloud):
174
  try:
175
  input_ids = [tokenizer.pad_token_id] * len(point_cloud) + [tokenizer('<|im_start|>')['input_ids'][0]]
176
  attention_mask = [-1] * len(point_cloud) + [1]
177
- if torch.cuda.is_available():
178
- model = cad_recode.cuda()
179
  with torch.no_grad():
180
  batch_ids = cad_recode.generate(
181
  input_ids=torch.tensor(input_ids).unsqueeze(0).to(model.device),
@@ -266,7 +265,8 @@ tokenizer = AutoTokenizer.from_pretrained(
266
  padding_side='left')
267
  cad_recode = CADRecode.from_pretrained(
268
  'filapro/cad-recode',
269
- torch_dtype='auto').eval()
 
270
 
271
  os.environ['TOKENIZERS_PARALLELISM'] = 'False'
272
  run()
 
174
  try:
175
  input_ids = [tokenizer.pad_token_id] * len(point_cloud) + [tokenizer('<|im_start|>')['input_ids'][0]]
176
  attention_mask = [-1] * len(point_cloud) + [1]
177
+ model = cad_recode.cuda()
 
178
  with torch.no_grad():
179
  batch_ids = cad_recode.generate(
180
  input_ids=torch.tensor(input_ids).unsqueeze(0).to(model.device),
 
265
  padding_side='left')
266
  cad_recode = CADRecode.from_pretrained(
267
  'filapro/cad-recode',
268
+ torch_dtype='auto',
269
+ attn_implementation='flash_attention_2').eval()
270
 
271
  os.environ['TOKENIZERS_PARALLELISM'] = 'False'
272
  run()