Spaces:
Runtime error
Runtime error
lemonaddie
commited on
Commit
•
9ba96aa
1
Parent(s):
f12775a
Update models/depth_normal_pipeline_clip.py
Browse files
models/depth_normal_pipeline_clip.py
CHANGED
@@ -79,6 +79,7 @@ class DepthNormalEstimationPipeline(DiffusionPipeline):
|
|
79 |
match_input_res:bool =True,
|
80 |
batch_size:int = 0,
|
81 |
domain: str = "indoor",
|
|
|
82 |
color_map: str="Spectral",
|
83 |
show_progress_bar:bool = True,
|
84 |
ensemble_kwargs: Dict = None,
|
@@ -147,6 +148,7 @@ class DepthNormalEstimationPipeline(DiffusionPipeline):
|
|
147 |
input_rgb=batched_image,
|
148 |
num_inference_steps=denoising_steps,
|
149 |
domain=domain,
|
|
|
150 |
show_pbar=show_progress_bar,
|
151 |
)
|
152 |
depth_pred_ls.append(depth_pred_raw.detach().clone())
|
@@ -230,6 +232,7 @@ class DepthNormalEstimationPipeline(DiffusionPipeline):
|
|
230 |
def single_infer(self,input_rgb:torch.Tensor,
|
231 |
num_inference_steps:int,
|
232 |
domain:str,
|
|
|
233 |
show_pbar:bool,):
|
234 |
|
235 |
device = input_rgb.device
|
@@ -242,6 +245,8 @@ class DepthNormalEstimationPipeline(DiffusionPipeline):
|
|
242 |
rgb_latent = self.encode_RGB(input_rgb)
|
243 |
|
244 |
# Initial depth map (Guassian noise)
|
|
|
|
|
245 |
geo_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype).repeat(2,1,1,1)
|
246 |
rgb_latent = rgb_latent.repeat(2,1,1,1)
|
247 |
|
|
|
79 |
match_input_res:bool =True,
|
80 |
batch_size:int = 0,
|
81 |
domain: str = "indoor",
|
82 |
+
seed: int = 0,
|
83 |
color_map: str="Spectral",
|
84 |
show_progress_bar:bool = True,
|
85 |
ensemble_kwargs: Dict = None,
|
|
|
148 |
input_rgb=batched_image,
|
149 |
num_inference_steps=denoising_steps,
|
150 |
domain=domain,
|
151 |
+
seed=seed,
|
152 |
show_pbar=show_progress_bar,
|
153 |
)
|
154 |
depth_pred_ls.append(depth_pred_raw.detach().clone())
|
|
|
232 |
def single_infer(self,input_rgb:torch.Tensor,
|
233 |
num_inference_steps:int,
|
234 |
domain:str,
|
235 |
+
seed: int,
|
236 |
show_pbar:bool,):
|
237 |
|
238 |
device = input_rgb.device
|
|
|
245 |
rgb_latent = self.encode_RGB(input_rgb)
|
246 |
|
247 |
# Initial depth map (Guassian noise)
|
248 |
+
if seed >= 0:
|
249 |
+
torch.manual_seed(0)
|
250 |
geo_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype).repeat(2,1,1,1)
|
251 |
rgb_latent = rgb_latent.repeat(2,1,1,1)
|
252 |
|