lemonaddie commited on
Commit
aae9b16
·
verified ·
1 Parent(s): b6ddde2

Delete models/depth_normal_pipeline.py

Browse files
Files changed (1) hide show
  1. models/depth_normal_pipeline.py +0 -361
models/depth_normal_pipeline.py DELETED
@@ -1,361 +0,0 @@
1
- # A reimplemented version in public environments by Xiao Fu and Mu Hu
2
-
3
- from typing import Any, Dict, Union
4
-
5
- import torch
6
- from torch.utils.data import DataLoader, TensorDataset
7
- import numpy as np
8
- from tqdm.auto import tqdm
9
- from PIL import Image
10
- from diffusers import (
11
- DiffusionPipeline,
12
- DDIMScheduler,
13
- AutoencoderKL,
14
- )
15
- from models.unet_2d_condition import UNet2DConditionModel
16
- from diffusers.utils import BaseOutput
17
- from transformers import CLIPTextModel, CLIPTokenizer
18
-
19
- from utils.image_util import resize_max_res,chw2hwc,colorize_depth_maps
20
- from utils.colormap import kitti_colormap
21
- from utils.depth_ensemble import ensemble_depths
22
- from utils.batch_size import find_batch_size
23
- import cv2
24
-
25
- class DepthNormalPipelineOutput(BaseOutput):
26
- """
27
- Output class for Marigold monocular depth prediction pipeline.
28
-
29
- Args:
30
- depth_np (`np.ndarray`):
31
- Predicted depth map, with depth values in the range of [0, 1].
32
- depth_colored (`PIL.Image.Image`):
33
- Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
34
- normal_np (`np.ndarray`):
35
- Predicted normal map, with depth values in the range of [0, 1].
36
- normal_colored (`PIL.Image.Image`):
37
- Colorized normal map, with the shape of [3, H, W] and values in [0, 1].
38
- uncertainty (`None` or `np.ndarray`):
39
- Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
40
- """
41
- depth_np: np.ndarray
42
- depth_colored: Image.Image
43
- normal_np: np.ndarray
44
- normal_colored: Image.Image
45
- uncertainty: Union[None, np.ndarray]
46
-
47
- class DepthNormalEstimationPipeline(DiffusionPipeline):
48
- # two hyper-parameters
49
- latent_scale_factor = 0.18215
50
-
51
- def __init__(self,
52
- unet:UNet2DConditionModel,
53
- vae:AutoencoderKL,
54
- scheduler:DDIMScheduler,
55
- text_encoder:CLIPTextModel,
56
- tokenizer:CLIPTokenizer,
57
- ):
58
- super().__init__()
59
-
60
- self.register_modules(
61
- unet=unet,
62
- vae=vae,
63
- scheduler=scheduler,
64
- text_encoder=text_encoder,
65
- tokenizer=tokenizer,
66
- )
67
- self.empty_text_embed = None
68
-
69
- @torch.no_grad()
70
- def __call__(self,
71
- input_image:Image,
72
- denoising_steps: int = 10,
73
- ensemble_size: int = 10,
74
- processing_res: int = 768,
75
- match_input_res:bool =True,
76
- batch_size:int = 0,
77
- domain: str = "indoor",
78
- color_map: str="Spectral",
79
- show_progress_bar:bool = True,
80
- ensemble_kwargs: Dict = None,
81
- ) -> DepthNormalPipelineOutput:
82
-
83
- # inherit from thea Diffusion Pipeline
84
- device = self.device
85
- input_size = input_image.size
86
-
87
- # adjust the input resolution.
88
- if not match_input_res:
89
- assert (
90
- processing_res is not None
91
- )," Value Error: `resize_output_back` is only valid with "
92
-
93
- assert processing_res >=0
94
- assert denoising_steps >=1
95
- assert ensemble_size >=1
96
-
97
- # --------------- Image Processing ------------------------
98
- # Resize image
99
- if processing_res >0:
100
- input_image = resize_max_res(
101
- input_image, max_edge_resolution=processing_res
102
- )
103
-
104
- # Convert the image to RGB, to 1. reomve the alpha channel.
105
- input_image = input_image.convert("RGB")
106
- image = np.array(input_image)
107
-
108
- # Normalize RGB Values.
109
- rgb = np.transpose(image,(2,0,1))
110
- rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
111
- rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
112
- rgb_norm = rgb_norm.to(device)
113
-
114
- assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
115
-
116
- # ----------------- predicting depth -----------------
117
- duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
118
- single_rgb_dataset = TensorDataset(duplicated_rgb)
119
-
120
- # find the batch size
121
- if batch_size>0:
122
- _bs = batch_size
123
- else:
124
- _bs = 1
125
-
126
- single_rgb_loader = DataLoader(single_rgb_dataset, batch_size=_bs, shuffle=False)
127
-
128
- # predicted the depth
129
- depth_pred_ls = []
130
- normal_pred_ls = []
131
-
132
- if show_progress_bar:
133
- iterable_bar = tqdm(
134
- single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
135
- )
136
- else:
137
- iterable_bar = single_rgb_loader
138
-
139
- for batch in iterable_bar:
140
- (batched_image, )= batch # here the image is still around 0-1
141
-
142
- depth_pred_raw, normal_pred_raw = self.single_infer(
143
- input_rgb=batched_image,
144
- num_inference_steps=denoising_steps,
145
- domain=domain,
146
- show_pbar=show_progress_bar,
147
- )
148
- depth_pred_ls.append(depth_pred_raw.detach().clone())
149
- normal_pred_ls.append(normal_pred_raw.detach().clone())
150
-
151
- depth_preds = torch.concat(depth_pred_ls, axis=0).squeeze() #(10,224,768)
152
- normal_preds = torch.concat(normal_pred_ls, axis=0).squeeze()
153
- torch.cuda.empty_cache() # clear vram cache for ensembling
154
-
155
- # ----------------- Test-time ensembling -----------------
156
- if ensemble_size > 1:
157
- depth_pred, pred_uncert = ensemble_depths(
158
- depth_preds, **(ensemble_kwargs or {})
159
- )
160
- normal_pred = normal_preds[0]
161
- else:
162
- depth_pred = depth_preds
163
- normal_pred = normal_preds
164
- pred_uncert = None
165
-
166
- # ----------------- Post processing -----------------
167
- # Scale prediction to [0, 1]
168
- min_d = torch.min(depth_pred)
169
- max_d = torch.max(depth_pred)
170
- depth_pred = (depth_pred - min_d) / (max_d - min_d)
171
-
172
- # Convert to numpy
173
- depth_pred = depth_pred.cpu().numpy().astype(np.float32)
174
- normal_pred = normal_pred.cpu().numpy().astype(np.float32)
175
-
176
- # Resize back to original resolution
177
- if match_input_res:
178
- pred_img = Image.fromarray(depth_pred)
179
- pred_img = pred_img.resize(input_size)
180
- depth_pred = np.asarray(pred_img)
181
- normal_pred = cv2.resize(chw2hwc(normal_pred), input_size, interpolation = cv2.INTER_NEAREST)
182
-
183
- # Clip output range: current size is the original size
184
- depth_pred = depth_pred.clip(0, 1)
185
- normal_pred = normal_pred.clip(-1, 1)
186
-
187
- # Colorize
188
- depth_colored = colorize_depth_maps(
189
- depth_pred, 0, 1, cmap=color_map
190
- ).squeeze() # [3, H, W], value in (0, 1)
191
- depth_colored = (depth_colored * 255).astype(np.uint8)
192
- depth_colored_hwc = chw2hwc(depth_colored)
193
- depth_colored_img = Image.fromarray(depth_colored_hwc)
194
-
195
- normal_colored = ((normal_pred + 1)/2 * 255).astype(np.uint8)
196
- normal_colored_img = Image.fromarray(normal_colored)
197
-
198
- return DepthNormalPipelineOutput(
199
- depth_np = depth_pred,
200
- depth_colored = depth_colored_img,
201
- normal_np = normal_pred,
202
- normal_colored = normal_colored_img,
203
- uncertainty=pred_uncert,
204
- )
205
-
206
- def __encode_empty_text(self):
207
- """
208
- Encode text embedding for empty prompt
209
- """
210
- prompt = ""
211
- text_inputs = self.tokenizer(
212
- prompt,
213
- padding="do_not_pad",
214
- max_length=self.tokenizer.model_max_length,
215
- truncation=True,
216
- return_tensors="pt",
217
- )
218
- text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) #[1,2]
219
- # print(text_input_ids.shape)
220
- self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) #[1,2,1024]
221
-
222
-
223
- @torch.no_grad()
224
- def single_infer(self,input_rgb:torch.Tensor,
225
- num_inference_steps:int,
226
- domain:str,
227
- show_pbar:bool,):
228
-
229
- device = input_rgb.device
230
-
231
- # Set timesteps: inherit from the diffuison pipeline
232
- self.scheduler.set_timesteps(num_inference_steps, device=device) # here the numbers of the steps is only 10.
233
- timesteps = self.scheduler.timesteps # [T]
234
-
235
- # encode image
236
- rgb_latent = self.encode_RGB(input_rgb)
237
-
238
- # Initial depth map (Guassian noise)
239
- geo_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype).repeat(2,1,1,1)
240
- rgb_latent = rgb_latent.repeat(2,1,1,1)
241
-
242
- # Batched empty text embedding
243
- if self.empty_text_embed is None:
244
- self.__encode_empty_text()
245
-
246
- batch_empty_text_embed = self.empty_text_embed.repeat(
247
- (rgb_latent.shape[0], 1, 1)
248
- ) # [B, 2, 1024]
249
-
250
- # hybrid hierarchical switcher
251
- geo_class = torch.tensor([[0., 1.], [1, 0]], device=device, dtype=self.dtype)
252
- geo_embedding = torch.cat([torch.sin(geo_class), torch.cos(geo_class)], dim=-1)
253
-
254
- if domain == "indoor":
255
- domain_class = torch.tensor([[1., 0., 0]], device=device, dtype=self.dtype).repeat(2,1)
256
- elif domain == "outdoor":
257
- domain_class = torch.tensor([[0., 1., 0]], device=device, dtype=self.dtype).repeat(2,1)
258
- elif domain == "object":
259
- domain_class = torch.tensor([[0., 0., 1]], device=device, dtype=self.dtype).repeat(2,1)
260
- domain_embedding = torch.cat([torch.sin(domain_class), torch.cos(domain_class)], dim=-1)
261
-
262
- class_embedding = torch.cat((geo_embedding, domain_embedding), dim=-1)
263
-
264
- # Denoising loop
265
- if show_pbar:
266
- iterable = tqdm(
267
- enumerate(timesteps),
268
- total=len(timesteps),
269
- leave=False,
270
- desc=" " * 4 + "Diffusion denoising",
271
- )
272
- else:
273
- iterable = enumerate(timesteps)
274
-
275
- for i, t in iterable:
276
- unet_input = torch.cat([rgb_latent, geo_latent], dim=1)
277
-
278
- # predict the noise residual
279
- noise_pred = self.unet(
280
- unet_input, t.repeat(2), encoder_hidden_states=batch_empty_text_embed, class_labels=class_embedding
281
- ).sample # [B, 4, h, w]
282
-
283
- # compute the previous noisy sample x_t -> x_t-1
284
- geo_latent = self.scheduler.step(noise_pred, t, geo_latent).prev_sample
285
-
286
- geo_latent = geo_latent
287
- torch.cuda.empty_cache()
288
-
289
- depth = self.decode_depth(geo_latent[0][None])
290
- depth = torch.clip(depth, -1.0, 1.0)
291
- depth = (depth + 1.0) / 2.0
292
-
293
- normal = self.decode_normal(geo_latent[1][None])
294
- normal /= (torch.norm(normal, p=2, dim=1, keepdim=True)+1e-5)
295
-
296
- return depth, normal
297
-
298
-
299
- def encode_RGB(self, rgb_in: torch.Tensor) -> torch.Tensor:
300
- """
301
- Encode RGB image into latent.
302
-
303
- Args:
304
- rgb_in (`torch.Tensor`):
305
- Input RGB image to be encoded.
306
-
307
- Returns:
308
- `torch.Tensor`: Image latent.
309
- """
310
-
311
- # encode
312
- h = self.vae.encoder(rgb_in)
313
-
314
- moments = self.vae.quant_conv(h)
315
- mean, logvar = torch.chunk(moments, 2, dim=1)
316
- # scale latent
317
- rgb_latent = mean * self.latent_scale_factor
318
-
319
- return rgb_latent
320
-
321
- def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
322
- """
323
- Decode depth latent into depth map.
324
-
325
- Args:
326
- depth_latent (`torch.Tensor`):
327
- Depth latent to be decoded.
328
-
329
- Returns:
330
- `torch.Tensor`: Decoded depth map.
331
- """
332
-
333
- # scale latent
334
- depth_latent = depth_latent / self.latent_scale_factor
335
- # decode
336
- z = self.vae.post_quant_conv(depth_latent)
337
- stacked = self.vae.decoder(z)
338
- # mean of output channels
339
- depth_mean = stacked.mean(dim=1, keepdim=True)
340
- return depth_mean
341
-
342
- def decode_normal(self, normal_latent: torch.Tensor) -> torch.Tensor:
343
- """
344
- Decode normal latent into normal map.
345
-
346
- Args:
347
- normal_latent (`torch.Tensor`):
348
- Depth latent to be decoded.
349
-
350
- Returns:
351
- `torch.Tensor`: Decoded normal map.
352
- """
353
-
354
- # scale latent
355
- normal_latent = normal_latent / self.latent_scale_factor
356
- # decode
357
- z = self.vae.post_quant_conv(normal_latent)
358
- normal = self.vae.decoder(z)
359
- return normal
360
-
361
-