cbensimon HF staff commited on
Commit
0d0baeb
·
verified ·
1 Parent(s): 9ae28d8

Update freesplatter/webui/runner.py

Browse files
Files changed (1) hide show
  1. freesplatter/webui/runner.py +17 -16
freesplatter/webui/runner.py CHANGED
@@ -1,4 +1,3 @@
1
- import spaces
2
  import os
3
  import json
4
  import uuid
@@ -157,17 +156,14 @@ class FreeSplatterRunner:
157
  image,
158
  do_rembg=True,
159
  ):
160
- torch.cuda.empty_cache()
161
-
162
  if do_rembg:
163
  image = remove_background(image, self.rembg)
164
 
165
  return image
166
 
167
- @spaces.GPU
168
  def run_img_to_3d(
169
  self,
170
- image_rgba,
171
  model='Zero123++ v1.2',
172
  diffusion_steps=30,
173
  guidance_scale=4.0,
@@ -177,7 +173,10 @@ class FreeSplatterRunner:
177
  mesh_reduction=0.5,
178
  cache_dir=None,
179
  ):
180
- torch.cuda.empty_cache()
 
 
 
181
 
182
  self.output_dir = os.path.join(cache_dir, f'output_{uuid.uuid4()}')
183
  os.makedirs(self.output_dir, exist_ok=True)
@@ -226,6 +225,10 @@ class FreeSplatterRunner:
226
  images = images[[0, 2, 4, 5, 3, 1]]
227
  alphas = alphas[[0, 2, 4, 5, 3, 1]]
228
  images_vis = v2.functional.to_pil_image(rearrange(images, 'nm c h w -> c h (nm w)'))
 
 
 
 
229
  images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
230
  alphas = v2.functional.resize(alphas, 512, interpolation=0, antialias=True).clamp(0, 1)
231
 
@@ -237,12 +240,12 @@ class FreeSplatterRunner:
237
  images, alphas = images[view_indices], alphas[view_indices]
238
  legends = [f'V{i}' if i != 0 else 'Input' for i in view_indices]
239
 
240
- gs_vis_path, video_path, mesh_fine_path, fig = self.run_freesplatter_object(
241
- images, alphas, legends=legends, gs_type=gs_type, mesh_reduction=mesh_reduction)
 
 
242
 
243
- return images_vis, gs_vis_path, video_path, mesh_fine_path, fig
244
 
245
- @spaces.GPU
246
  def run_views_to_3d(
247
  self,
248
  image_files,
@@ -251,7 +254,6 @@ class FreeSplatterRunner:
251
  mesh_reduction=0.5,
252
  cache_dir=None,
253
  ):
254
- torch.cuda.empty_cache()
255
 
256
  self.output_dir = os.path.join(cache_dir, f'output_{uuid.uuid4()}')
257
  os.makedirs(self.output_dir, exist_ok=True)
@@ -300,7 +302,6 @@ class FreeSplatterRunner:
300
  gs_type='2DGS',
301
  mesh_reduction=0.5,
302
  ):
303
- torch.cuda.empty_cache()
304
  device = self.device
305
 
306
  freesplatter = self.freesplatter_2dgs if gs_type == '2DGS' else self.freesplatter
@@ -316,11 +317,13 @@ class FreeSplatterRunner:
316
  c2ws_pred, focals_pred = freesplatter.estimate_poses(images, gaussians, masks=alphas, use_first_focal=True, pnp_iter=10)
317
  fig = self.visualize_cameras_object(images, c2ws_pred, focals_pred, legends=legends)
318
  t2 = time.time()
 
319
 
320
  # save gaussians
321
  gs_vis_path = os.path.join(self.output_dir, 'gs_vis.ply')
322
  save_gaussian(gaussians, gs_vis_path, freesplatter, opacity_threshold=5e-3, pad_2dgs_scale=True)
323
  print(f'Save gaussian at {gs_vis_path}')
 
324
 
325
  # render video
326
  with torch.inference_mode():
@@ -339,6 +342,7 @@ class FreeSplatterRunner:
339
  save_video(video_frames, video_path, fps=30)
340
  print(f'Save video at {video_path}')
341
  t3 = time.time()
 
342
 
343
  # extract mesh
344
  with torch.inference_mode():
@@ -454,7 +458,7 @@ class FreeSplatterRunner:
454
  print(f'Generate mesh: {t4-t3:.2f} seconds.')
455
  print(f'Optimize mesh: {t5-t4:.2f} seconds.')
456
 
457
- return gs_vis_path, video_path, mesh_fine_path, fig
458
 
459
  def visualize_cameras_object(
460
  self,
@@ -494,14 +498,12 @@ class FreeSplatterRunner:
494
  return fig
495
 
496
  # FreeSplatter-S
497
- @spaces.GPU
498
  def run_views_to_scene(
499
  self,
500
  image1,
501
  image2,
502
  cache_dir=None,
503
  ):
504
- torch.cuda.empty_cache()
505
 
506
  self.output_dir = os.path.join(cache_dir, f'output_{uuid.uuid4()}')
507
  os.makedirs(self.output_dir, exist_ok=True)
@@ -531,7 +533,6 @@ class FreeSplatterRunner:
531
  images,
532
  legends=None,
533
  ):
534
- torch.cuda.empty_cache()
535
 
536
  freesplatter = self.freesplatter_scene
537
 
 
 
1
  import os
2
  import json
3
  import uuid
 
156
  image,
157
  do_rembg=True,
158
  ):
 
 
159
  if do_rembg:
160
  image = remove_background(image, self.rembg)
161
 
162
  return image
163
 
 
164
  def run_img_to_3d(
165
  self,
166
+ image,
167
  model='Zero123++ v1.2',
168
  diffusion_steps=30,
169
  guidance_scale=4.0,
 
173
  mesh_reduction=0.5,
174
  cache_dir=None,
175
  ):
176
+ image_rgba = self.run_segmentation(image)
177
+
178
+ res = [image_rgba]
179
+ yield res + [None] * (6 - len(res))
180
 
181
  self.output_dir = os.path.join(cache_dir, f'output_{uuid.uuid4()}')
182
  os.makedirs(self.output_dir, exist_ok=True)
 
225
  images = images[[0, 2, 4, 5, 3, 1]]
226
  alphas = alphas[[0, 2, 4, 5, 3, 1]]
227
  images_vis = v2.functional.to_pil_image(rearrange(images, 'nm c h w -> c h (nm w)'))
228
+
229
+ res += [images_vis]
230
+ yield res + [None] * (6 - len(res))
231
+
232
  images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
233
  alphas = v2.functional.resize(alphas, 512, interpolation=0, antialias=True).clamp(0, 1)
234
 
 
240
  images, alphas = images[view_indices], alphas[view_indices]
241
  legends = [f'V{i}' if i != 0 else 'Input' for i in view_indices]
242
 
243
+ for item in self.run_freesplatter_object(
244
+ images, alphas, legends=legends, gs_type=gs_type, mesh_reduction=mesh_reduction):
245
+ res += [item]
246
+ yield res + [None] * (6 - len(res))
247
 
 
248
 
 
249
  def run_views_to_3d(
250
  self,
251
  image_files,
 
254
  mesh_reduction=0.5,
255
  cache_dir=None,
256
  ):
 
257
 
258
  self.output_dir = os.path.join(cache_dir, f'output_{uuid.uuid4()}')
259
  os.makedirs(self.output_dir, exist_ok=True)
 
302
  gs_type='2DGS',
303
  mesh_reduction=0.5,
304
  ):
 
305
  device = self.device
306
 
307
  freesplatter = self.freesplatter_2dgs if gs_type == '2DGS' else self.freesplatter
 
317
  c2ws_pred, focals_pred = freesplatter.estimate_poses(images, gaussians, masks=alphas, use_first_focal=True, pnp_iter=10)
318
  fig = self.visualize_cameras_object(images, c2ws_pred, focals_pred, legends=legends)
319
  t2 = time.time()
320
+ yield fig
321
 
322
  # save gaussians
323
  gs_vis_path = os.path.join(self.output_dir, 'gs_vis.ply')
324
  save_gaussian(gaussians, gs_vis_path, freesplatter, opacity_threshold=5e-3, pad_2dgs_scale=True)
325
  print(f'Save gaussian at {gs_vis_path}')
326
+ yield gs_vis_path
327
 
328
  # render video
329
  with torch.inference_mode():
 
342
  save_video(video_frames, video_path, fps=30)
343
  print(f'Save video at {video_path}')
344
  t3 = time.time()
345
+ yield video_path
346
 
347
  # extract mesh
348
  with torch.inference_mode():
 
458
  print(f'Generate mesh: {t4-t3:.2f} seconds.')
459
  print(f'Optimize mesh: {t5-t4:.2f} seconds.')
460
 
461
+ yield mesh_fine_path
462
 
463
  def visualize_cameras_object(
464
  self,
 
498
  return fig
499
 
500
  # FreeSplatter-S
 
501
  def run_views_to_scene(
502
  self,
503
  image1,
504
  image2,
505
  cache_dir=None,
506
  ):
 
507
 
508
  self.output_dir = os.path.join(cache_dir, f'output_{uuid.uuid4()}')
509
  os.makedirs(self.output_dir, exist_ok=True)
 
533
  images,
534
  legends=None,
535
  ):
 
536
 
537
  freesplatter = self.freesplatter_scene
538