brandonsmart commited on
Commit
c65a25a
1 Parent(s): e059ab1

Attempting to solve pickle issue

Browse files
Files changed (1) hide show
  1. demo.py +6 -5
demo.py CHANGED
@@ -23,8 +23,10 @@ from mast3r.utils.misc import hash_md5
23
  import main
24
  import utils.export as export
25
 
26
- @spaces.GPU(duration=10)
27
- def get_reconstructed_scene(outdir, model, device, silent, image_size, ios_mode, filelist):
 
 
28
 
29
  assert len(filelist) == 1 or len(filelist) == 2, "Please provide one or two images"
30
  if ios_mode:
@@ -37,6 +39,7 @@ def get_reconstructed_scene(outdir, model, device, silent, image_size, ios_mode,
37
  img['img'] = img['img'].to(device)
38
  img['original_img'] = img['original_img'].to(device)
39
  img['true_shape'] = torch.from_numpy(img['true_shape'])
 
40
 
41
  output = model(imgs[0], imgs[1])
42
 
@@ -50,13 +53,11 @@ if __name__ == '__main__':
50
  image_size = 512
51
  silent = False
52
  ios_mode = True
53
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
54
 
55
  model_name = "brandonsmart/splatt3r_v1.0"
56
  filename = "epoch=19-step=1200.ckpt"
57
  weights_path = hf_hub_download(repo_id=model_name, filename=filename)
58
  model = main.MAST3RGaussians.load_from_checkpoint(weights_path, 'cpu')
59
- model = model.to(device)
60
  chkpt_tag = hash_md5(weights_path)
61
 
62
  # Define example inputs and their corresponding precalculated outputs
@@ -88,7 +89,7 @@ if __name__ == '__main__':
88
  cache_path = os.path.join(tmpdirname, chkpt_tag)
89
  os.makedirs(cache_path, exist_ok=True)
90
 
91
- recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, device, silent, image_size, ios_mode)
92
 
93
  if not ios_mode:
94
  for i in range(len(examples)):
 
23
  import main
24
  import utils.export as export
25
 
26
+ @spaces.GPU(duration=15)
27
+ def get_reconstructed_scene(outdir, model, silent, image_size, ios_mode, filelist):
28
+
29
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30
 
31
  assert len(filelist) == 1 or len(filelist) == 2, "Please provide one or two images"
32
  if ios_mode:
 
39
  img['img'] = img['img'].to(device)
40
  img['original_img'] = img['original_img'].to(device)
41
  img['true_shape'] = torch.from_numpy(img['true_shape'])
42
+ model = model.to(device)
43
 
44
  output = model(imgs[0], imgs[1])
45
 
 
53
  image_size = 512
54
  silent = False
55
  ios_mode = True
 
56
 
57
  model_name = "brandonsmart/splatt3r_v1.0"
58
  filename = "epoch=19-step=1200.ckpt"
59
  weights_path = hf_hub_download(repo_id=model_name, filename=filename)
60
  model = main.MAST3RGaussians.load_from_checkpoint(weights_path, 'cpu')
 
61
  chkpt_tag = hash_md5(weights_path)
62
 
63
  # Define example inputs and their corresponding precalculated outputs
 
89
  cache_path = os.path.join(tmpdirname, chkpt_tag)
90
  os.makedirs(cache_path, exist_ok=True)
91
 
92
+ recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, silent, image_size, ios_mode)
93
 
94
  if not ios_mode:
95
  for i in range(len(examples)):