zouzx commited on
Commit
cb6c3df
1 Parent(s): 436ca04

update save path

Browse files
Files changed (1) hide show
  1. app.py +26 -16
app.py CHANGED
@@ -27,6 +27,8 @@ MODEL_CKPT_PATH = "code/checkpoints/tgs_lvis_100v_rel.ckpt"
27
  CONFIG = "code/configs/single-rel.yaml"
28
  EXP_ROOT_DIR = "./outputs-gradio"
29
 
 
 
30
  gpu = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
31
  device = "cuda:{}".format(gpu) if torch.cuda.is_available() else "cpu"
32
 
@@ -52,7 +54,11 @@ HEADER = """
52
 
53
  TGS enables fast reconstruction from single-view image in a few seconds based on a hybrid Triplane-Gaussian 3D representation.
54
 
55
- This model is trained on Objaverse-LVIS (~40K synthetic objects) only. And note that we normalize the input camera pose to a pre-set viewpoint during training stage following LRM, rather than directly using camera pose of input camera as implemented in our original paper.
 
 
 
 
56
  """
57
 
58
  def preprocess(image_path, save_path=None, lower_contrast=False):
@@ -71,41 +77,43 @@ def preprocess(image_path, save_path=None, lower_contrast=False):
71
  return save_path
72
 
73
  def init_trial_dir():
74
- if not os.path.exists(EXP_ROOT_DIR):
75
- os.makedirs(EXP_ROOT_DIR, exist_ok=True)
76
  trial_dir = tempfile.TemporaryDirectory(dir=EXP_ROOT_DIR).name
77
- system.set_save_dir(trial_dir)
78
  return trial_dir
79
 
80
  @torch.no_grad()
81
  def infer(image_path: str,
82
  cam_dist: float,
 
83
  only_3dgs: bool = False):
84
  data_cfg = deepcopy(base_cfg.data)
85
  data_cfg.only_3dgs = only_3dgs
86
  data_cfg.cond_camera_distance = cam_dist
 
87
  data_cfg.image_list = [image_path]
88
  dm = tgs.find(base_cfg.data_cls)(data_cfg)
89
 
90
  dm.setup()
91
  for batch_idx, batch in enumerate(dm.test_dataloader()):
92
  batch = todevice(batch, device)
93
- system.test_step(batch, batch_idx, save_3dgs=only_3dgs)
94
  if not only_3dgs:
95
- system.on_test_epoch_end()
96
 
97
  def run(image_path: str,
98
- cam_dist: float):
99
- infer(image_path, cam_dist, only_3dgs=True)
100
- save_path = system.get_save_dir()
101
  gs = glob.glob(os.path.join(save_path, "*.ply"))[0]
 
102
  return gs
103
 
104
  def run_video(image_path: str,
105
- cam_dist: float):
106
- infer(image_path, cam_dist)
107
- save_path = system.get_save_dir()
108
  video = glob.glob(os.path.join(save_path, "*.mp4"))[0]
 
109
  return video
110
 
111
  def launch(port):
@@ -118,8 +126,8 @@ def launch(port):
118
  with gr.Row(variant='panel'):
119
  with gr.Column(scale=1):
120
  input_image = gr.Image(value=None, width=512, height=512, type="filepath", label="Input Image")
121
- camera_dist_slider = gr.Slider(1.0, 4.0, value=1.6, step=0.1, label="Camera Distance")
122
- img_run_btn = gr.Button("Reconstruction")
123
 
124
  gr.Examples(
125
  examples=[
@@ -148,6 +156,7 @@ def launch(port):
148
  output_video = gr.Video(value=None, width="auto", label="Rendered Video", autoplay=True)
149
  output_3dgs = Model3DGS(value=None, label="3D Model")
150
 
 
151
  img_run_btn.click(
152
  fn=preprocess,
153
  inputs=[input_image],
@@ -155,13 +164,14 @@ def launch(port):
155
  concurrency_limit=1,
156
  ).success(
157
  fn=init_trial_dir,
 
158
  concurrency_limit=1,
159
  ).success(fn=run,
160
- inputs=[seg_image, camera_dist_slider],
161
  outputs=[output_3dgs],
162
  concurrency_limit=1
163
  ).success(fn=run_video,
164
- inputs=[seg_image, camera_dist_slider],
165
  outputs=[output_video],
166
  concurrency_limit=1)
167
 
 
27
  CONFIG = "code/configs/single-rel.yaml"
28
  EXP_ROOT_DIR = "./outputs-gradio"
29
 
30
+ os.makedirs(EXP_ROOT_DIR, exist_ok=True)
31
+
32
  gpu = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
33
  device = "cuda:{}".format(gpu) if torch.cuda.is_available() else "cpu"
34
 
 
54
 
55
  TGS enables fast reconstruction from single-view image in a few seconds based on a hybrid Triplane-Gaussian 3D representation.
56
 
57
+ This model is trained on Objaverse-LVIS (**~40K** synthetic objects) only. And note that we normalize the input camera pose to a pre-set viewpoint during training stage following LRM, rather than directly using camera pose of input camera as implemented in our original paper.
58
+
59
+ **Tips:**
60
+ 1. If you find the result is unsatisfied, please try to change the camera distance. It perhaps improves the results.
61
+ 2. Please wait until the completion of the reconstruction of the previous model before proceeding with the next one, otherwise, it may cause bug. We will fix it soon.
62
  """
63
 
64
  def preprocess(image_path, save_path=None, lower_contrast=False):
 
77
  return save_path
78
 
79
  def init_trial_dir():
 
 
80
  trial_dir = tempfile.TemporaryDirectory(dir=EXP_ROOT_DIR).name
81
+ os.makedirs(trial_dir, exist_ok=True)
82
  return trial_dir
83
 
84
  @torch.no_grad()
85
  def infer(image_path: str,
86
  cam_dist: float,
87
+ save_path: str,
88
  only_3dgs: bool = False):
89
  data_cfg = deepcopy(base_cfg.data)
90
  data_cfg.only_3dgs = only_3dgs
91
  data_cfg.cond_camera_distance = cam_dist
92
+ data_cfg.eval_camera_distance = cam_dist
93
  data_cfg.image_list = [image_path]
94
  dm = tgs.find(base_cfg.data_cls)(data_cfg)
95
 
96
  dm.setup()
97
  for batch_idx, batch in enumerate(dm.test_dataloader()):
98
  batch = todevice(batch, device)
99
+ system.test_step(save_path, batch, batch_idx, save_3dgs=only_3dgs)
100
  if not only_3dgs:
101
+ system.on_test_epoch_end(save_path)
102
 
103
  def run(image_path: str,
104
+ cam_dist: float,
105
+ save_path: str):
106
+ infer(image_path, cam_dist, save_path, only_3dgs=True)
107
  gs = glob.glob(os.path.join(save_path, "*.ply"))[0]
108
+ # print("save gs", gs)
109
  return gs
110
 
111
  def run_video(image_path: str,
112
+ cam_dist: float,
113
+ save_path: str):
114
+ infer(image_path, cam_dist, save_path)
115
  video = glob.glob(os.path.join(save_path, "*.mp4"))[0]
116
+ # print("save video", video)
117
  return video
118
 
119
  def launch(port):
 
126
  with gr.Row(variant='panel'):
127
  with gr.Column(scale=1):
128
  input_image = gr.Image(value=None, width=512, height=512, type="filepath", label="Input Image")
129
+ camera_dist_slider = gr.Slider(1.0, 4.0, value=1.9, step=0.1, label="Camera Distance")
130
+ img_run_btn = gr.Button("Reconstruction", variant="primary")
131
 
132
  gr.Examples(
133
  examples=[
 
156
  output_video = gr.Video(value=None, width="auto", label="Rendered Video", autoplay=True)
157
  output_3dgs = Model3DGS(value=None, label="3D Model")
158
 
159
+ trial_dir = gr.State()
160
  img_run_btn.click(
161
  fn=preprocess,
162
  inputs=[input_image],
 
164
  concurrency_limit=1,
165
  ).success(
166
  fn=init_trial_dir,
167
+ outputs=[trial_dir],
168
  concurrency_limit=1,
169
  ).success(fn=run,
170
+ inputs=[seg_image, camera_dist_slider, trial_dir],
171
  outputs=[output_3dgs],
172
  concurrency_limit=1
173
  ).success(fn=run_video,
174
+ inputs=[seg_image, camera_dist_slider, trial_dir],
175
  outputs=[output_video],
176
  concurrency_limit=1)
177