lemonaddie commited on
Commit
e82b0fa
·
verified ·
1 Parent(s): 5c03ece

Delete models/depth_normal_pipeline_clip_cfg.py

Browse files
models/depth_normal_pipeline_clip_cfg.py DELETED
@@ -1,375 +0,0 @@
1
- # A reimplemented version in public environments by Xiao Fu and Mu Hu
2
-
3
- from typing import Any, Dict, Union
4
-
5
- import torch
6
- from torch.utils.data import DataLoader, TensorDataset
7
- import numpy as np
8
- from tqdm.auto import tqdm
9
- from PIL import Image
10
- from diffusers import (
11
- DiffusionPipeline,
12
- DDIMScheduler,
13
- AutoencoderKL,
14
- )
15
- from models.unet_2d_condition import UNet2DConditionModel
16
- from diffusers.utils import BaseOutput
17
- from transformers import CLIPTextModel, CLIPTokenizer
18
- from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
19
- import torchvision.transforms.functional as TF
20
- from torchvision.transforms import InterpolationMode
21
-
22
- from utils.image_util import resize_max_res,chw2hwc,colorize_depth_maps
23
- from utils.colormap import kitti_colormap
24
- from utils.depth_ensemble import ensemble_depths
25
- from utils.normal_ensemble import ensemble_normals
26
- from utils.batch_size import find_batch_size
27
- import cv2
28
-
29
- class DepthNormalPipelineOutput(BaseOutput):
30
- """
31
- Output class for Marigold monocular depth prediction pipeline.
32
-
33
- Args:
34
- depth_np (`np.ndarray`):
35
- Predicted depth map, with depth values in the range of [0, 1].
36
- depth_colored (`PIL.Image.Image`):
37
- Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
38
- normal_np (`np.ndarray`):
39
- Predicted normal map, with depth values in the range of [0, 1].
40
- normal_colored (`PIL.Image.Image`):
41
- Colorized normal map, with the shape of [3, H, W] and values in [0, 1].
42
- uncertainty (`None` or `np.ndarray`):
43
- Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
44
- """
45
- depth_np: np.ndarray
46
- depth_colored: Image.Image
47
- normal_np: np.ndarray
48
- normal_colored: Image.Image
49
- uncertainty: Union[None, np.ndarray]
50
-
51
- class DepthNormalEstimationPipeline(DiffusionPipeline):
52
- # two hyper-parameters
53
- latent_scale_factor = 0.18215
54
-
55
- def __init__(self,
56
- unet:UNet2DConditionModel,
57
- vae:AutoencoderKL,
58
- scheduler:DDIMScheduler,
59
- image_encoder:CLIPVisionModelWithProjection,
60
- feature_extractor:CLIPImageProcessor,
61
- ):
62
- super().__init__()
63
-
64
- self.register_modules(
65
- unet=unet,
66
- vae=vae,
67
- scheduler=scheduler,
68
- image_encoder=image_encoder,
69
- feature_extractor=feature_extractor,
70
- )
71
- self.img_embed = None
72
-
73
- @torch.no_grad()
74
- def __call__(self,
75
- input_image:Image,
76
- denoising_steps: int = 10,
77
- ensemble_size: int = 10,
78
- guidance_scale: int = 1,
79
- processing_res: int = 768,
80
- match_input_res:bool =True,
81
- batch_size:int = 0,
82
- domain: str = "indoor",
83
- color_map: str="Spectral",
84
- show_progress_bar:bool = True,
85
- ensemble_kwargs: Dict = None,
86
- ) -> DepthNormalPipelineOutput:
87
-
88
- # inherit from thea Diffusion Pipeline
89
- device = self.device
90
- input_size = input_image.size
91
-
92
- # adjust the input resolution.
93
- if not match_input_res:
94
- assert (
95
- processing_res is not None
96
- )," Value Error: `resize_output_back` is only valid with "
97
-
98
- assert processing_res >=0
99
- assert denoising_steps >=1
100
- assert ensemble_size >=1
101
-
102
- # --------------- Image Processing ------------------------
103
- # Resize image
104
- if processing_res >0:
105
- input_image = resize_max_res(
106
- input_image, max_edge_resolution=processing_res
107
- )
108
-
109
- # Convert the image to RGB, to 1. reomve the alpha channel.
110
- input_image = input_image.convert("RGB")
111
- image = np.array(input_image)
112
-
113
- # Normalize RGB Values.
114
- rgb = np.transpose(image,(2,0,1))
115
- rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
116
- rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
117
- rgb_norm = rgb_norm.to(device)
118
-
119
- assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
120
-
121
- # ----------------- predicting depth -----------------
122
- duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
123
- single_rgb_dataset = TensorDataset(duplicated_rgb)
124
-
125
- # find the batch size
126
- if batch_size>0:
127
- _bs = batch_size
128
- else:
129
- _bs = 1
130
-
131
- single_rgb_loader = DataLoader(single_rgb_dataset, batch_size=_bs, shuffle=False)
132
-
133
- # predicted the depth
134
- depth_pred_ls = []
135
- normal_pred_ls = []
136
-
137
- if show_progress_bar:
138
- iterable_bar = tqdm(
139
- single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
140
- )
141
- else:
142
- iterable_bar = single_rgb_loader
143
-
144
- for batch in iterable_bar:
145
- (batched_image, )= batch # here the image is still around 0-1
146
-
147
- depth_pred_raw, normal_pred_raw = self.single_infer(
148
- input_rgb=batched_image,
149
- num_inference_steps=denoising_steps,
150
- domain=domain,
151
- guidance_scale=guidance_scale,
152
- show_pbar=show_progress_bar,
153
- )
154
- depth_pred_ls.append(depth_pred_raw.detach().clone())
155
- normal_pred_ls.append(normal_pred_raw.detach().clone())
156
-
157
- depth_preds = torch.concat(depth_pred_ls, axis=0).squeeze()
158
- normal_preds = torch.concat(normal_pred_ls, axis=0).squeeze()
159
- torch.cuda.empty_cache() # clear vram cache for ensembling
160
-
161
- # ----------------- Test-time ensembling -----------------
162
- if ensemble_size > 1:
163
- depth_pred, pred_uncert = ensemble_depths(
164
- depth_preds, **(ensemble_kwargs or {})
165
- )
166
- normal_pred = ensemble_normals(normal_preds)
167
- else:
168
- depth_pred = depth_preds
169
- normal_pred = normal_preds
170
- pred_uncert = None
171
-
172
- # ----------------- Post processing -----------------
173
- # Scale prediction to [0, 1]
174
- min_d = torch.min(depth_pred)
175
- max_d = torch.max(depth_pred)
176
- depth_pred = (depth_pred - min_d) / (max_d - min_d)
177
-
178
- # Convert to numpy
179
- depth_pred = depth_pred.cpu().numpy().astype(np.float32)
180
- normal_pred = normal_pred.cpu().numpy().astype(np.float32)
181
-
182
- # Resize back to original resolution
183
- if match_input_res:
184
- pred_img = Image.fromarray(depth_pred)
185
- pred_img = pred_img.resize(input_size)
186
- depth_pred = np.asarray(pred_img)
187
- normal_pred = cv2.resize(chw2hwc(normal_pred), input_size, interpolation = cv2.INTER_NEAREST)
188
-
189
- # Clip output range: current size is the original size
190
- depth_pred = depth_pred.clip(0, 1)
191
- normal_pred = normal_pred.clip(-1, 1)
192
-
193
- # Colorize
194
- depth_colored = colorize_depth_maps(
195
- depth_pred, 0, 1, cmap=color_map
196
- ).squeeze() # [3, H, W], value in (0, 1)
197
- depth_colored = (depth_colored * 255).astype(np.uint8)
198
- depth_colored_hwc = chw2hwc(depth_colored)
199
- depth_colored_img = Image.fromarray(depth_colored_hwc)
200
-
201
- normal_colored = ((normal_pred + 1)/2 * 255).astype(np.uint8)
202
- normal_colored_img = Image.fromarray(normal_colored)
203
-
204
- return DepthNormalPipelineOutput(
205
- depth_np = depth_pred,
206
- depth_colored = depth_colored_img,
207
- normal_np = normal_pred,
208
- normal_colored = normal_colored_img,
209
- uncertainty=pred_uncert,
210
- )
211
-
212
- def __encode_img_embed(self, rgb):
213
- """
214
- Encode clip embeddings for img
215
- """
216
- clip_image_mean = torch.as_tensor(self.feature_extractor.image_mean)[:,None,None].to(device=self.device, dtype=self.dtype)
217
- clip_image_std = torch.as_tensor(self.feature_extractor.image_std)[:,None,None].to(device=self.device, dtype=self.dtype)
218
-
219
- img_in_proc = TF.resize((rgb +1)/2,
220
- (self.feature_extractor.crop_size['height'], self.feature_extractor.crop_size['width']),
221
- interpolation=InterpolationMode.BICUBIC,
222
- antialias=True
223
- )
224
- # do the normalization in float32 to preserve precision
225
- img_in_proc = ((img_in_proc.float() - clip_image_mean) / clip_image_std).to(self.dtype)
226
- img_embed = self.image_encoder(img_in_proc).image_embeds.unsqueeze(1).to(self.dtype)
227
-
228
- self.img_embed = img_embed
229
-
230
-
231
- @torch.no_grad()
232
- def single_infer(self,input_rgb:torch.Tensor,
233
- num_inference_steps:int,
234
- domain:str,
235
- guidance_scale: int,
236
- show_pbar:bool,):
237
-
238
- device = input_rgb.device
239
-
240
- # Set timesteps: inherit from the diffuison pipeline
241
- self.scheduler.set_timesteps(num_inference_steps, device=device) # here the numbers of the steps is only 10.
242
- timesteps = self.scheduler.timesteps # [T]
243
-
244
- # encode image
245
- rgb_latent = self.encode_RGB(input_rgb)
246
-
247
- # Initial depth map (Guassian noise)
248
- geo_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype).repeat(2,1,1,1)
249
- rgb_latent = rgb_latent.repeat(2,1,1,1)
250
-
251
- # Batched img embedding
252
- if self.img_embed is None:
253
- self.__encode_img_embed(input_rgb)
254
-
255
- batch_img_embed = self.img_embed.repeat(
256
- (rgb_latent.shape[0], 1, 1)
257
- ) # [B, 1, 768]
258
-
259
- batch_img_embed = torch.cat((torch.zeros_like(batch_img_embed), batch_img_embed), dim=0)
260
- rgb_latent = torch.cat((torch.zeros_like(rgb_latent), rgb_latent), dim=0)
261
-
262
- # hybrid switcher
263
- geo_class = torch.tensor([[0., 1.], [1, 0]], device=device, dtype=self.dtype)
264
- geo_embedding = torch.cat([torch.sin(geo_class), torch.cos(geo_class)], dim=-1)
265
-
266
- if domain == "indoor":
267
- domain_class = torch.tensor([[1., 0., 0]], device=device, dtype=self.dtype).repeat(2,1)
268
- elif domain == "outdoor":
269
- domain_class = torch.tensor([[0., 1., 0]], device=device, dtype=self.dtype).repeat(2,1)
270
- elif domain == "object":
271
- domain_class = torch.tensor([[0., 0., 1]], device=device, dtype=self.dtype).repeat(2,1)
272
- domain_embedding = torch.cat([torch.sin(domain_class), torch.cos(domain_class)], dim=-1)
273
-
274
- class_embedding = torch.cat((geo_embedding, domain_embedding), dim=-1)
275
-
276
- # Denoising loop
277
- if show_pbar:
278
- iterable = tqdm(
279
- enumerate(timesteps),
280
- total=len(timesteps),
281
- leave=False,
282
- desc=" " * 4 + "Diffusion denoising",
283
- )
284
- else:
285
- iterable = enumerate(timesteps)
286
-
287
- for i, t in iterable:
288
- unet_input = torch.cat((rgb_latent, geo_latent.repeat(2,1,1,1)), dim=1)
289
- # predict the noise residual
290
- noise_pred = self.unet(unet_input, t.repeat(4), encoder_hidden_states=batch_img_embed, class_labels=class_embedding.repeat(2,1)).sample
291
- noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
292
- #guidance_scale = 3.
293
- guidance_scale = guidance_scale
294
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
295
-
296
- # compute the previous noisy sample x_t -> x_t-1
297
- geo_latent = self.scheduler.step(noise_pred, t, geo_latent).prev_sample
298
-
299
- geo_latent = geo_latent
300
- torch.cuda.empty_cache()
301
-
302
- depth = self.decode_depth(geo_latent[0][None])
303
- depth = torch.clip(depth, -1.0, 1.0)
304
- depth = (depth + 1.0) / 2.0
305
-
306
- normal = self.decode_normal(geo_latent[1][None])
307
- normal /= (torch.norm(normal, p=2, dim=1, keepdim=True)+1e-5)
308
- normal *= -1.
309
-
310
- return depth, normal
311
-
312
-
313
- def encode_RGB(self, rgb_in: torch.Tensor) -> torch.Tensor:
314
- """
315
- Encode RGB image into latent.
316
-
317
- Args:
318
- rgb_in (`torch.Tensor`):
319
- Input RGB image to be encoded.
320
-
321
- Returns:
322
- `torch.Tensor`: Image latent.
323
- """
324
-
325
- # encode
326
- h = self.vae.encoder(rgb_in)
327
-
328
- moments = self.vae.quant_conv(h)
329
- mean, logvar = torch.chunk(moments, 2, dim=1)
330
- # scale latent
331
- rgb_latent = mean * self.latent_scale_factor
332
-
333
- return rgb_latent
334
-
335
- def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
336
- """
337
- Decode depth latent into depth map.
338
-
339
- Args:
340
- depth_latent (`torch.Tensor`):
341
- Depth latent to be decoded.
342
-
343
- Returns:
344
- `torch.Tensor`: Decoded depth map.
345
- """
346
-
347
- # scale latent
348
- depth_latent = depth_latent / self.latent_scale_factor
349
- # decode
350
- z = self.vae.post_quant_conv(depth_latent)
351
- stacked = self.vae.decoder(z)
352
- # mean of output channels
353
- depth_mean = stacked.mean(dim=1, keepdim=True)
354
- return depth_mean
355
-
356
- def decode_normal(self, normal_latent: torch.Tensor) -> torch.Tensor:
357
- """
358
- Decode normal latent into normal map.
359
-
360
- Args:
361
- normal_latent (`torch.Tensor`):
362
- Depth latent to be decoded.
363
-
364
- Returns:
365
- `torch.Tensor`: Decoded normal map.
366
- """
367
-
368
- # scale latent
369
- normal_latent = normal_latent / self.latent_scale_factor
370
- # decode
371
- z = self.vae.post_quant_conv(normal_latent)
372
- normal = self.vae.decoder(z)
373
- return normal
374
-
375
-