flamehaze1115 commited on
Commit
89a7ccb
1 Parent(s): 95584fd

Upload 19 files

Browse files

update_pipeline_code

mv_diffusion_30/data/depth_utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ import numpy as np
3
+ import torch
4
+
5
+ def colorize_depth_maps(
6
+ depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
7
+ ):
8
+ """
9
+ Colorize depth maps.
10
+ """
11
+ assert len(depth_map.shape) >= 2, "Invalid dimension"
12
+
13
+ if isinstance(depth_map, torch.Tensor):
14
+ depth = depth_map.detach().squeeze().numpy()
15
+ elif isinstance(depth_map, np.ndarray):
16
+ depth = depth_map.copy().squeeze()
17
+ # reshape to [ (B,) H, W ]
18
+ if depth.ndim < 3:
19
+ depth = depth[np.newaxis, :, :]
20
+
21
+ # colorize
22
+ cm = matplotlib.colormaps[cmap]
23
+ depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
24
+ img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
25
+ img_colored_np = np.rollaxis(img_colored_np, 3, 1)
26
+
27
+ if valid_mask is not None:
28
+ if isinstance(depth_map, torch.Tensor):
29
+ valid_mask = valid_mask.detach().numpy()
30
+ valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
31
+ if valid_mask.ndim < 3:
32
+ valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
33
+ else:
34
+ valid_mask = valid_mask[:, np.newaxis, :, :]
35
+ valid_mask = np.repeat(valid_mask, 3, axis=1)
36
+ img_colored_np[~valid_mask] = 0
37
+
38
+ if isinstance(depth_map, torch.Tensor):
39
+ img_colored = torch.from_numpy(img_colored_np).float()
40
+ elif isinstance(depth_map, np.ndarray):
41
+ img_colored = img_colored_np
42
+
43
+ return img_colored
44
+
45
+
46
+ def scale_depth_to_model(depth, camera_type='ortho'):
47
+ """
48
+ Scale depth from the original range.
49
+ """
50
+ assert camera_type == 'ortho' or camera_type == 'persp'
51
+ w, h = depth.shape
52
+
53
+ if camera_type == 'ortho':
54
+ original_min = 9000
55
+ original_max = 17000
56
+ target_min = 2000
57
+ target_max = 62000
58
+
59
+ mask = depth != 0
60
+ # Scale depth to [0, 1]
61
+ depth_normalized = np.zeros([w, h])
62
+ depth_normalized[mask] = (depth[mask] - original_min) / (original_max - original_min)
63
+
64
+ # Scale depth to [2000, 60000]
65
+ scaled_depth = np.zeros([w, h])
66
+ scaled_depth[mask] = depth_normalized[mask] * (target_max - target_min) + target_min
67
+
68
+ else:
69
+ original_min = 4000
70
+ original_max = 13000
71
+ target_min = 2000
72
+ target_max = 62000
73
+
74
+ mask = depth != 0
75
+ # Scale depth to [0, 1]
76
+ depth_normalized = np.zeros([w, h])
77
+ depth_normalized[mask] = (depth[mask] - original_min) / (original_max - original_min)
78
+
79
+ # Scale depth to [2000, 60000]
80
+ scaled_depth = np.zeros([w, h])
81
+ scaled_depth[mask] = depth_normalized[mask] * (target_max - target_min) + target_min
82
+
83
+ scaled_depth[scaled_depth > 62000] = 0
84
+ scaled_depth = scaled_depth / 65535. # [0, 1]
85
+
86
+ return scaled_depth
87
+
88
+ def rescale_depth_to_world(scaled_depth, camera_type='ortho'):
89
+ """
90
+ Rescale depth from the scaled range back to the original range.
91
+ """
92
+ assert camera_type == 'ortho' or camera_type == 'persp'
93
+ scaled_depth = scaled_depth * 65535.
94
+ w, h = scaled_depth.shape
95
+
96
+ if camera_type == 'ortho':
97
+ original_min = 9000
98
+ original_max = 17000
99
+ target_min = 2000
100
+ target_max = 62000
101
+
102
+ mask = scaled_depth != 0
103
+ rescaled_depth_norm = np.zeros([w, h])
104
+ # Rescale depth to [0, 1]
105
+ rescaled_depth_norm[mask] = (scaled_depth[mask] - target_min) / (target_max - target_min)
106
+
107
+ # Rescale depth to [9000, 17000]
108
+ rescaled_depth = np.zeros([w, h])
109
+ rescaled_depth[mask] = rescaled_depth_norm[mask] * (original_max - original_min) + original_min
110
+
111
+ else:
112
+ original_min = 4000
113
+ original_max = 13000
114
+ target_min = 2000
115
+ target_max = 62000
116
+
117
+ mask = scaled_depth != 0
118
+ rescaled_depth_norm = np.zeros([w, h])
119
+ # Rescale depth to [0, 1]
120
+ rescaled_depth_norm[mask] = (scaled_depth[mask] - target_min) / (target_max - target_min)
121
+
122
+ # Rescale depth to [9000, 17000]
123
+ rescaled_depth = np.zeros([w, h])
124
+ rescaled_depth[mask] = rescaled_depth_norm[mask] * (original_max - original_min) + original_min
125
+
126
+ return rescaled_depth
mv_diffusion_30/data/fixed_poses/nine_views.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a30afc8a8c757429716f3be7ee58e7a9a5e0fb5ec5cb4d106bc04e43550ac2b
3
+ size 7385
mv_diffusion_30/data/fixed_poses/nine_views/000_back_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ -5.266582965850830078e-01 7.410295009613037109e-01 -4.165407419204711914e-01 -5.960464477539062500e-08
2
+ 5.865638996738198330e-08 4.900035560131072998e-01 8.717204332351684570e-01 -9.462351613365171943e-08
3
+ 8.500770330429077148e-01 4.590988159179687500e-01 -2.580644786357879639e-01 -1.300000071525573730e+00
mv_diffusion_30/data/fixed_poses/nine_views/000_back_left_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ -9.734988808631896973e-01 1.993551850318908691e-01 -1.120596975088119507e-01 -1.713633537292480469e-07
2
+ 3.790224578636980368e-09 4.900034964084625244e-01 8.717204928398132324e-01 1.772203575001185527e-07
3
+ 2.286916375160217285e-01 8.486189246177673340e-01 -4.770178496837615967e-01 -1.838477611541748047e+00
mv_diffusion_30/data/fixed_poses/nine_views/000_back_right_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 2.286914736032485962e-01 8.486190438270568848e-01 -4.770178198814392090e-01 1.564621925354003906e-07
2
+ -3.417914484771245043e-08 4.900034070014953613e-01 8.717205524444580078e-01 -7.293811421504869941e-08
3
+ 9.734990000724792480e-01 -1.993550658226013184e-01 1.120596155524253845e-01 -1.838477969169616699e+00
mv_diffusion_30/data/fixed_poses/nine_views/000_front_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 5.266583561897277832e-01 -7.410295009613037109e-01 4.165407419204711914e-01 0.000000000000000000e+00
2
+ 5.865638996738198330e-08 4.900035560131072998e-01 8.717204332351684570e-01 9.462351613365171943e-08
3
+ -8.500770330429077148e-01 -4.590988159179687500e-01 2.580645382404327393e-01 -1.300000071525573730e+00
mv_diffusion_30/data/fixed_poses/nine_views/000_front_left_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ -2.286916971206665039e-01 -8.486189842224121094e-01 4.770179092884063721e-01 -2.458691596984863281e-07
2
+ 9.085837859856837895e-09 4.900034666061401367e-01 8.717205524444580078e-01 1.205695667749751010e-07
3
+ -9.734990000724792480e-01 1.993551701307296753e-01 -1.120597645640373230e-01 -1.838477969169616699e+00
mv_diffusion_30/data/fixed_poses/nine_views/000_front_right_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 9.734989404678344727e-01 -1.993551850318908691e-01 1.120596975088119507e-01 -1.415610313415527344e-07
2
+ 3.790224578636980368e-09 4.900034964084625244e-01 8.717204928398132324e-01 -1.772203575001185527e-07
3
+ -2.286916375160217285e-01 -8.486189246177673340e-01 4.770178794860839844e-01 -1.838477611541748047e+00
mv_diffusion_30/data/fixed_poses/nine_views/000_left_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ -8.500771522521972656e-01 -4.590989053249359131e-01 2.580644488334655762e-01 0.000000000000000000e+00
2
+ -4.257411134744870651e-08 4.900034964084625244e-01 8.717204928398132324e-01 9.006067358541258727e-08
3
+ -5.266583561897277832e-01 7.410295605659484863e-01 -4.165408313274383545e-01 -1.300000071525573730e+00
mv_diffusion_30/data/fixed_poses/nine_views/000_right_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 8.500770330429077148e-01 4.590989053249359131e-01 -2.580644488334655762e-01 5.960464477539062500e-08
2
+ -4.257411134744870651e-08 4.900034964084625244e-01 8.717204928398132324e-01 -9.006067358541258727e-08
3
+ 5.266583561897277832e-01 -7.410295605659484863e-01 4.165407419204711914e-01 -1.300000071525573730e+00
mv_diffusion_30/data/fixed_poses/nine_views/000_top_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 9.958608150482177734e-01 7.923202216625213623e-02 -4.453715682029724121e-02 -3.098167056236889039e-09
2
+ -9.089154005050659180e-02 8.681122064590454102e-01 -4.879753291606903076e-01 5.784738377201392723e-08
3
+ -2.028124157504862524e-08 4.900035560131072998e-01 8.717204332351684570e-01 -1.300000071525573730e+00
mv_diffusion_30/data/multiview_image_dataset.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import numpy as np
3
+ from omegaconf import DictConfig, ListConfig
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from pathlib import Path
7
+ import json
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+ from einops import rearrange
11
+ from typing import Literal, Tuple, Optional, Any
12
+ import cv2
13
+ import random
14
+
15
+ import json
16
+ import os, sys
17
+ import math
18
+
19
+ from glob import glob
20
+
21
+ import PIL.Image
22
+ from .normal_utils import trans_normal, normal2img, img2normal
23
+ import pdb
24
+
25
+
26
+ import cv2
27
+ import numpy as np
28
+
29
+ def add_margin(pil_img, color=0, size=256):
30
+ width, height = pil_img.size
31
+ result = Image.new(pil_img.mode, (size, size), color)
32
+ result.paste(pil_img, ((size - width) // 2, (size - height) // 2))
33
+ return result
34
+
35
+ def scale_and_place_object(image, scale_factor):
36
+ assert np.shape(image)[-1]==4 # RGBA
37
+
38
+ # Extract the alpha channel (transparency) and the object (RGB channels)
39
+ alpha_channel = image[:, :, 3]
40
+
41
+ # Find the bounding box coordinates of the object
42
+ coords = cv2.findNonZero(alpha_channel)
43
+ x, y, width, height = cv2.boundingRect(coords)
44
+
45
+ # Calculate the scale factor for resizing
46
+ original_height, original_width = image.shape[:2]
47
+
48
+ if width > height:
49
+ size = width
50
+ original_size = original_width
51
+ else:
52
+ size = height
53
+ original_size = original_height
54
+
55
+ scale_factor = min(scale_factor, size / (original_size+0.0))
56
+
57
+ new_size = scale_factor * original_size
58
+ scale_factor = new_size / size
59
+
60
+ # Calculate the new size based on the scale factor
61
+ new_width = int(width * scale_factor)
62
+ new_height = int(height * scale_factor)
63
+
64
+ center_x = original_width // 2
65
+ center_y = original_height // 2
66
+
67
+ paste_x = center_x - (new_width // 2)
68
+ paste_y = center_y - (new_height // 2)
69
+
70
+ # Resize the object (RGB channels) to the new size
71
+ rescaled_object = cv2.resize(image[y:y+height, x:x+width], (new_width, new_height))
72
+
73
+ # Create a new RGBA image with the resized image
74
+ new_image = np.zeros((original_height, original_width, 4), dtype=np.uint8)
75
+
76
+ new_image[paste_y:paste_y + new_height, paste_x:paste_x + new_width] = rescaled_object
77
+
78
+ return new_image
79
+
80
+ class InferenceImageDataset(Dataset):
81
+ def __init__(self,
82
+ root_dir: str,
83
+ num_views: int,
84
+ img_wh: Tuple[int, int],
85
+ bg_color: str,
86
+ crop_size: int = 224,
87
+ single_image: Optional[PIL.Image.Image] = None,
88
+ num_validation_samples: Optional[int] = None,
89
+ filepaths: Optional[list] = None,
90
+ cam_types: Optional[list] = None,
91
+ cond_type: Optional[str] = None,
92
+ load_cam_type: Optional[bool] = True
93
+ ) -> None:
94
+ """Create a dataset from a folder of images.
95
+ If you pass in a root directory it will be searched for images
96
+ ending in ext (ext can be a list)
97
+ """
98
+ self.root_dir = root_dir
99
+ self.num_views = num_views
100
+ self.img_wh = img_wh
101
+ self.crop_size = crop_size
102
+ self.bg_color = bg_color
103
+ self.cond_type = cond_type
104
+ self.load_cam_type = load_cam_type
105
+ self.cam_types = cam_types
106
+
107
+ if self.num_views == 4:
108
+ self.view_types = ['front', 'right', 'back', 'left']
109
+ elif self.num_views == 5:
110
+ self.view_types = ['front', 'front_right', 'right', 'back', 'left']
111
+ elif self.num_views == 6:
112
+ self.view_types = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
113
+
114
+ self.fix_cam_pose_dir = "./mvdiffusion/data/fixed_poses/nine_views"
115
+
116
+ self.fix_cam_poses = self.load_fixed_poses() # world2cam matrix
117
+
118
+
119
+
120
+ if filepaths is None:
121
+ # Get a list of all files in the directory
122
+ file_list = os.listdir(self.root_dir)
123
+ self.cam_types = ['ortho'] * len(file_list) + ['persp']* len(file_list)
124
+ file_list = file_list * 2
125
+ else:
126
+ file_list = filepaths
127
+ print(filepaths, root_dir)
128
+ # Filter the files that end with .png or .jpg
129
+ self.file_list = [file for file in file_list]
130
+
131
+ self.bg_color = self.get_bg_color()
132
+
133
+
134
+
135
+
136
+ def __len__(self):
137
+ return len(self.file_list)
138
+
139
+ def load_fixed_poses(self):
140
+ poses = {}
141
+ for face in self.view_types:
142
+ RT = np.loadtxt(os.path.join(self.fix_cam_pose_dir,'%03d_%s_RT.txt'%(0, face)))
143
+ poses[face] = RT
144
+
145
+ return poses
146
+
147
+ def cartesian_to_spherical(self, xyz):
148
+ ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
149
+ xy = xyz[:,0]**2 + xyz[:,1]**2
150
+ z = np.sqrt(xy + xyz[:,2]**2)
151
+ theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down
152
+ #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
153
+ azimuth = np.arctan2(xyz[:,1], xyz[:,0])
154
+ return np.array([theta, azimuth, z])
155
+
156
+ def get_T(self, target_RT, cond_RT):
157
+ R, T = target_RT[:3, :3], target_RT[:, -1]
158
+ T_target = -R.T @ T # change to cam2world
159
+
160
+ R, T = cond_RT[:3, :3], cond_RT[:, -1]
161
+ T_cond = -R.T @ T
162
+
163
+ theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :])
164
+ theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :])
165
+
166
+ d_theta = theta_target - theta_cond
167
+ d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
168
+ d_z = z_target - z_cond
169
+
170
+ # d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()])
171
+ return d_theta, d_azimuth
172
+
173
+ def get_bg_color(self):
174
+ if self.bg_color == 'white':
175
+ bg_color = np.array([1., 1., 1.], dtype=np.float32)
176
+ elif self.bg_color == 'black':
177
+ bg_color = np.array([0., 0., 0.], dtype=np.float32)
178
+ elif self.bg_color == 'gray':
179
+ bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
180
+ elif self.bg_color == 'random':
181
+ bg_color = np.random.rand(3)
182
+ elif isinstance(self.bg_color, float):
183
+ bg_color = np.array([self.bg_color] * 3, dtype=np.float32)
184
+ else:
185
+ raise NotImplementedError
186
+ return bg_color
187
+
188
+
189
+ def load_image(self, img_path, bg_color, return_type='pt', Imagefile=None):
190
+ # pil always returns uint8
191
+ if Imagefile is None:
192
+ image_input = Image.open(img_path)
193
+ else:
194
+ image_input = Imagefile
195
+ image_size = self.img_wh[0]
196
+
197
+ # if self.crop_size!=-1:
198
+ # alpha_np = np.asarray(image_input)[:, :, 3]
199
+ # coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)]
200
+ # min_x, min_y = np.min(coords, 0)
201
+ # max_x, max_y = np.max(coords, 0)
202
+ # ref_img_ = image_input.crop((min_x, min_y, max_x, max_y))
203
+ # h, w = ref_img_.height, ref_img_.width
204
+ # scale = self.crop_size / max(h, w)
205
+ # h_, w_ = int(scale * h), int(scale * w)
206
+ # ref_img_ = ref_img_.resize((w_, h_))
207
+ # image_input = add_margin(ref_img_, size=image_size)
208
+ # else:
209
+ # image_input = add_margin(image_input, size=max(image_input.height, image_input.width))
210
+ # image_input = image_input.resize((image_size, image_size))
211
+
212
+ # img = scale_and_place_object(img, self.scale_ratio)
213
+ img = np.array(image_input)
214
+ img = img.astype(np.float32) / 255. # [0, 1]
215
+ assert img.shape[-1] == 4 # RGBA
216
+
217
+ alpha = img[...,3:4]
218
+ img = img[...,:3] * alpha + bg_color * (1 - alpha)
219
+
220
+ if return_type == "np":
221
+ pass
222
+ elif return_type == "pt":
223
+ img = torch.from_numpy(img)
224
+ alpha = torch.from_numpy(alpha)
225
+ else:
226
+ raise NotImplementedError
227
+
228
+ return img, alpha
229
+
230
+
231
+ def __len__(self):
232
+ return len(self.file_list)
233
+
234
+ def __getitem__(self, index):
235
+
236
+ # image = self.all_images[index%len(self.all_images)]
237
+ # alpha = self.all_alphas[index%len(self.all_images)]
238
+ cam_type = self.cam_types[index%len(self.file_list)]
239
+ if self.file_list is not None:
240
+ filename = self.file_list[index%len(self.file_list)].replace(".png", "")
241
+ else:
242
+ filename = 'null'
243
+
244
+ cond_w2c = self.fix_cam_poses['front']
245
+
246
+ tgt_w2cs = [self.fix_cam_poses[view] for view in self.view_types]
247
+
248
+ elevations = []
249
+ azimuths = []
250
+
251
+ img_tensors_in = []
252
+ for view in self.view_types:
253
+ img_path = os.path.join(self.root_dir, filename, cam_type,"color_000_%s.png" % (view))
254
+ img_tensor, alpha = self.load_image(img_path, self.bg_color, return_type="pt")
255
+ img_tensor = img_tensor.permute(2, 0, 1)
256
+ img_tensors_in.append(img_tensor)
257
+
258
+ alpha_tensors_in = [
259
+ alpha.permute(2, 0, 1)
260
+ ] * self.num_views
261
+
262
+ for view, tgt_w2c in zip(self.view_types, tgt_w2cs):
263
+ # evelations, azimuths
264
+ elevation, azimuth = self.get_T(tgt_w2c, cond_w2c)
265
+ elevations.append(elevation)
266
+ azimuths.append(azimuth)
267
+
268
+ img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
269
+ # alpha_tensors_in = torch.stack(alpha_tensors_in, dim=0).float() # (Nv, 3, H, W)
270
+
271
+ elevations = torch.as_tensor(elevations).float().squeeze(1)
272
+ azimuths = torch.as_tensor(azimuths).float().squeeze(1)
273
+ elevations_cond = torch.as_tensor([0] * self.num_views).float()
274
+
275
+ normal_class = torch.tensor([1, 0]).float()
276
+ normal_task_embeddings = torch.stack([normal_class] * self.num_views, dim=0) # (Nv, 2)
277
+ color_class = torch.tensor([0, 1]).float()
278
+ depth_task_embeddings = torch.stack([color_class] * self.num_views, dim=0) # (Nv, 2)
279
+
280
+ camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1) # (Nv, 3)
281
+
282
+ if cam_type == 'ortho':
283
+ cam_type_emb = torch.tensor([0, 1]).expand(self.num_views, -1)
284
+ else:
285
+ cam_type_emb = torch.tensor([1, 0]).expand(self.num_views, -1)
286
+
287
+ if self.load_cam_type:
288
+ camera_embeddings = torch.cat((camera_embeddings, cam_type_emb), dim=-1) # (Nv, 5)
289
+
290
+ out = {
291
+ 'elevations_cond': elevations_cond,
292
+ 'elevations_cond_deg': torch.rad2deg(elevations_cond),
293
+ 'elevations': elevations,
294
+ 'azimuths': azimuths,
295
+ 'elevations_deg': torch.rad2deg(elevations),
296
+ 'azimuths_deg': torch.rad2deg(azimuths),
297
+ 'imgs_in': img_tensors_in,
298
+ 'alphas': alpha_tensors_in,
299
+ 'camera_embeddings': camera_embeddings,
300
+ 'normal_task_embeddings': normal_task_embeddings,
301
+ 'depth_task_embeddings': depth_task_embeddings,
302
+ 'filename': filename,
303
+ 'cam_type': cam_type
304
+ }
305
+
306
+ return out
307
+
308
+
mv_diffusion_30/data/normal_utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def camNormal2worldNormal(rot_c2w, camNormal):
4
+ H,W,_ = camNormal.shape
5
+ normal_img = np.matmul(rot_c2w[None, :, :], camNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
6
+
7
+ return normal_img
8
+
9
+ def worldNormal2camNormal(rot_w2c, normal_map_world):
10
+ H,W,_ = normal_map_world.shape
11
+ # normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
12
+
13
+ # faster version
14
+ # Reshape the normal map into a 2D array where each row represents a normal vector
15
+ normal_map_flat = normal_map_world.reshape(-1, 3)
16
+
17
+ # Transform the normal vectors using the transformation matrix
18
+ normal_map_camera_flat = np.dot(normal_map_flat, rot_w2c.T)
19
+
20
+ # Reshape the transformed normal map back to its original shape
21
+ normal_map_camera = normal_map_camera_flat.reshape(normal_map_world.shape)
22
+
23
+ return normal_map_camera
24
+
25
+ def trans_normal(normal, RT_w2c, RT_w2c_target):
26
+
27
+ # normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal)
28
+ # normal_target_cam = worldNormal2camNormal(RT_w2c_target[:3,:3], normal_world)
29
+
30
+ relative_RT = np.matmul(RT_w2c_target[:3,:3], np.linalg.inv(RT_w2c[:3,:3]))
31
+ normal_target_cam = worldNormal2camNormal(relative_RT[:3,:3], normal)
32
+
33
+ return normal_target_cam
34
+
35
+ def img2normal(img):
36
+ return (img/255.)*2-1
37
+
38
+ def normal2img(normal):
39
+ return np.uint8((normal*0.5+0.5)*255)
40
+
41
+ def norm_normalize(normal, dim=-1):
42
+
43
+ normal = normal/(np.linalg.norm(normal, axis=dim, keepdims=True)+1e-6)
44
+
45
+ return normal
mv_diffusion_30/data/objaverse_dataset.py ADDED
@@ -0,0 +1,1359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import numpy as np
3
+ from omegaconf import DictConfig, ListConfig
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from pathlib import Path
7
+ import json
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+ from einops import rearrange
11
+ from typing import Literal, Tuple, Optional, Any
12
+ import cv2
13
+ import random
14
+
15
+ import json
16
+ import os, sys
17
+ import math
18
+
19
+ import PIL.Image
20
+ from .normal_utils import trans_normal, normal2img, img2normal
21
+ import pdb
22
+ from .depth_utils import scale_depth_to_model
23
+ import traceback
24
+
25
+
26
+ class ObjaverseDataset(Dataset):
27
+ def __init__(self,
28
+ root_dir_ortho: str,
29
+ root_dir_persp: str,
30
+ pred_ortho: bool,
31
+ pred_persp: bool,
32
+ num_views: int,
33
+ bg_color: Any,
34
+ img_wh: Tuple[int, int],
35
+ object_list: str,
36
+ groups_num: int=1,
37
+ validation: bool = False,
38
+ data_view_num: int = 6,
39
+ num_validation_samples: int = 64,
40
+ num_samples: Optional[int] = None,
41
+ invalid_list: Optional[str] = None,
42
+ trans_norm_system: bool = True, # if True, transform all normals map into the cam system of front view
43
+ augment_data: bool = False,
44
+ read_normal: bool = True,
45
+ read_color: bool = False,
46
+ read_depth: bool = False,
47
+ read_mask: bool = False,
48
+ pred_type: str = 'color',
49
+ suffix: str = 'png',
50
+ subscene_tag: int = 2,
51
+ load_cam_type: bool = False,
52
+ backup_scene: str = "0306b42594fb447ca574f597352d4b56",
53
+ ortho_crop_size: int = 360,
54
+ persp_crop_size: int = 440,
55
+ load_switcher: bool = True
56
+ ) -> None:
57
+ """Create a dataset from a folder of images.
58
+ If you pass in a root directory it will be searched for images
59
+ ending in ext (ext can be a list)
60
+ """
61
+ self.load_cam_type = load_cam_type
62
+ self.root_dir_ortho = Path(root_dir_ortho)
63
+ self.root_dir_persp = Path(root_dir_persp)
64
+ self.pred_ortho = pred_ortho
65
+ self.pred_persp = pred_persp
66
+ self.num_views = num_views
67
+ self.bg_color = bg_color
68
+ self.validation = validation
69
+ self.num_samples = num_samples
70
+ self.trans_norm_system = trans_norm_system
71
+ self.augment_data = augment_data
72
+ self.invalid_list = invalid_list
73
+ self.groups_num = groups_num
74
+ print("augment data: ", self.augment_data)
75
+ self.img_wh = img_wh
76
+ self.read_normal = read_normal
77
+ self.read_color = read_color
78
+ self.read_depth = read_depth
79
+ self.read_mask = read_mask
80
+ self.pred_type = pred_type # load type
81
+ self.suffix = suffix
82
+ self.subscene_tag = subscene_tag
83
+
84
+ self.view_types = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
85
+ self.fix_cam_pose_dir = "./mvdiffusion/data/fixed_poses/nine_views"
86
+
87
+ self.fix_cam_poses = self.load_fixed_poses() # world2cam matrix
88
+ self.ortho_crop_size = ortho_crop_size
89
+ self.persp_crop_size = persp_crop_size
90
+ self.load_switcher = load_switcher
91
+
92
+ if object_list is not None:
93
+ with open(object_list) as f:
94
+ self.objects = json.load(f)
95
+ self.objects = [os.path.basename(o).replace(".glb", "") for o in self.objects]
96
+ else:
97
+ self.objects = os.listdir(self.root_dir)
98
+ self.objects = sorted(self.objects)
99
+
100
+ if self.invalid_list is not None:
101
+ with open(self.invalid_list) as f:
102
+ self.invalid_objects = json.load(f)
103
+ self.invalid_objects = [os.path.basename(o).replace(".glb", "") for o in self.invalid_objects]
104
+ else:
105
+ self.invalid_objects = []
106
+
107
+
108
+ self.all_objects = set(self.objects) - (set(self.invalid_objects) & set(self.objects))
109
+ self.all_objects = list(self.all_objects)
110
+
111
+ if not validation:
112
+ self.all_objects = self.all_objects[:-num_validation_samples]
113
+ else:
114
+ self.all_objects = self.all_objects[-num_validation_samples:]
115
+ if num_samples is not None:
116
+ self.all_objects = self.all_objects[:num_samples]
117
+
118
+ print("loading ", len(self.all_objects), " objects in the dataset")
119
+
120
+ if self.pred_type == 'color':
121
+ self.backup_data = self.__getitem_color__(0, backup_scene)
122
+ elif self.pred_type == 'normal_depth':
123
+ self.backup_data = self.__getitem_normal_depth__(0, backup_scene)
124
+ elif self.pred_type == 'mixed_rgb_normal_depth':
125
+ self.backup_data = self.__getitem_mixed__(0, backup_scene)
126
+ elif self.pred_type == 'mixed_color_normal':
127
+ self.backup_data = self.__getitem_image_normal_mixed__(0, backup_scene)
128
+ elif self.pred_type == 'mixed_rgb_noraml_mask':
129
+ self.backup_data = self.__getitem_mixed_rgb_noraml_mask__(0, backup_scene)
130
+ elif self.pred_type == 'joint_color_normal':
131
+ self.backup_data = self.__getitem_joint_rgb_noraml__(0, backup_scene)
132
+
133
+
134
+ def __len__(self):
135
+ return len(self.objects)*self.total_view
136
+
137
+ def load_fixed_poses(self):
138
+ poses = {}
139
+ for face in self.view_types:
140
+ RT = np.loadtxt(os.path.join(self.fix_cam_pose_dir,'%03d_%s_RT.txt'%(0, face)))
141
+ poses[face] = RT
142
+
143
+ return poses
144
+
145
+ def cartesian_to_spherical(self, xyz):
146
+ ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
147
+ xy = xyz[:,0]**2 + xyz[:,1]**2
148
+ z = np.sqrt(xy + xyz[:,2]**2)
149
+ theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down
150
+ #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
151
+ azimuth = np.arctan2(xyz[:,1], xyz[:,0])
152
+ return np.array([theta, azimuth, z])
153
+
154
+ def get_T(self, target_RT, cond_RT):
155
+ R, T = target_RT[:3, :3], target_RT[:, -1]
156
+ T_target = -R.T @ T # change to cam2world
157
+
158
+ R, T = cond_RT[:3, :3], cond_RT[:, -1]
159
+ T_cond = -R.T @ T
160
+
161
+ theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :])
162
+ theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :])
163
+
164
+ d_theta = theta_target - theta_cond
165
+ d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
166
+ d_z = z_target - z_cond
167
+
168
+ # d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()])
169
+ return d_theta, d_azimuth
170
+
171
+ def get_bg_color(self):
172
+ if self.bg_color == 'white':
173
+ bg_color = np.array([1., 1., 1.], dtype=np.float32)
174
+ elif self.bg_color == 'black':
175
+ bg_color = np.array([0., 0., 0.], dtype=np.float32)
176
+ elif self.bg_color == 'gray':
177
+ bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
178
+ elif self.bg_color == 'random':
179
+ bg_color = np.random.rand(3)
180
+ elif self.bg_color == 'three_choices':
181
+ white = np.array([1., 1., 1.], dtype=np.float32)
182
+ black = np.array([0., 0., 0.], dtype=np.float32)
183
+ gray = np.array([0.5, 0.5, 0.5], dtype=np.float32)
184
+ bg_color = random.choice([white, black, gray])
185
+ elif isinstance(self.bg_color, float):
186
+ bg_color = np.array([self.bg_color] * 3, dtype=np.float32)
187
+ else:
188
+ raise NotImplementedError
189
+ return bg_color
190
+
191
+
192
+
193
+ def load_mask(self, img_path, return_type='np'):
194
+ # not using cv2 as may load in uint16 format
195
+ # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
196
+ # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
197
+ # pil always returns uint8
198
+ img = np.array(Image.open(img_path).resize(self.img_wh))
199
+ img = np.float32(img > 0)
200
+
201
+ assert len(np.shape(img)) == 2
202
+
203
+ if return_type == "np":
204
+ pass
205
+ elif return_type == "pt":
206
+ img = torch.from_numpy(img)
207
+ else:
208
+ raise NotImplementedError
209
+
210
+ return img
211
+
212
+ def load_mask_from_rgba(self, img_path, camera_type):
213
+ img = Image.open(img_path)
214
+
215
+ if camera_type == 'ortho':
216
+ left = (img.width - self.ortho_crop_size) // 2
217
+ right = (img.width + self.ortho_crop_size) // 2
218
+ top = (img.height - self.ortho_crop_size) // 2
219
+ bottom = (img.height + self.ortho_crop_size) // 2
220
+ img = img.crop((left, top, right, bottom))
221
+ if camera_type == 'persp':
222
+ left = (img.width - self.persp_crop_size) // 2
223
+ right = (img.width + self.persp_crop_size) // 2
224
+ top = (img.height - self.persp_crop_size) // 2
225
+ bottom = (img.height + self.persp_crop_size) // 2
226
+ img = img.crop((left, top, right, bottom))
227
+
228
+ img = img.resize(self.img_wh)
229
+ img = np.array(img).astype(np.float32) / 255. # [0, 1]
230
+ assert img.shape[-1] == 4 # must RGBA
231
+
232
+ alpha = img[:, :, 3:]
233
+
234
+ if alpha.shape[-1] != 1:
235
+ alpha = alpha[:, :, None]
236
+
237
+ return alpha
238
+
239
+ def load_image(self, img_path, bg_color, alpha, return_type='np', camera_type=None, read_depth=False, center_crop_size=None):
240
+ # not using cv2 as may load in uint16 format
241
+ # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
242
+ # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
243
+ # pil always returns uint8
244
+ img = Image.open(img_path)
245
+ if center_crop_size == None:
246
+ if camera_type == 'ortho':
247
+ left = (img.width - self.ortho_crop_size) // 2
248
+ right = (img.width + self.ortho_crop_size) // 2
249
+ top = (img.height - self.ortho_crop_size) // 2
250
+ bottom = (img.height + self.ortho_crop_size) // 2
251
+ img = img.crop((left, top, right, bottom))
252
+ if camera_type == 'persp':
253
+ left = (img.width - self.persp_crop_size) // 2
254
+ right = (img.width + self.persp_crop_size) // 2
255
+ top = (img.height - self.persp_crop_size) // 2
256
+ bottom = (img.height + self.persp_crop_size) // 2
257
+ img = img.crop((left, top, right, bottom))
258
+ else:
259
+ center_crop_size = min(center_crop_size, 512)
260
+ left = (img.width - center_crop_size) // 2
261
+ right = (img.width + center_crop_size) // 2
262
+ top = (img.height - center_crop_size) // 2
263
+ bottom = (img.height + center_crop_size) // 2
264
+ img = img.crop((left, top, right, bottom))
265
+
266
+ img = img.resize(self.img_wh)
267
+ img = np.array(img).astype(np.float32) / 255. # [0, 1]
268
+ assert img.shape[-1] == 3 or img.shape[-1] == 4 # RGB or RGBA
269
+
270
+ if alpha is None and img.shape[-1] == 4:
271
+ alpha = img[:, :, 3:]
272
+ img = img[:, :, :3]
273
+
274
+ if alpha.shape[-1] != 1:
275
+ alpha = alpha[:, :, None]
276
+
277
+ if read_depth:
278
+ bg_color = np.array([1., 1., 1.], dtype=np.float32)
279
+ img = img[...,:3] * alpha + bg_color * (1 - alpha)
280
+
281
+ if return_type == "np":
282
+ pass
283
+ elif return_type == "pt":
284
+ img = torch.from_numpy(img)
285
+ else:
286
+ raise NotImplementedError
287
+
288
+ return img
289
+
290
+ def load_depth(self, img_path, bg_color, alpha, return_type='np', camera_type=None):
291
+ # not using cv2 as may load in uint16 format
292
+ # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
293
+ # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
294
+ # pil always returns uint8
295
+ depth_bg_color = np.array([1., 1., 1.], dtype=np.float32) # white for depth
296
+ depth_map = Image.open(img_path)
297
+
298
+ if camera_type == 'ortho':
299
+ left = (depth_map.width - self.ortho_crop_size) // 2
300
+ right = (depth_map.width + self.ortho_crop_size) // 2
301
+ top = (depth_map.height - self.ortho_crop_size) // 2
302
+ bottom = (depth_map.height + self.ortho_crop_size) // 2
303
+ depth_map = depth_map.crop((left, top, right, bottom))
304
+ if camera_type == 'persp':
305
+ left = (depth_map.width - self.persp_crop_size) // 2
306
+ right = (depth_map.width + self.persp_crop_size) // 2
307
+ top = (depth_map.height - self.persp_crop_size) // 2
308
+ bottom = (depth_map.height + self.persp_crop_size) // 2
309
+ depth_map = depth_map.crop((left, top, right, bottom))
310
+
311
+ depth_map = depth_map.resize(self.img_wh)
312
+ depth_map = np.array(depth_map)
313
+
314
+ # scale the depth map:
315
+ depth_map = scale_depth_to_model(depth_map.astype(np.float32))
316
+ # depth_map = depth_map / 65535. # [0, 1]
317
+ # depth_map[depth_map > 0.4] = 0
318
+ # depth_map = depth_map / 0.4
319
+
320
+ assert depth_map.ndim == 2 # depth
321
+ img = np.stack([depth_map]*3, axis=-1)
322
+
323
+ if alpha.shape[-1] != 1:
324
+ alpha = alpha[:, :, None]
325
+
326
+
327
+ # print(np.max(img[:, :, 0]))
328
+ # print(np.min(img[...,:3]), np.max(img[...,:3]))
329
+ img = img[...,:3] * alpha + depth_bg_color * (1 - alpha)
330
+
331
+ if return_type == "np":
332
+ pass
333
+ elif return_type == "pt":
334
+ img = torch.from_numpy(img)
335
+ else:
336
+ raise NotImplementedError
337
+
338
+ return img
339
+
340
+ def transform_mask_as_input(self, mask, return_type='np'):
341
+
342
+ # mask = mask * 255
343
+ # print(np.max(mask))
344
+
345
+ # mask = mask.resize(self.img_wh)
346
+ mask = np.squeeze(mask, axis=-1)
347
+ assert mask.ndim == 2 #
348
+ mask = np.stack([mask]*3, axis=-1)
349
+ if return_type == "np":
350
+ pass
351
+ elif return_type == "pt":
352
+ mask = torch.from_numpy(mask)
353
+ else:
354
+ raise NotImplementedError
355
+ return mask
356
+
357
+
358
+
359
+ def load_normal(self, img_path, bg_color, alpha, RT_w2c=None, RT_w2c_cond=None, return_type='np', camera_type=None, center_crop_size=None):
360
+ # not using cv2 as may load in uint16 format
361
+ # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
362
+ # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
363
+ # pil always returns uint8
364
+ # normal = Image.open(img_path)
365
+
366
+ img = Image.open(img_path)
367
+ if center_crop_size == None:
368
+ if camera_type == 'ortho':
369
+ left = (img.width - self.ortho_crop_size) // 2
370
+ right = (img.width + self.ortho_crop_size) // 2
371
+ top = (img.height - self.ortho_crop_size) // 2
372
+ bottom = (img.height + self.ortho_crop_size) // 2
373
+ img = img.crop((left, top, right, bottom))
374
+ if camera_type == 'persp':
375
+ left = (img.width - self.persp_crop_size) // 2
376
+ right = (img.width + self.persp_crop_size) // 2
377
+ top = (img.height - self.persp_crop_size) // 2
378
+ bottom = (img.height + self.persp_crop_size) // 2
379
+ img = img.crop((left, top, right, bottom))
380
+ else:
381
+ center_crop_size = min(center_crop_size, 512)
382
+ left = (img.width - center_crop_size) // 2
383
+ right = (img.width + center_crop_size) // 2
384
+ top = (img.height - center_crop_size) // 2
385
+ bottom = (img.height + center_crop_size) // 2
386
+ img = img.crop((left, top, right, bottom))
387
+
388
+ normal = np.array(img.resize(self.img_wh))
389
+
390
+ assert normal.shape[-1] == 3 or normal.shape[-1] == 4 # RGB or RGBA
391
+
392
+ if alpha is None and normal.shape[-1] == 4:
393
+ alpha = normal[:, :, 3:] / 255.
394
+ normal = normal[:, :, :3]
395
+
396
+ normal = trans_normal(img2normal(normal), RT_w2c, RT_w2c_cond)
397
+
398
+ img = (normal*0.5 + 0.5).astype(np.float32) # [0, 1]
399
+
400
+ if alpha.shape[-1] != 1:
401
+ alpha = alpha[:, :, None]
402
+
403
+ img = img[...,:3] * alpha + bg_color * (1 - alpha)
404
+
405
+ if return_type == "np":
406
+ pass
407
+ elif return_type == "pt":
408
+ img = torch.from_numpy(img)
409
+ else:
410
+ raise NotImplementedError
411
+
412
+ return img
413
+
414
+ def __len__(self):
415
+ return len(self.all_objects)
416
+
417
+ def __getitem_color__(self, index, debug_object=None):
418
+ if debug_object is not None:
419
+ object_name = debug_object #
420
+ set_idx = random.sample(range(0, self.groups_num), 1)[0] # without replacement
421
+ else:
422
+ object_name = self.all_objects[index % len(self.all_objects)]
423
+ set_idx = 0
424
+
425
+ if self.augment_data:
426
+ cond_view = random.sample(self.view_types, k=1)[0]
427
+ else:
428
+ cond_view = 'front'
429
+
430
+ assert self.pred_ortho or self.pred_persp
431
+ if self.pred_ortho and self.pred_persp:
432
+ if random.random() < 0.5:
433
+ load_dir = self.root_dir_ortho
434
+ load_cam_type = 'ortho'
435
+ else:
436
+ load_dir = self.root_dir_persp
437
+ load_cam_type = 'persp'
438
+ elif self.pred_ortho and not self.pred_persp:
439
+ load_dir = self.root_dir_ortho
440
+ load_cam_type = 'ortho'
441
+ elif self.pred_persp and not self.pred_ortho:
442
+ load_dir = self.root_dir_persp
443
+ load_cam_type = 'persp'
444
+
445
+ # ! if you would like predict depth; modify here
446
+
447
+ read_color, read_normal, read_depth = True, False, False
448
+
449
+
450
+ assert (read_color and (read_normal or read_depth)) is False
451
+
452
+ view_types = self.view_types
453
+
454
+ cond_w2c = self.fix_cam_poses[cond_view]
455
+
456
+ tgt_w2cs = [self.fix_cam_poses[view] for view in view_types]
457
+
458
+ elevations = []
459
+ azimuths = []
460
+
461
+ # get the bg color
462
+ bg_color = self.get_bg_color()
463
+
464
+ if self.read_mask:
465
+ cond_alpha = self.load_mask(os.path.join(load_dir, object_name[:self.subscene_tag], object_name,
466
+ "mask_%03d_%s.%s" % (set_idx, cond_view, self.suffix)),
467
+ return_type='np')
468
+ else:
469
+ cond_alpha = None
470
+ img_tensors_in = [
471
+ self.load_image(os.path.join(load_dir, object_name[:self.subscene_tag], object_name,
472
+ "rgb_%03d_%s.%s" % (set_idx, cond_view, self.suffix)),
473
+ bg_color, cond_alpha, return_type='pt', camera_type=load_cam_type).permute(2, 0, 1)
474
+ ] * self.num_views
475
+ img_tensors_out = []
476
+
477
+ for view, tgt_w2c in zip(view_types, tgt_w2cs):
478
+ img_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name,
479
+ "rgb_%03d_%s.%s" % (set_idx, view, self.suffix))
480
+ mask_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name,
481
+ "mask_%03d_%s.%s" % (set_idx, view, self.suffix))
482
+ normal_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name,
483
+ "normals_%03d_%s.%s" % (set_idx, view, self.suffix))
484
+ depth_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name,
485
+ "depth_%03d_%s.%s" % (set_idx, view, self.suffix))
486
+ if self.read_mask:
487
+ alpha = self.load_mask(mask_path, return_type='np')
488
+ else:
489
+ alpha = None
490
+
491
+ if read_color:
492
+ img_tensor = self.load_image(img_path, bg_color, alpha, return_type="pt", camera_type=load_cam_type)
493
+ img_tensor = img_tensor.permute(2, 0, 1)
494
+ img_tensors_out.append(img_tensor)
495
+
496
+ if read_normal:
497
+ normal_tensor = self.load_normal(normal_path, bg_color, alpha, RT_w2c=tgt_w2c, RT_w2c_cond=cond_w2c,
498
+ return_type="pt", camera_type=load_cam_type).permute(2, 0, 1)
499
+ img_tensors_out.append(normal_tensor)
500
+ if read_depth:
501
+ depth_tensor = self.load_depth(depth_path, bg_color, alpha, return_type="pt", camera_type=load_cam_type).permute(2, 0, 1)
502
+ img_tensors_out.append(depth_tensor)
503
+
504
+ # evelations, azimuths
505
+ elevation, azimuth = self.get_T(tgt_w2c, cond_w2c)
506
+ elevations.append(elevation)
507
+ azimuths.append(azimuth)
508
+
509
+ img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
510
+ img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
511
+
512
+ elevations = torch.as_tensor(elevations).float().squeeze(1)
513
+ azimuths = torch.as_tensor(azimuths).float().squeeze(1)
514
+ elevations_cond = torch.as_tensor([0] * self.num_views).float() # fixed only use 4 views to train
515
+
516
+ if load_cam_type == 'ortho':
517
+ cam_type_emb = torch.tensor([0, 1]).expand(self.num_views, -1)
518
+ else:
519
+ cam_type_emb = torch.tensor([1, 0]).expand(self.num_views, -1)
520
+ camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1)
521
+ # if self.pred_ortho and self.pred_persp:
522
+ if self.load_cam_type:
523
+ camera_embeddings = torch.cat((camera_embeddings, cam_type_emb), dim=-1) # (Nv, 5)
524
+
525
+ normal_class = torch.tensor([1, 0]).float()
526
+ normal_task_embeddings = torch.stack([normal_class] * self.num_views, dim=0) # (Nv, 2)
527
+ color_class = torch.tensor([0, 1]).float()
528
+ color_task_embeddings = torch.stack([color_class] * self.num_views, dim=0) # (Nv, 2)
529
+ if read_normal or read_depth:
530
+ task_embeddings = normal_task_embeddings
531
+ if read_color:
532
+ task_embeddings = color_task_embeddings
533
+ # print(elevations)
534
+ # print(azimuths)
535
+ return {
536
+ 'elevations_cond': elevations_cond,
537
+ 'elevations_cond_deg': torch.rad2deg(elevations_cond),
538
+ 'elevations': elevations,
539
+ 'azimuths': azimuths,
540
+ 'elevations_deg': torch.rad2deg(elevations),
541
+ 'azimuths_deg': torch.rad2deg(azimuths),
542
+ 'imgs_in': img_tensors_in,
543
+ 'imgs_out': img_tensors_out,
544
+ 'camera_embeddings': camera_embeddings,
545
+ 'task_embeddings': task_embeddings
546
+ }
547
+
548
+ def __getitem_normal_depth__(self, index, debug_object=None):
549
+ if debug_object is not None:
550
+ object_name = debug_object #
551
+ set_idx = random.sample(range(0, self.groups_num), 1)[0] # without replacement
552
+ else:
553
+ object_name = self.all_objects[index%len(self.all_objects)]
554
+ set_idx = 0
555
+
556
+ if self.augment_data:
557
+ cond_view = random.sample(self.view_types, k=1)[0]
558
+ else:
559
+ cond_view = 'front'
560
+
561
+ assert self.pred_ortho or self.pred_persp
562
+ if self.pred_ortho and self.pred_persp:
563
+ if random.random() < 0.5:
564
+ load_dir = self.root_dir_ortho
565
+ load_cam_type = 'ortho'
566
+ else:
567
+ load_dir = self.root_dir_persp
568
+ load_cam_type = 'persp'
569
+ elif self.pred_ortho and not self.pred_persp:
570
+ load_dir = self.root_dir_ortho
571
+ load_cam_type = 'ortho'
572
+ elif self.pred_persp and not self.pred_ortho:
573
+ load_dir = self.root_dir_persp
574
+ load_cam_type = 'persp'
575
+
576
+ view_types = self.view_types
577
+
578
+ cond_w2c = self.fix_cam_poses[cond_view]
579
+
580
+ tgt_w2cs = [self.fix_cam_poses[view] for view in view_types]
581
+
582
+ elevations = []
583
+ azimuths = []
584
+
585
+ # get the bg color
586
+ bg_color = self.get_bg_color()
587
+
588
+ if self.read_mask:
589
+ cond_alpha = self.load_mask(os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "mask_%03d_%s.%s" % (set_idx, cond_view, self.suffix)), return_type='np')
590
+ else:
591
+ cond_alpha = None
592
+ # img_tensors_in = [
593
+ # self.load_image(os.path.join(self.root_dir, object_name[:self.subscene_tag], object_name, "rgb_%03d_%s.%s" % (set_idx, cond_view, self.suffix)), bg_color, cond_alpha, return_type='pt').permute(2, 0, 1)
594
+ # ] * self.num_views
595
+ img_tensors_out = []
596
+ normal_tensors_out = []
597
+ depth_tensors_out = []
598
+ for view, tgt_w2c in zip(view_types, tgt_w2cs):
599
+ img_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "rgb_%03d_%s.%s" % (set_idx, view, self.suffix))
600
+ mask_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "mask_%03d_%s.%s" % (set_idx, view, self.suffix))
601
+ depth_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "depth_%03d_%s.%s" % (set_idx, view, self.suffix))
602
+
603
+ if self.read_mask:
604
+ alpha = self.load_mask(mask_path, return_type='np')
605
+ else:
606
+ alpha = None
607
+
608
+ if self.read_color:
609
+ img_tensor = self.load_image(img_path, bg_color, alpha, return_type="pt", camera_type=load_cam_type)
610
+ img_tensor = img_tensor.permute(2, 0, 1)
611
+ img_tensors_out.append(img_tensor)
612
+
613
+ if self.read_normal:
614
+ normal_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "normals_%03d_%s.%s" % (set_idx, view, self.suffix))
615
+ normal_tensor = self.load_normal(normal_path, bg_color, alpha, RT_w2c=tgt_w2c, RT_w2c_cond=cond_w2c, return_type="pt", camera_type=load_cam_type).permute(2, 0, 1)
616
+ normal_tensors_out.append(normal_tensor)
617
+
618
+ if self.read_depth:
619
+ if alpha is None:
620
+ alpha = self.load_mask_from_rgba(img_path, camera_type=load_cam_type)
621
+ depth_tensor = self.load_depth(depth_path, bg_color, alpha, return_type="pt", camera_type=load_cam_type).permute(2, 0, 1)
622
+ depth_tensors_out.append(depth_tensor)
623
+
624
+
625
+ # evelations, azimuths
626
+ elevation, azimuth = self.get_T(tgt_w2c, cond_w2c)
627
+ elevations.append(elevation)
628
+ azimuths.append(azimuth)
629
+
630
+ img_tensors_in = img_tensors_out
631
+ img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
632
+ if self.read_color:
633
+ img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
634
+ if self.read_normal:
635
+ normal_tensors_out = torch.stack(normal_tensors_out, dim=0).float() # (Nv, 3, H, W)
636
+ if self.read_depth:
637
+ depth_tensors_out = torch.stack(depth_tensors_out, dim=0).float() # (Nv, 3, H, W)
638
+
639
+ elevations = torch.as_tensor(elevations).float().squeeze(1)
640
+ azimuths = torch.as_tensor(azimuths).float().squeeze(1)
641
+ elevations_cond = torch.as_tensor([0] * self.num_views).float() # fixed only use 4 views to train
642
+
643
+ if load_cam_type == 'ortho':
644
+ cam_type_emb = torch.tensor([0, 1]).expand(self.num_views, -1)
645
+ else:
646
+ cam_type_emb = torch.tensor([1, 0]).expand(self.num_views, -1)
647
+ camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1)
648
+ # if self.pred_ortho and self.pred_persp:
649
+ if self.load_cam_type:
650
+ camera_embeddings = torch.cat((camera_embeddings, cam_type_emb), dim=-1) # (Nv, 5)
651
+
652
+ normal_class = torch.tensor([1, 0]).float()
653
+ normal_task_embeddings = torch.stack([normal_class]*self.num_views, dim=0) # (Nv, 2)
654
+ color_class = torch.tensor([0, 1]).float()
655
+ depth_task_embeddings = torch.stack([color_class]*self.num_views, dim=0) # (Nv, 2)
656
+
657
+ return {
658
+ 'elevations_cond': elevations_cond,
659
+ 'elevations_cond_deg': torch.rad2deg(elevations_cond),
660
+ 'elevations': elevations,
661
+ 'azimuths': azimuths,
662
+ 'elevations_deg': torch.rad2deg(elevations),
663
+ 'azimuths_deg': torch.rad2deg(azimuths),
664
+ 'imgs_in': img_tensors_in,
665
+ 'imgs_out': img_tensors_out,
666
+ 'normals_out': normal_tensors_out,
667
+ 'depth_out': depth_tensors_out,
668
+ 'camera_embeddings': camera_embeddings,
669
+ 'normal_task_embeddings': normal_task_embeddings,
670
+ 'depth_task_embeddings': depth_task_embeddings
671
+ }
672
+
673
+ def __getitem_mixed_rgb_noraml_mask__(self, index, debug_object=None):
674
+ if debug_object is not None:
675
+ object_name = debug_object #
676
+ set_idx = random.sample(range(0, self.groups_num), 1)[0] # without replacement
677
+ else:
678
+ object_name = self.all_objects[index%len(self.all_objects)]
679
+ set_idx = 0
680
+
681
+ if self.augment_data:
682
+ cond_view = random.sample(self.view_types, k=1)[0]
683
+ else:
684
+ cond_view = 'front'
685
+
686
+ assert self.pred_ortho or self.pred_persp
687
+ if self.pred_ortho and self.pred_persp:
688
+ if random.random() < 0.5:
689
+ load_dir = self.root_dir_ortho
690
+ load_cam_type = 'ortho'
691
+ else:
692
+ load_dir = self.root_dir_persp
693
+ load_cam_type = 'persp'
694
+ elif self.pred_ortho and not self.pred_persp:
695
+ load_dir = self.root_dir_ortho
696
+ load_cam_type = 'ortho'
697
+ elif self.pred_persp and not self.pred_ortho:
698
+ load_dir = self.root_dir_persp
699
+ load_cam_type = 'persp'
700
+
701
+ view_types = self.view_types
702
+
703
+ cond_w2c = self.fix_cam_poses[cond_view]
704
+
705
+ tgt_w2cs = [self.fix_cam_poses[view] for view in view_types]
706
+
707
+ elevations = []
708
+ azimuths = []
709
+
710
+ # get the bg color
711
+ bg_color = self.get_bg_color()
712
+
713
+ if self.read_mask:
714
+ cond_alpha = self.load_mask(os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "mask_%03d_%s.%s" % (set_idx, cond_view, self.suffix)), return_type='np')
715
+ else:
716
+ cond_alpha = None
717
+
718
+ img_tensors_out = []
719
+ normal_tensors_out = []
720
+ depth_tensors_out = []
721
+
722
+ random_select = random.random()
723
+ read_color, read_normal, read_mask = [random_select < 1 / 3, 1 / 3 <= random_select <= 2 / 3,
724
+ random_select > 2 / 3]
725
+ # print(read_color, read_normal, read_depth)
726
+
727
+ assert sum([read_color, read_normal, read_mask]) == 1, "Only one variable should be True"
728
+
729
+ for view, tgt_w2c in zip(view_types, tgt_w2cs):
730
+ img_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "rgb_%03d_%s.%s" % (set_idx, view, self.suffix))
731
+ mask_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "mask_%03d_%s.%s" % (set_idx, view, self.suffix))
732
+ depth_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "depth_%03d_%s.%s" % (set_idx, view, self.suffix))
733
+
734
+ if self.read_mask:
735
+ alpha = self.load_mask(mask_path, return_type='np')
736
+ else:
737
+ alpha = None
738
+
739
+ if read_color:
740
+ img_tensor = self.load_image(img_path, bg_color, alpha, return_type="pt", camera_type=load_cam_type, read_depth=False)
741
+ img_tensor = img_tensor.permute(2, 0, 1)
742
+ img_tensors_out.append(img_tensor)
743
+
744
+ if read_normal:
745
+ normal_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "normals_%03d_%s.%s" % (set_idx, view, self.suffix))
746
+ normal_tensor = self.load_normal(normal_path, bg_color, alpha, RT_w2c=tgt_w2c, RT_w2c_cond=cond_w2c, return_type="pt", camera_type=load_cam_type).permute(2, 0, 1)
747
+ img_tensors_out.append(normal_tensor)
748
+
749
+ if read_mask:
750
+ if alpha is None:
751
+ alpha = self.load_mask_from_rgba(img_path, camera_type=load_cam_type)
752
+ mask_tensor = self.transform_mask_as_input(alpha, return_type='pt').permute(2, 0, 1)
753
+ img_tensors_out.append(mask_tensor)
754
+
755
+ # evelations, azimuths
756
+ elevation, azimuth = self.get_T(tgt_w2c, cond_w2c)
757
+ elevations.append(elevation)
758
+ azimuths.append(azimuth)
759
+
760
+ if self.load_switcher: # rgb input, use domain switcher to control the output type
761
+ img_tensors_in = [
762
+ self.load_image(os.path.join(load_dir, object_name[:self.subscene_tag], object_name,
763
+ "normals_%03d_%s.%s" % (set_idx, cond_view, self.suffix)),
764
+ bg_color, cond_alpha, RT_w2c=cond_w2c, RT_w2c_cond=cond_w2c, return_type='pt', camera_type=load_cam_type).permute(
765
+ 2, 0, 1)
766
+ ] * self.num_views
767
+ color_class = torch.tensor([0, 1]).float()
768
+ color_task_embeddings = torch.stack([color_class] * self.num_views, dim=0) # (Nv, 2)
769
+
770
+ normal_class = torch.tensor([1, 0]).float()
771
+ normal_task_embeddings = torch.stack([normal_class] * self.num_views, dim=0) # (Nv, 2)
772
+
773
+ mask_class = torch.tensor([1, 1]).float()
774
+ mask_task_embeddings = torch.stack([mask_class] * self.num_views, dim=0)
775
+
776
+ if read_color:
777
+ task_embeddings = color_task_embeddings
778
+ # img_tensors_out = depth_tensors_out
779
+ elif read_normal:
780
+ task_embeddings = normal_task_embeddings
781
+ # img_tensors_out = normal_tensors_out
782
+ elif read_mask:
783
+ task_embeddings = mask_task_embeddings
784
+ # img_tensors_out = depth_tensors_out
785
+
786
+ else: # for stage 1 training, the input and the output are in the same domain
787
+ img_tensors_in = [img_tensors_out[0]] * self.num_views
788
+
789
+ empty_class = torch.tensor([0, 0]).float() # empty task
790
+ empty_task_embeddings = torch.stack([empty_class] * self.num_views, dim=0)
791
+ task_embeddings = empty_task_embeddings
792
+
793
+ img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
794
+
795
+ img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
796
+
797
+ elevations = torch.as_tensor(elevations).float().squeeze(1)
798
+ azimuths = torch.as_tensor(azimuths).float().squeeze(1)
799
+ elevations_cond = torch.as_tensor([0] * self.num_views).float() # fixed only use 4 views to train
800
+
801
+ if load_cam_type == 'ortho':
802
+ cam_type_emb = torch.tensor([0, 1]).expand(self.num_views, -1)
803
+ else:
804
+ cam_type_emb = torch.tensor([1, 0]).expand(self.num_views, -1)
805
+
806
+ camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1)
807
+
808
+ if self.load_cam_type:
809
+ camera_embeddings = torch.cat((camera_embeddings, cam_type_emb), dim=-1) # (Nv, 5)
810
+
811
+ return {
812
+ 'elevations_cond': elevations_cond,
813
+ 'elevations_cond_deg': torch.rad2deg(elevations_cond),
814
+ 'elevations': elevations,
815
+ 'azimuths': azimuths,
816
+ 'elevations_deg': torch.rad2deg(elevations),
817
+ 'azimuths_deg': torch.rad2deg(azimuths),
818
+ 'imgs_in': img_tensors_in,
819
+ 'imgs_out': img_tensors_out,
820
+ 'normals_out': normal_tensors_out,
821
+ 'depth_out': depth_tensors_out,
822
+ 'camera_embeddings': camera_embeddings,
823
+ 'task_embeddings': task_embeddings,
824
+ }
825
+
826
+ def __getitem_mixed__(self, index, debug_object=None):
827
+ if debug_object is not None:
828
+ object_name = debug_object #
829
+ set_idx = random.sample(range(0, self.groups_num), 1)[0] # without replacement
830
+ else:
831
+ object_name = self.all_objects[index%len(self.all_objects)]
832
+ set_idx = 0
833
+
834
+ if self.augment_data:
835
+ cond_view = random.sample(self.view_types, k=1)[0]
836
+ else:
837
+ cond_view = 'front'
838
+
839
+ assert self.pred_ortho or self.pred_persp
840
+ if self.pred_ortho and self.pred_persp:
841
+ if random.random() < 0.5:
842
+ load_dir = self.root_dir_ortho
843
+ load_cam_type = 'ortho'
844
+ else:
845
+ load_dir = self.root_dir_persp
846
+ load_cam_type = 'persp'
847
+ elif self.pred_ortho and not self.pred_persp:
848
+ load_dir = self.root_dir_ortho
849
+ load_cam_type = 'ortho'
850
+ elif self.pred_persp and not self.pred_ortho:
851
+ load_dir = self.root_dir_persp
852
+ load_cam_type = 'persp'
853
+
854
+ view_types = self.view_types
855
+
856
+ cond_w2c = self.fix_cam_poses[cond_view]
857
+
858
+ tgt_w2cs = [self.fix_cam_poses[view] for view in view_types]
859
+
860
+ elevations = []
861
+ azimuths = []
862
+
863
+ # get the bg color
864
+ bg_color = self.get_bg_color()
865
+
866
+ if self.read_mask:
867
+ cond_alpha = self.load_mask(os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "mask_%03d_%s.%s" % (set_idx, cond_view, self.suffix)), return_type='np')
868
+ else:
869
+ cond_alpha = None
870
+ # img_tensors_in = [
871
+ # self.load_image(os.path.join(self.root_dir, object_name[:self.subscene_tag], object_name, "rgb_%03d_%s.%s" % (set_idx, cond_view, self.suffix)), bg_color, cond_alpha, return_type='pt').permute(2, 0, 1)
872
+ # ] * self.num_views
873
+ img_tensors_out = []
874
+ normal_tensors_out = []
875
+ depth_tensors_out = []
876
+
877
+ random_select = random.random()
878
+ read_color, read_normal, read_depth = [random_select < 1 / 3, 1 / 3 <= random_select <= 2 / 3,
879
+ random_select > 2 / 3]
880
+ # print(read_color, read_normal, read_depth)
881
+
882
+ assert sum([read_color, read_normal, read_depth]) == 1, "Only one variable should be True"
883
+
884
+ for view, tgt_w2c in zip(view_types, tgt_w2cs):
885
+ img_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "rgb_%03d_%s.%s" % (set_idx, view, self.suffix))
886
+ mask_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "mask_%03d_%s.%s" % (set_idx, view, self.suffix))
887
+ depth_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "depth_%03d_%s.%s" % (set_idx, view, self.suffix))
888
+
889
+ if self.read_mask:
890
+ alpha = self.load_mask(mask_path, return_type='np')
891
+ else:
892
+ alpha = None
893
+
894
+ if read_color:
895
+ img_tensor = self.load_image(img_path, bg_color, alpha, return_type="pt", camera_type=load_cam_type, read_depth=read_depth)
896
+ img_tensor = img_tensor.permute(2, 0, 1)
897
+ img_tensors_out.append(img_tensor)
898
+
899
+ if read_normal:
900
+ normal_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "normals_%03d_%s.%s" % (set_idx, view, self.suffix))
901
+ normal_tensor = self.load_normal(normal_path, bg_color, alpha, RT_w2c=tgt_w2c, RT_w2c_cond=cond_w2c, return_type="pt", camera_type=load_cam_type).permute(2, 0, 1)
902
+ img_tensors_out.append(normal_tensor)
903
+
904
+ if read_depth:
905
+ if alpha is None:
906
+ alpha = self.load_mask_from_rgba(img_path, camera_type=load_cam_type)
907
+ depth_tensor = self.load_depth(depth_path, bg_color, alpha, return_type="pt", camera_type=load_cam_type).permute(2, 0, 1)
908
+ img_tensors_out.append(depth_tensor)
909
+
910
+
911
+ # evelations, azimuths
912
+ elevation, azimuth = self.get_T(tgt_w2c, cond_w2c)
913
+ elevations.append(elevation)
914
+ azimuths.append(azimuth)
915
+
916
+ img_tensors_in = [
917
+ self.load_image(os.path.join(load_dir, object_name[:self.subscene_tag], object_name,
918
+ "rgb_%03d_%s.%s" % (set_idx, cond_view, self.suffix)),
919
+ bg_color, cond_alpha, return_type='pt', camera_type=load_cam_type, read_depth=read_depth).permute(
920
+ 2, 0, 1)
921
+ ] * self.num_views
922
+ img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
923
+ # if self.read_color:
924
+ # img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
925
+ # if self.read_normal:
926
+ # normal_tensors_out = torch.stack(normal_tensors_out, dim=0).float() # (Nv, 3, H, W)
927
+ # if self.read_depth:
928
+ # depth_tensors_out = torch.stack(depth_tensors_out, dim=0).float() # (Nv, 3, H, W)
929
+ img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
930
+ elevations = torch.as_tensor(elevations).float().squeeze(1)
931
+ azimuths = torch.as_tensor(azimuths).float().squeeze(1)
932
+ elevations_cond = torch.as_tensor([0] * self.num_views).float() # fixed only use 4 views to train
933
+
934
+ if load_cam_type == 'ortho':
935
+ cam_type_emb = torch.tensor([0, 1]).expand(self.num_views, -1)
936
+ else:
937
+ cam_type_emb = torch.tensor([1, 0]).expand(self.num_views, -1)
938
+ camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1)
939
+ # if self.pred_ortho and self.pred_persp:
940
+ if self.load_cam_type:
941
+ camera_embeddings = torch.cat((camera_embeddings, cam_type_emb), dim=-1) # (Nv, 5)
942
+
943
+ color_class = torch.tensor([0, 1]).float()
944
+ color_task_embeddings = torch.stack([color_class]*self.num_views, dim=0) # (Nv, 2)
945
+
946
+ normal_class = torch.tensor([1, 0]).float()
947
+ normal_task_embeddings = torch.stack([normal_class]*self.num_views, dim=0) # (Nv, 2)
948
+
949
+ depth_class = torch.tensor([1, 1]).float()
950
+ depth_task_embeddings = torch.stack([depth_class]*self.num_views, dim=0)
951
+
952
+ if read_color:
953
+ task_embeddings = color_task_embeddings
954
+ # img_tensors_out = depth_tensors_out
955
+ elif read_normal:
956
+ task_embeddings = normal_task_embeddings
957
+ # img_tensors_out = normal_tensors_out
958
+ elif read_depth:
959
+ task_embeddings = depth_task_embeddings
960
+ # img_tensors_out = depth_tensors_out
961
+
962
+ return {
963
+ 'elevations_cond': elevations_cond,
964
+ 'elevations_cond_deg': torch.rad2deg(elevations_cond),
965
+ 'elevations': elevations,
966
+ 'azimuths': azimuths,
967
+ 'elevations_deg': torch.rad2deg(elevations),
968
+ 'azimuths_deg': torch.rad2deg(azimuths),
969
+ 'imgs_in': img_tensors_in,
970
+ 'imgs_out': img_tensors_out,
971
+ 'normals_out': normal_tensors_out,
972
+ 'depth_out': depth_tensors_out,
973
+ 'camera_embeddings': camera_embeddings,
974
+ 'task_embeddings': task_embeddings,
975
+ }
976
+
977
+ def __getitem_image_normal_mixed__(self, index, debug_object=None):
978
+ if debug_object is not None:
979
+ object_name = debug_object #
980
+ set_idx = random.sample(range(0, self.groups_num), 1)[0] # without replacement
981
+ else:
982
+ object_name = self.all_objects[index%len(self.all_objects)]
983
+ set_idx = 0
984
+
985
+ if self.augment_data:
986
+ cond_view = random.sample(self.view_types, k=1)[0]
987
+ else:
988
+ cond_view = 'front'
989
+
990
+ assert self.pred_ortho or self.pred_persp
991
+ if self.pred_ortho and self.pred_persp:
992
+ if random.random() < 0.5:
993
+ load_dir = self.root_dir_ortho
994
+ load_cam_type = 'ortho'
995
+ else:
996
+ load_dir = self.root_dir_persp
997
+ load_cam_type = 'persp'
998
+ elif self.pred_ortho and not self.pred_persp:
999
+ load_dir = self.root_dir_ortho
1000
+ load_cam_type = 'ortho'
1001
+ elif self.pred_persp and not self.pred_ortho:
1002
+ load_dir = self.root_dir_persp
1003
+ load_cam_type = 'persp'
1004
+
1005
+ view_types = self.view_types
1006
+
1007
+ cond_w2c = self.fix_cam_poses[cond_view]
1008
+
1009
+ tgt_w2cs = [self.fix_cam_poses[view] for view in view_types]
1010
+
1011
+ elevations = []
1012
+ azimuths = []
1013
+
1014
+ # get the bg color
1015
+ bg_color = self.get_bg_color()
1016
+
1017
+ # get crop size for each mv instance:
1018
+ center_crop_size = 0
1019
+ for view in view_types:
1020
+ img_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "rgb_%03d_%s.%s" % (set_idx, view, self.suffix))
1021
+
1022
+ img = Image.open(img_path)
1023
+ img = img.resize([512,512])
1024
+ img = np.array(img).astype(np.float32) / 255. # [0, 1]
1025
+
1026
+ max_w_h = self.cal_single_view_crop(img)
1027
+ center_crop_size = max(center_crop_size, max_w_h)
1028
+
1029
+ center_crop_size = center_crop_size * 4. / 3.
1030
+ center_crop_size = center_crop_size + (random.random()-0.5) * 10.
1031
+
1032
+ if self.read_mask:
1033
+ cond_alpha = self.load_mask(os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "mask_%03d_%s.%s" % (set_idx, cond_view, self.suffix)), return_type='np')
1034
+ else:
1035
+ cond_alpha = None
1036
+ # img_tensors_in = [
1037
+ # self.load_image(os.path.join(self.root_dir, object_name[:self.subscene_tag], object_name, "rgb_%03d_%s.%s" % (set_idx, cond_view, self.suffix)), bg_color, cond_alpha, return_type='pt').permute(2, 0, 1)
1038
+ # ] * self.num_views
1039
+ img_tensors_out = []
1040
+ normal_tensors_out = []
1041
+ depth_tensors_out = []
1042
+
1043
+ random_select = random.random()
1044
+ read_color, read_normal = [random_select < 1 / 2, 1 / 2 <= random_select <= 1]
1045
+ # print(read_color, read_normal, read_depth)
1046
+
1047
+ assert sum([read_color, read_normal]) == 1, "Only one variable should be True"
1048
+
1049
+ for view, tgt_w2c in zip(view_types, tgt_w2cs):
1050
+ img_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "rgb_%03d_%s.%s" % (set_idx, view, self.suffix))
1051
+ mask_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "mask_%03d_%s.%s" % (set_idx, view, self.suffix))
1052
+ depth_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "depth_%03d_%s.%s" % (set_idx, view, self.suffix))
1053
+
1054
+ if self.read_mask:
1055
+ alpha = self.load_mask(mask_path, return_type='np')
1056
+ else:
1057
+ alpha = None
1058
+
1059
+ if read_color:
1060
+ img_tensor = self.load_image(img_path, bg_color, alpha, return_type="pt", camera_type=load_cam_type, read_depth=False, center_crop_size=center_crop_size)
1061
+ img_tensor = img_tensor.permute(2, 0, 1)
1062
+ img_tensors_out.append(img_tensor)
1063
+
1064
+ if read_normal:
1065
+ normal_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "normals_%03d_%s.%s" % (set_idx, view, self.suffix))
1066
+ normal_tensor = self.load_normal(normal_path, bg_color, alpha, RT_w2c=tgt_w2c, RT_w2c_cond=cond_w2c, return_type="pt", camera_type=load_cam_type, center_crop_size=center_crop_size).permute(2, 0, 1)
1067
+ img_tensors_out.append(normal_tensor)
1068
+
1069
+ # if read_depth:
1070
+ # if alpha is None:
1071
+ # alpha = self.load_mask_from_rgba(img_path, camera_type=load_cam_type)
1072
+ # depth_tensor = self.load_depth(depth_path, bg_color, alpha, return_type="pt", camera_type=load_cam_type).permute(2, 0, 1)
1073
+ # img_tensors_out.append(depth_tensor)
1074
+
1075
+
1076
+ # evelations, azimuths
1077
+ elevation, azimuth = self.get_T(tgt_w2c, cond_w2c)
1078
+ elevations.append(elevation)
1079
+ azimuths.append(azimuth)
1080
+
1081
+ img_tensors_in = [
1082
+ self.load_image(os.path.join(load_dir, object_name[:self.subscene_tag], object_name,
1083
+ "rgb_%03d_%s.%s" % (set_idx, cond_view, self.suffix)),
1084
+ bg_color, cond_alpha, return_type='pt', camera_type=load_cam_type, read_depth=False, center_crop_size=center_crop_size).permute(
1085
+ 2, 0, 1)
1086
+ ] * self.num_views
1087
+ img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
1088
+ # if self.read_color:
1089
+ # img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
1090
+ # if self.read_normal:
1091
+ # normal_tensors_out = torch.stack(normal_tensors_out, dim=0).float() # (Nv, 3, H, W)
1092
+ # if self.read_depth:
1093
+ # depth_tensors_out = torch.stack(depth_tensors_out, dim=0).float() # (Nv, 3, H, W)
1094
+ img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
1095
+ elevations = torch.as_tensor(elevations).float().squeeze(1)
1096
+ azimuths = torch.as_tensor(azimuths).float().squeeze(1)
1097
+ elevations_cond = torch.as_tensor([0] * self.num_views).float() # fixed only use 4 views to train
1098
+
1099
+ if load_cam_type == 'ortho':
1100
+ cam_type_emb = torch.tensor([0, 1]).expand(self.num_views, -1)
1101
+ else:
1102
+ cam_type_emb = torch.tensor([1, 0]).expand(self.num_views, -1)
1103
+ camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1)
1104
+ # if self.pred_ortho and self.pred_persp:
1105
+ if self.load_cam_type:
1106
+ camera_embeddings = torch.cat((camera_embeddings, cam_type_emb), dim=-1) # (Nv, 5)
1107
+
1108
+ color_class = torch.tensor([0, 1]).float()
1109
+ color_task_embeddings = torch.stack([color_class]*self.num_views, dim=0) # (Nv, 2)
1110
+
1111
+ normal_class = torch.tensor([1, 0]).float()
1112
+ normal_task_embeddings = torch.stack([normal_class]*self.num_views, dim=0) # (Nv, 2)
1113
+
1114
+ # depth_class = torch.tensor([1, 1]).float()
1115
+ # depth_task_embeddings = torch.stack([depth_class]*self.num_views, dim=0)
1116
+
1117
+ if read_color:
1118
+ task_embeddings = color_task_embeddings
1119
+ # img_tensors_out = depth_tensors_out
1120
+ elif read_normal:
1121
+ task_embeddings = normal_task_embeddings
1122
+ # img_tensors_out = normal_tensors_out
1123
+ # elif read_depth:
1124
+ # task_embeddings = depth_task_embeddings
1125
+ # img_tensors_out = depth_tensors_out
1126
+
1127
+ return {
1128
+ 'elevations_cond': elevations_cond,
1129
+ 'elevations_cond_deg': torch.rad2deg(elevations_cond),
1130
+ 'elevations': elevations,
1131
+ 'azimuths': azimuths,
1132
+ 'elevations_deg': torch.rad2deg(elevations),
1133
+ 'azimuths_deg': torch.rad2deg(azimuths),
1134
+ 'imgs_in': img_tensors_in,
1135
+ 'imgs_out': img_tensors_out,
1136
+ 'normals_out': normal_tensors_out,
1137
+ 'depth_out': depth_tensors_out,
1138
+ 'camera_embeddings': camera_embeddings,
1139
+ 'task_embeddings': task_embeddings,
1140
+ }
1141
+
1142
+ def cal_single_view_crop(self, image):
1143
+ assert np.shape(image)[-1] == 4 # RGBA
1144
+
1145
+ # Extract the alpha channel (transparency) and the object (RGB channels)
1146
+ alpha_channel = image[:, :, 3]
1147
+
1148
+ # Find the bounding box coordinates of the object
1149
+ coords = cv2.findNonZero(alpha_channel)
1150
+ x, y, width, height = cv2.boundingRect(coords)
1151
+
1152
+ return max(width, height)
1153
+
1154
+ def __getitem_joint_rgb_noraml__(self, index, debug_object=None):
1155
+ if debug_object is not None:
1156
+ object_name = debug_object #
1157
+ set_idx = random.sample(range(0, self.groups_num), 1)[0] # without replacement
1158
+ else:
1159
+ object_name = self.all_objects[index%len(self.all_objects)]
1160
+ set_idx = 0
1161
+
1162
+ if self.augment_data:
1163
+ cond_view = random.sample(self.view_types, k=1)[0]
1164
+ else:
1165
+ cond_view = 'front'
1166
+
1167
+ assert self.pred_ortho or self.pred_persp
1168
+ if self.pred_ortho and self.pred_persp:
1169
+ if random.random() < 0.5:
1170
+ load_dir = self.root_dir_ortho
1171
+ load_cam_type = 'ortho'
1172
+ else:
1173
+ load_dir = self.root_dir_persp
1174
+ load_cam_type = 'persp'
1175
+ elif self.pred_ortho and not self.pred_persp:
1176
+ load_dir = self.root_dir_ortho
1177
+ load_cam_type = 'ortho'
1178
+ elif self.pred_persp and not self.pred_ortho:
1179
+ load_dir = self.root_dir_persp
1180
+ load_cam_type = 'persp'
1181
+
1182
+ view_types = self.view_types
1183
+
1184
+ cond_w2c = self.fix_cam_poses[cond_view]
1185
+
1186
+ tgt_w2cs = [self.fix_cam_poses[view] for view in view_types]
1187
+
1188
+ elevations = []
1189
+ azimuths = []
1190
+
1191
+ # get the bg color
1192
+ bg_color = self.get_bg_color()
1193
+
1194
+ if self.read_mask:
1195
+ cond_alpha = self.load_mask(os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "mask_%03d_%s.%s" % (set_idx, cond_view, self.suffix)), return_type='np')
1196
+ else:
1197
+ cond_alpha = None
1198
+
1199
+ img_tensors_out = []
1200
+ normal_tensors_out = []
1201
+
1202
+
1203
+ read_color, read_normal = True, True
1204
+ # print(read_color, read_normal, read_depth)
1205
+
1206
+ # get crop size for each mv instance:
1207
+ center_crop_size = 0
1208
+ for view in view_types:
1209
+ img_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "rgb_%03d_%s.%s" % (set_idx, view, self.suffix))
1210
+
1211
+ img = Image.open(img_path)
1212
+ img = img.resize([512,512])
1213
+ img = np.array(img).astype(np.float32) / 255. # [0, 1]
1214
+
1215
+ max_w_h = self.cal_single_view_crop(img)
1216
+ center_crop_size = max(center_crop_size, max_w_h)
1217
+
1218
+ center_crop_size = center_crop_size * 4. / 3.
1219
+ center_crop_size = center_crop_size + (random.random()-0.5) * 10.
1220
+
1221
+
1222
+
1223
+ for view, tgt_w2c in zip(view_types, tgt_w2cs):
1224
+ img_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "rgb_%03d_%s.%s" % (set_idx, view, self.suffix))
1225
+ mask_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "mask_%03d_%s.%s" % (set_idx, view, self.suffix))
1226
+ depth_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "depth_%03d_%s.%s" % (set_idx, view, self.suffix))
1227
+
1228
+ if self.read_mask:
1229
+ alpha = self.load_mask(mask_path, return_type='np')
1230
+ else:
1231
+ alpha = None
1232
+
1233
+ if read_color:
1234
+ img_tensor = self.load_image(img_path, bg_color, alpha, return_type="pt", camera_type=load_cam_type, read_depth=False, center_crop_size=center_crop_size)
1235
+ img_tensor = img_tensor.permute(2, 0, 1)
1236
+ img_tensors_out.append(img_tensor)
1237
+
1238
+ if read_normal:
1239
+ normal_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "normals_%03d_%s.%s" % (set_idx, view, self.suffix))
1240
+ normal_tensor = self.load_normal(normal_path, bg_color, alpha, RT_w2c=tgt_w2c, RT_w2c_cond=cond_w2c, return_type="pt", camera_type=load_cam_type, center_crop_size=center_crop_size).permute(2, 0, 1)
1241
+ normal_tensors_out.append(normal_tensor)
1242
+
1243
+ # evelations, azimuths
1244
+ elevation, azimuth = self.get_T(tgt_w2c, cond_w2c)
1245
+ elevations.append(elevation)
1246
+ azimuths.append(azimuth)
1247
+
1248
+ if self.load_switcher: # rgb input, use domain switcher to control the output type
1249
+ img_tensors_in = [
1250
+ self.load_image(os.path.join(load_dir, object_name[:self.subscene_tag], object_name,
1251
+ "rgb_%03d_%s.%s" % (set_idx, cond_view, self.suffix)),
1252
+ bg_color, cond_alpha, return_type='pt', camera_type=load_cam_type,
1253
+ read_depth=False, center_crop_size=center_crop_size).permute(
1254
+ 2, 0, 1)
1255
+ ] * self.num_views
1256
+
1257
+ color_class = torch.tensor([0, 1]).float()
1258
+ color_task_embeddings = torch.stack([color_class] * self.num_views, dim=0) # (Nv, 2)
1259
+
1260
+ normal_class = torch.tensor([1, 0]).float()
1261
+ normal_task_embeddings = torch.stack([normal_class] * self.num_views, dim=0) # (Nv, 2)
1262
+
1263
+
1264
+ if read_color:
1265
+ task_embeddings = color_task_embeddings
1266
+ # img_tensors_out = depth_tensors_out
1267
+ elif read_normal:
1268
+ task_embeddings = normal_task_embeddings
1269
+ # img_tensors_out = normal_tensors_out
1270
+
1271
+ else: # for stage 1 training, the input and the output are in the same domain
1272
+ img_tensors_in = [img_tensors_out[0]] * self.num_views
1273
+
1274
+ empty_class = torch.tensor([0, 0]).float() # empty task
1275
+ empty_task_embeddings = torch.stack([empty_class] * self.num_views, dim=0)
1276
+ task_embeddings = empty_task_embeddings
1277
+
1278
+ img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
1279
+
1280
+ img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
1281
+ normal_tensors_out = torch.stack(normal_tensors_out, dim=0).float() # (Nv, 3, H, W)
1282
+ elevations = torch.as_tensor(elevations).float().squeeze(1)
1283
+ azimuths = torch.as_tensor(azimuths).float().squeeze(1)
1284
+ elevations_cond = torch.as_tensor([0] * self.num_views).float() # fixed only use 4 views to train
1285
+
1286
+ if load_cam_type == 'ortho':
1287
+ cam_type_emb = torch.tensor([0, 1]).expand(self.num_views, -1)
1288
+ else:
1289
+ cam_type_emb = torch.tensor([1, 0]).expand(self.num_views, -1)
1290
+
1291
+ camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1)
1292
+
1293
+ if self.load_cam_type:
1294
+ camera_embeddings = torch.cat((camera_embeddings, cam_type_emb), dim=-1) # (Nv, 5)
1295
+
1296
+ return {
1297
+ 'elevations_cond': elevations_cond,
1298
+ 'elevations_cond_deg': torch.rad2deg(elevations_cond),
1299
+ 'elevations': elevations,
1300
+ 'azimuths': azimuths,
1301
+ 'elevations_deg': torch.rad2deg(elevations),
1302
+ 'azimuths_deg': torch.rad2deg(azimuths),
1303
+ 'imgs_in': img_tensors_in,
1304
+ 'imgs_out': img_tensors_out,
1305
+ 'normals_out': normal_tensors_out,
1306
+ 'camera_embeddings': camera_embeddings,
1307
+ 'color_task_embeddings': color_task_embeddings,
1308
+ 'normal_task_embeddings': normal_task_embeddings
1309
+ }
1310
+
1311
+ def __getitem__(self, index):
1312
+ try:
1313
+ if self.pred_type == 'color':
1314
+ data = self.backup_data = self.__getitem_color__(index)
1315
+ elif self.pred_type == 'normal_depth':
1316
+ data = self.backup_data = self.__getitem_normal_depth__(index)
1317
+ elif self.pred_type == 'mixed_rgb_normal_depth':
1318
+ data = self.backup_data = self.__getitem_mixed__(index)
1319
+ elif self.pred_type == 'mixed_color_normal':
1320
+ data = self.backup_data = self.__getitem_image_normal_mixed__(index)
1321
+ elif self.pred_type == 'mixed_rgb_noraml_mask':
1322
+ data = self.backup_data = self.__getitem_mixed_rgb_noraml_mask__(index)
1323
+ elif self.pred_type == 'joint_color_normal':
1324
+ data = self.backup_data = self.__getitem_joint_rgb_noraml__(index)
1325
+ return data
1326
+
1327
+ except:
1328
+ print("load error ", self.all_objects[index%len(self.all_objects)])
1329
+ return self.backup_data
1330
+
1331
+ class ConcatDataset(torch.utils.data.Dataset):
1332
+ def __init__(self, datasets, weights):
1333
+ self.datasets = datasets
1334
+ self.weights = weights
1335
+ self.num_datasets = len(datasets)
1336
+
1337
+ def __getitem__(self, i):
1338
+
1339
+ chosen = random.choices(self.datasets, self.weights, k=1)[0]
1340
+ return chosen[i]
1341
+
1342
+ def __len__(self):
1343
+ return max(len(d) for d in self.datasets)
1344
+
1345
+ if __name__ == "__main__":
1346
+ train_dataset = ObjaverseDataset(
1347
+ root_dir="/ghome/l5/xxlong/.objaverse/hf-objaverse-v1/renderings",
1348
+ size=(128, 128),
1349
+ ext="hdf5",
1350
+ default_trans=torch.zeros(3),
1351
+ return_paths=False,
1352
+ total_view=8,
1353
+ validation=False,
1354
+ object_list=None,
1355
+ views_mode='fourviews'
1356
+ )
1357
+ data0 = train_dataset[0]
1358
+ data1 = train_dataset[50]
1359
+ # print(data)
mv_diffusion_30/data/single_image_dataset.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import numpy as np
3
+ from omegaconf import DictConfig, ListConfig
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from pathlib import Path
7
+ import json
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+ from einops import rearrange
11
+ from typing import Literal, Tuple, Optional, Any
12
+ import cv2
13
+ import random
14
+
15
+ import json
16
+ import os, sys
17
+ import math
18
+
19
+ from glob import glob
20
+
21
+ import PIL.Image
22
+ from .normal_utils import trans_normal, normal2img, img2normal
23
+ import pdb
24
+ from rembg import remove
25
+
26
+ import cv2
27
+ import numpy as np
28
+
29
+
30
+ def add_margin(pil_img, color=0, size=256):
31
+ width, height = pil_img.size
32
+ result = Image.new(pil_img.mode, (size, size), color)
33
+ result.paste(pil_img, ((size - width) // 2, (size - height) // 2))
34
+ return result
35
+
36
+
37
+ def scale_and_place_object(image, scale_factor):
38
+ assert np.shape(image)[-1] == 4 # RGBA
39
+
40
+ # Extract the alpha channel (transparency) and the object (RGB channels)
41
+ alpha_channel = image[:, :, 3]
42
+
43
+ # Find the bounding box coordinates of the object
44
+ coords = cv2.findNonZero(alpha_channel)
45
+ x, y, width, height = cv2.boundingRect(coords)
46
+
47
+ # Calculate the scale factor for resizing
48
+ original_height, original_width = image.shape[:2]
49
+
50
+ if width > height:
51
+ size = width
52
+ original_size = original_width
53
+ else:
54
+ size = height
55
+ original_size = original_height
56
+
57
+ scale_factor = min(scale_factor, size / (original_size + 0.0))
58
+
59
+ new_size = scale_factor * original_size
60
+ scale_factor = new_size / size
61
+
62
+ # Calculate the new size based on the scale factor
63
+ new_width = int(width * scale_factor)
64
+ new_height = int(height * scale_factor)
65
+
66
+ center_x = original_width // 2
67
+ center_y = original_height // 2
68
+
69
+ paste_x = center_x - (new_width // 2)
70
+ paste_y = center_y - (new_height // 2)
71
+
72
+ # Resize the object (RGB channels) to the new size
73
+ rescaled_object = cv2.resize(image[y:y + height, x:x + width], (new_width, new_height))
74
+
75
+ # Create a new RGBA image with the resized image
76
+ new_image = np.zeros((original_height, original_width, 4), dtype=np.uint8)
77
+
78
+ new_image[paste_y:paste_y + new_height, paste_x:paste_x + new_width] = rescaled_object
79
+
80
+ return new_image
81
+
82
+
83
+ class SingleImageDataset(Dataset):
84
+ def __init__(self,
85
+ root_dir: str = None,
86
+ num_views: int =6,
87
+ img_wh: Tuple[int, int] =[256,256],
88
+ bg_color: str ='white',
89
+ crop_size: int = 224,
90
+ single_image: Optional[PIL.Image.Image] = None,
91
+ num_validation_samples: Optional[int] = None,
92
+ filepaths: Optional[list] = None,
93
+ cam_types: Optional[list] = None,
94
+ cond_type: Optional[str] = None,
95
+ load_cam_type: Optional[bool] = True
96
+ ) -> None:
97
+ """Create a dataset from a folder of images.
98
+ If you pass in a root directory it will be searched for images
99
+ ending in ext (ext can be a list)
100
+ """
101
+ self.root_dir = root_dir
102
+ self.num_views = num_views
103
+ self.img_wh = img_wh
104
+ self.crop_size = crop_size
105
+ self.bg_color = bg_color
106
+ self.cond_type = cond_type
107
+ self.load_cam_type = load_cam_type
108
+ self.cam_types = cam_types
109
+
110
+ if self.num_views == 4:
111
+ self.view_types = ['front', 'right', 'back', 'left']
112
+ elif self.num_views == 5:
113
+ self.view_types = ['front', 'front_right', 'right', 'back', 'left']
114
+ elif self.num_views == 6:
115
+ self.view_types = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
116
+
117
+ self.fix_cam_pose_dir = "./mv_diffusion_30/data/fixed_poses/nine_views"
118
+
119
+ self.fix_cam_poses = self.load_fixed_poses() # world2cam matrix
120
+
121
+ if single_image is None:
122
+ if filepaths is None:
123
+ # Get a list of all files in the directory
124
+ file_list = os.listdir(self.root_dir)
125
+ self.cam_types = ['ortho'] * len(file_list) + ['persp'] * len(file_list)
126
+ file_list = file_list * 2
127
+ else:
128
+ file_list = filepaths
129
+
130
+ # Filter the files that end with .png or .jpg
131
+ self.file_list = [file for file in file_list if file.endswith(('.png', '.jpg'))]
132
+ else:
133
+ self.file_list = None
134
+
135
+ # load all images
136
+ self.all_images = []
137
+ self.all_alphas = []
138
+ bg_color = self.get_bg_color()
139
+
140
+ if single_image is not None:
141
+ image, alpha = self.load_image(None, bg_color, return_type='pt', Imagefile=single_image)
142
+ self.all_images.append(image)
143
+ self.all_alphas.append(alpha)
144
+ else:
145
+ for file in self.file_list:
146
+ print(os.path.join(self.root_dir, file))
147
+ image, alpha = self.load_image(os.path.join(self.root_dir, file), bg_color, return_type='pt')
148
+ self.all_images.append(image)
149
+ self.all_alphas.append(alpha)
150
+ #
151
+ # assert len(self.file_list) == len(self.cam_types)
152
+ self.all_images = self.all_images[:num_validation_samples]
153
+ self.all_alphas = self.all_alphas[:num_validation_samples]
154
+
155
+ def __len__(self):
156
+ return len(self.all_images)
157
+
158
+ def load_fixed_poses(self):
159
+ poses = {}
160
+ for face in self.view_types:
161
+ RT = np.loadtxt(os.path.join(self.fix_cam_pose_dir, '%03d_%s_RT.txt' % (0, face)))
162
+ poses[face] = RT
163
+
164
+ return poses
165
+
166
+ def cartesian_to_spherical(self, xyz):
167
+ ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
168
+ xy = xyz[:, 0] ** 2 + xyz[:, 1] ** 2
169
+ z = np.sqrt(xy + xyz[:, 2] ** 2)
170
+ theta = np.arctan2(np.sqrt(xy), xyz[:, 2]) # for elevation angle defined from Z-axis down
171
+ # ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
172
+ azimuth = np.arctan2(xyz[:, 1], xyz[:, 0])
173
+ return np.array([theta, azimuth, z])
174
+
175
+ def get_T(self, target_RT, cond_RT):
176
+ R, T = target_RT[:3, :3], target_RT[:, -1]
177
+ T_target = -R.T @ T # change to cam2world
178
+
179
+ R, T = cond_RT[:3, :3], cond_RT[:, -1]
180
+ T_cond = -R.T @ T
181
+
182
+ theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :])
183
+ theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :])
184
+
185
+ d_theta = theta_target - theta_cond
186
+ d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
187
+ d_z = z_target - z_cond
188
+
189
+ # d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()])
190
+ return d_theta, d_azimuth
191
+
192
+ def get_bg_color(self):
193
+ if self.bg_color == 'white':
194
+ bg_color = np.array([1., 1., 1.], dtype=np.float32)
195
+ elif self.bg_color == 'black':
196
+ bg_color = np.array([0., 0., 0.], dtype=np.float32)
197
+ elif self.bg_color == 'gray':
198
+ bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
199
+ elif self.bg_color == 'random':
200
+ bg_color = np.random.rand(3)
201
+ elif isinstance(self.bg_color, float):
202
+ bg_color = np.array([self.bg_color] * 3, dtype=np.float32)
203
+ else:
204
+ raise NotImplementedError
205
+ return bg_color
206
+
207
+ def load_image(self, img_path, bg_color, return_type='np', Imagefile=None):
208
+ # pil always returns uint8
209
+ if Imagefile is None:
210
+ image_input = Image.open(img_path)
211
+ else:
212
+ image_input = Imagefile
213
+ image_size = self.img_wh[0]
214
+
215
+ image_input = image_input.resize((image_size, image_size))
216
+ img = np.array(image_input)
217
+ img = img.astype(np.float32) / 255. # [0, 1]
218
+ assert img.shape[-1] == 4 # RGBA
219
+
220
+ alpha = img[..., 3:4]
221
+
222
+ img = img[..., :3] * alpha + bg_color * (1 - alpha)
223
+ img = np.clip(img, 0, 1)
224
+
225
+ if return_type == "np":
226
+ pass
227
+ elif return_type == "pt":
228
+ img = torch.from_numpy(img)
229
+ alpha = torch.from_numpy(alpha)
230
+ else:
231
+ raise NotImplementedError
232
+
233
+ return img, alpha
234
+
235
+ def __len__(self):
236
+ return len(self.all_images)
237
+
238
+ def __getitem__(self, index):
239
+
240
+ image = self.all_images[index % len(self.all_images)]
241
+ alpha = self.all_alphas[index % len(self.all_images)]
242
+ if self.load_cam_type:
243
+ cam_type = self.cam_types[index % len(self.all_images)]
244
+ else:
245
+ cam_type = 'ortho'
246
+ if self.file_list is not None:
247
+ filename = self.file_list[index % len(self.all_images)].replace(".png", "")
248
+ else:
249
+ filename = 'null'
250
+
251
+ cond_w2c = self.fix_cam_poses['front']
252
+
253
+ tgt_w2cs = [self.fix_cam_poses[view] for view in self.view_types]
254
+
255
+ elevations = []
256
+ azimuths = []
257
+
258
+ img_tensors_in = [
259
+ image.permute(2, 0, 1)
260
+ ] * self.num_views
261
+
262
+ alpha_tensors_in = [
263
+ alpha.permute(2, 0, 1)
264
+ ] * self.num_views
265
+
266
+ for view, tgt_w2c in zip(self.view_types, tgt_w2cs):
267
+ # evelations, azimuths
268
+ elevation, azimuth = self.get_T(tgt_w2c, cond_w2c)
269
+ elevations.append(elevation)
270
+ azimuths.append(azimuth)
271
+
272
+ img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
273
+ alpha_tensors_in = torch.stack(alpha_tensors_in, dim=0).float() # (Nv, 3, H, W)
274
+
275
+ elevations = torch.as_tensor(elevations).float().squeeze(1)
276
+ azimuths = torch.as_tensor(azimuths).float().squeeze(1)
277
+ elevations_cond = torch.as_tensor([0] * self.num_views).float()
278
+
279
+ normal_class = torch.tensor([1, 0]).float()
280
+ normal_task_embeddings = torch.stack([normal_class] * self.num_views, dim=0) # (Nv, 2)
281
+ color_class = torch.tensor([0, 1]).float()
282
+ color_task_embeddings = torch.stack([color_class] * self.num_views, dim=0) # (Nv, 2)
283
+ depth_class = torch.tensor([1, 1]).float()
284
+ depth_task_embeddings = torch.stack([depth_class] * self.num_views, dim=0)
285
+
286
+ camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1) # (Nv, 3)
287
+
288
+ print("camera type:", cam_type)
289
+ if cam_type == 'ortho':
290
+ cam_type_emb = torch.tensor([0, 1]).expand(self.num_views, -1)
291
+ else:
292
+ cam_type_emb = torch.tensor([1, 0]).expand(self.num_views, -1)
293
+
294
+ if self.load_cam_type:
295
+ camera_embeddings = torch.cat((camera_embeddings, cam_type_emb), dim=-1) # (Nv, 5)
296
+
297
+ out = {
298
+ 'elevations_cond': elevations_cond,
299
+ 'elevations_cond_deg': torch.rad2deg(elevations_cond),
300
+ 'elevations': elevations,
301
+ 'azimuths': azimuths,
302
+ 'elevations_deg': torch.rad2deg(elevations),
303
+ 'azimuths_deg': torch.rad2deg(azimuths),
304
+ 'imgs_in': img_tensors_in,
305
+ 'alphas': alpha_tensors_in,
306
+ 'camera_embeddings': camera_embeddings,
307
+ 'normal_task_embeddings': normal_task_embeddings,
308
+ 'color_task_embeddings': color_task_embeddings,
309
+ 'depth_task_embeddings': depth_task_embeddings,
310
+ 'filename': filename,
311
+ 'cam_type': cam_type
312
+ }
313
+
314
+ return out
315
+
316
+
mv_diffusion_30/models/transformer_mv2d.py ADDED
@@ -0,0 +1,1093 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+ # from torch.nn.attention import SDPBackend, sdpa_kernel
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
24
+ from diffusers.utils import BaseOutput, deprecate
25
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
26
+ from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention
27
+ from diffusers.models.embeddings import PatchEmbed
28
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
29
+ from diffusers.models.modeling_utils import ModelMixin
30
+ from diffusers.utils.import_utils import is_xformers_available
31
+
32
+ from einops import rearrange, repeat
33
+ import pdb
34
+ import random
35
+
36
+
37
+ # if is_xformers_available():
38
+ # import xformers
39
+ # import xformers.ops
40
+ # else:
41
+ # xformers = None
42
+
43
+ def my_repeat(tensor, num_repeats):
44
+ """
45
+ Repeat a tensor along a given dimension
46
+ """
47
+ if len(tensor.shape) == 3:
48
+ return repeat(tensor, "b d c -> (b v) d c", v=num_repeats)
49
+ elif len(tensor.shape) == 4:
50
+ return repeat(tensor, "a b d c -> (a v) b d c", v=num_repeats)
51
+
52
+
53
+ @dataclass
54
+ class TransformerMV2DModelOutput(BaseOutput):
55
+ """
56
+ The output of [`Transformer2DModel`].
57
+
58
+ Args:
59
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
60
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
61
+ distributions for the unnoised latent pixels.
62
+ """
63
+
64
+ sample: torch.FloatTensor
65
+
66
+
67
+ class TransformerMV2DModel(ModelMixin, ConfigMixin):
68
+ """
69
+ A 2D Transformer model for image-like data.
70
+
71
+ Parameters:
72
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
73
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
74
+ in_channels (`int`, *optional*):
75
+ The number of channels in the input and output (specify if the input is **continuous**).
76
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
77
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
78
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
79
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
80
+ This is fixed during training since it is used to learn a number of position embeddings.
81
+ num_vector_embeds (`int`, *optional*):
82
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
83
+ Includes the class for the masked latent pixel.
84
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
85
+ num_embeds_ada_norm ( `int`, *optional*):
86
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
87
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
88
+ added to the hidden states.
89
+
90
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
91
+ attention_bias (`bool`, *optional*):
92
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
93
+ """
94
+
95
+ @register_to_config
96
+ def __init__(
97
+ self,
98
+ num_attention_heads: int = 16,
99
+ attention_head_dim: int = 88,
100
+ in_channels: Optional[int] = None,
101
+ out_channels: Optional[int] = None,
102
+ num_layers: int = 1,
103
+ dropout: float = 0.0,
104
+ norm_num_groups: int = 32,
105
+ cross_attention_dim: Optional[int] = None,
106
+ attention_bias: bool = False,
107
+ sample_size: Optional[int] = None,
108
+ num_vector_embeds: Optional[int] = None,
109
+ patch_size: Optional[int] = None,
110
+ activation_fn: str = "geglu",
111
+ num_embeds_ada_norm: Optional[int] = None,
112
+ use_linear_projection: bool = False,
113
+ only_cross_attention: bool = False,
114
+ upcast_attention: bool = False,
115
+ norm_type: str = "layer_norm",
116
+ norm_elementwise_affine: bool = True,
117
+ num_views: int = 1,
118
+ cd_attention_last: bool=False,
119
+ cd_attention_mid: bool=False,
120
+ multiview_attention: bool=True,
121
+ sparse_mv_attention: bool = False,
122
+ mvcd_attention: bool=False
123
+ ):
124
+ super().__init__()
125
+ self.use_linear_projection = use_linear_projection
126
+ self.num_attention_heads = num_attention_heads
127
+ self.attention_head_dim = attention_head_dim
128
+ inner_dim = num_attention_heads * attention_head_dim
129
+
130
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
131
+ # Define whether input is continuous or discrete depending on configuration
132
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
133
+ self.is_input_vectorized = num_vector_embeds is not None
134
+ self.is_input_patches = in_channels is not None and patch_size is not None
135
+
136
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
137
+ deprecation_message = (
138
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
139
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
140
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
141
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
142
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
143
+ )
144
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
145
+ norm_type = "ada_norm"
146
+
147
+ if self.is_input_continuous and self.is_input_vectorized:
148
+ raise ValueError(
149
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
150
+ " sure that either `in_channels` or `num_vector_embeds` is None."
151
+ )
152
+ elif self.is_input_vectorized and self.is_input_patches:
153
+ raise ValueError(
154
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
155
+ " sure that either `num_vector_embeds` or `num_patches` is None."
156
+ )
157
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
158
+ raise ValueError(
159
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
160
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
161
+ )
162
+
163
+ # 2. Define input layers
164
+ if self.is_input_continuous:
165
+ self.in_channels = in_channels
166
+
167
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
168
+ if use_linear_projection:
169
+ self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
170
+ else:
171
+ self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
172
+ elif self.is_input_vectorized:
173
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
174
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
175
+
176
+ self.height = sample_size
177
+ self.width = sample_size
178
+ self.num_vector_embeds = num_vector_embeds
179
+ self.num_latent_pixels = self.height * self.width
180
+
181
+ self.latent_image_embedding = ImagePositionalEmbeddings(
182
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
183
+ )
184
+ elif self.is_input_patches:
185
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
186
+
187
+ self.height = sample_size
188
+ self.width = sample_size
189
+
190
+ self.patch_size = patch_size
191
+ self.pos_embed = PatchEmbed(
192
+ height=sample_size,
193
+ width=sample_size,
194
+ patch_size=patch_size,
195
+ in_channels=in_channels,
196
+ embed_dim=inner_dim,
197
+ )
198
+
199
+ # 3. Define transformers blocks
200
+ self.transformer_blocks = nn.ModuleList(
201
+ [
202
+ BasicMVTransformerBlock(
203
+ inner_dim,
204
+ num_attention_heads,
205
+ attention_head_dim,
206
+ dropout=dropout,
207
+ cross_attention_dim=cross_attention_dim,
208
+ activation_fn=activation_fn,
209
+ num_embeds_ada_norm=num_embeds_ada_norm,
210
+ attention_bias=attention_bias,
211
+ only_cross_attention=only_cross_attention,
212
+ upcast_attention=upcast_attention,
213
+ norm_type=norm_type,
214
+ norm_elementwise_affine=norm_elementwise_affine,
215
+ num_views=num_views,
216
+ cd_attention_last=cd_attention_last,
217
+ cd_attention_mid=cd_attention_mid,
218
+ multiview_attention=multiview_attention,
219
+ sparse_mv_attention=sparse_mv_attention,
220
+ mvcd_attention=mvcd_attention
221
+ )
222
+ for d in range(num_layers)
223
+ ]
224
+ )
225
+
226
+ # 4. Define output layers
227
+ self.out_channels = in_channels if out_channels is None else out_channels
228
+ if self.is_input_continuous:
229
+ # TODO: should use out_channels for continuous projections
230
+ if use_linear_projection:
231
+ self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
232
+ else:
233
+ self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
234
+ elif self.is_input_vectorized:
235
+ self.norm_out = nn.LayerNorm(inner_dim)
236
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
237
+ elif self.is_input_patches:
238
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
239
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
240
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
241
+
242
+ def forward(
243
+ self,
244
+ hidden_states: torch.Tensor,
245
+ encoder_hidden_states: Optional[torch.Tensor] = None,
246
+ timestep: Optional[torch.LongTensor] = None,
247
+ class_labels: Optional[torch.LongTensor] = None,
248
+ cross_attention_kwargs: Dict[str, Any] = None,
249
+ attention_mask: Optional[torch.Tensor] = None,
250
+ encoder_attention_mask: Optional[torch.Tensor] = None,
251
+ return_dict: bool = True,
252
+ ):
253
+ """
254
+ The [`Transformer2DModel`] forward method.
255
+
256
+ Args:
257
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
258
+ Input `hidden_states`.
259
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
260
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
261
+ self-attention.
262
+ timestep ( `torch.LongTensor`, *optional*):
263
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
264
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
265
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
266
+ `AdaLayerZeroNorm`.
267
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
268
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
269
+
270
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
271
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
272
+
273
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
274
+ above. This bias will be added to the cross-attention scores.
275
+ return_dict (`bool`, *optional*, defaults to `True`):
276
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
277
+ tuple.
278
+
279
+ Returns:
280
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
281
+ `tuple` where the first element is the sample tensor.
282
+ """
283
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
284
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
285
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
286
+ # expects mask of shape:
287
+ # [batch, key_tokens]
288
+ # adds singleton query_tokens dimension:
289
+ # [batch, 1, key_tokens]
290
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
291
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
292
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
293
+ if attention_mask is not None and attention_mask.ndim == 2:
294
+ # assume that mask is expressed as:
295
+ # (1 = keep, 0 = discard)
296
+ # convert mask into a bias that can be added to attention scores:
297
+ # (keep = +0, discard = -10000.0)
298
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
299
+ attention_mask = attention_mask.unsqueeze(1)
300
+
301
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
302
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
303
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
304
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
305
+
306
+ # 1. Input
307
+ if self.is_input_continuous:
308
+ batch, _, height, width = hidden_states.shape
309
+ residual = hidden_states
310
+
311
+ hidden_states = self.norm(hidden_states)
312
+ if not self.use_linear_projection:
313
+ hidden_states = self.proj_in(hidden_states)
314
+ inner_dim = hidden_states.shape[1]
315
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
316
+ else:
317
+ inner_dim = hidden_states.shape[1]
318
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
319
+ hidden_states = self.proj_in(hidden_states)
320
+ elif self.is_input_vectorized:
321
+ hidden_states = self.latent_image_embedding(hidden_states)
322
+ elif self.is_input_patches:
323
+ hidden_states = self.pos_embed(hidden_states)
324
+
325
+ # 2. Blocks
326
+ for block in self.transformer_blocks:
327
+ hidden_states = block(
328
+ hidden_states,
329
+ attention_mask=attention_mask,
330
+ encoder_hidden_states=encoder_hidden_states,
331
+ encoder_attention_mask=encoder_attention_mask,
332
+ timestep=timestep,
333
+ cross_attention_kwargs=cross_attention_kwargs,
334
+ class_labels=class_labels,
335
+ )
336
+
337
+ # 3. Output
338
+ if self.is_input_continuous:
339
+ if not self.use_linear_projection:
340
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
341
+ hidden_states = self.proj_out(hidden_states)
342
+ else:
343
+ hidden_states = self.proj_out(hidden_states)
344
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
345
+
346
+ output = hidden_states + residual
347
+ elif self.is_input_vectorized:
348
+ hidden_states = self.norm_out(hidden_states)
349
+ logits = self.out(hidden_states)
350
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
351
+ logits = logits.permute(0, 2, 1)
352
+
353
+ # log(p(x_0))
354
+ output = F.log_softmax(logits.double(), dim=1).float()
355
+ elif self.is_input_patches:
356
+ # TODO: cleanup!
357
+ conditioning = self.transformer_blocks[0].norm1.emb(
358
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
359
+ )
360
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
361
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
362
+ hidden_states = self.proj_out_2(hidden_states)
363
+
364
+ # unpatchify
365
+ height = width = int(hidden_states.shape[1] ** 0.5)
366
+ hidden_states = hidden_states.reshape(
367
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
368
+ )
369
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
370
+ output = hidden_states.reshape(
371
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
372
+ )
373
+
374
+ if not return_dict:
375
+ return (output,)
376
+
377
+ return TransformerMV2DModelOutput(sample=output)
378
+
379
+
380
+ @maybe_allow_in_graph
381
+ class BasicMVTransformerBlock(nn.Module):
382
+ r"""
383
+ A basic Transformer block.
384
+
385
+ Parameters:
386
+ dim (`int`): The number of channels in the input and output.
387
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
388
+ attention_head_dim (`int`): The number of channels in each head.
389
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
390
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
391
+ only_cross_attention (`bool`, *optional*):
392
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
393
+ double_self_attention (`bool`, *optional*):
394
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
395
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
396
+ num_embeds_ada_norm (:
397
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
398
+ attention_bias (:
399
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
400
+ """
401
+
402
+ def __init__(
403
+ self,
404
+ dim: int,
405
+ num_attention_heads: int,
406
+ attention_head_dim: int,
407
+ dropout=0.0,
408
+ cross_attention_dim: Optional[int] = None,
409
+ activation_fn: str = "geglu",
410
+ num_embeds_ada_norm: Optional[int] = None,
411
+ attention_bias: bool = False,
412
+ only_cross_attention: bool = False,
413
+ double_self_attention: bool = False,
414
+ upcast_attention: bool = False,
415
+ norm_elementwise_affine: bool = True,
416
+ norm_type: str = "layer_norm",
417
+ final_dropout: bool = False,
418
+ num_views: int = 1,
419
+ cd_attention_last: bool = False,
420
+ cd_attention_mid: bool = False,
421
+ multiview_attention: bool = True,
422
+ sparse_mv_attention: bool = False,
423
+ mvcd_attention: bool = False
424
+ ):
425
+ super().__init__()
426
+ self.only_cross_attention = only_cross_attention
427
+
428
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
429
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
430
+
431
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
432
+ raise ValueError(
433
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
434
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
435
+ )
436
+
437
+ # Define 3 blocks. Each block has its own normalization layer.
438
+ # 1. Self-Attn
439
+ if self.use_ada_layer_norm:
440
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
441
+ elif self.use_ada_layer_norm_zero:
442
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
443
+ else:
444
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
445
+
446
+ self.multiview_attention = multiview_attention
447
+ self.sparse_mv_attention = sparse_mv_attention
448
+ self.mvcd_attention = mvcd_attention
449
+
450
+ self.attn1 = CustomAttention(
451
+ query_dim=dim,
452
+ heads=num_attention_heads,
453
+ dim_head=attention_head_dim,
454
+ dropout=dropout,
455
+ bias=attention_bias,
456
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
457
+ upcast_attention=upcast_attention,
458
+ processor=MVAttnProcessor()
459
+ )
460
+
461
+ # 2. Cross-Attn
462
+ if cross_attention_dim is not None or double_self_attention:
463
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
464
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
465
+ # the second cross attention block.
466
+ self.norm2 = (
467
+ AdaLayerNorm(dim, num_embeds_ada_norm)
468
+ if self.use_ada_layer_norm
469
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
470
+ )
471
+ self.attn2 = Attention(
472
+ query_dim=dim,
473
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
474
+ heads=num_attention_heads,
475
+ dim_head=attention_head_dim,
476
+ dropout=dropout,
477
+ bias=attention_bias,
478
+ upcast_attention=upcast_attention,
479
+ # processor=CrossAttnProcessor()
480
+ ) # is self-attn if encoder_hidden_states is none
481
+ else:
482
+ self.norm2 = None
483
+ self.attn2 = None
484
+
485
+ # 3. Feed-forward
486
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
487
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
488
+
489
+ # let chunk size default to None
490
+ self._chunk_size = None
491
+ self._chunk_dim = 0
492
+
493
+ self.num_views = num_views
494
+
495
+ self.cd_attention_last = cd_attention_last
496
+
497
+ if self.cd_attention_last:
498
+ # Joint task -Attn
499
+ self.attn_joint_last = Attention(
500
+ query_dim=dim,
501
+ heads=num_attention_heads,
502
+ dim_head=attention_head_dim,
503
+ dropout=dropout,
504
+ bias=attention_bias,
505
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
506
+ upcast_attention=upcast_attention,
507
+ processor=JointAttnProcessor()
508
+ )
509
+ nn.init.zeros_(self.attn_joint_last.to_out[0].weight.data)
510
+ self.norm_joint_last = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
511
+
512
+
513
+ self.cd_attention_mid = cd_attention_mid
514
+
515
+ if self.cd_attention_mid:
516
+ # print("cross-domain attn in the middle")
517
+ # Joint task -Attn
518
+ self.attn_joint_mid = Attention(
519
+ query_dim=dim,
520
+ heads=num_attention_heads,
521
+ dim_head=attention_head_dim,
522
+ dropout=dropout,
523
+ bias=attention_bias,
524
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
525
+ upcast_attention=upcast_attention,
526
+ processor=JointAttnProcessor()
527
+ )
528
+ nn.init.zeros_(self.attn_joint_mid.to_out[0].weight.data)
529
+ self.norm_joint_mid = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
530
+
531
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
532
+ # Sets chunk feed-forward
533
+ self._chunk_size = chunk_size
534
+ self._chunk_dim = dim
535
+
536
+ def forward(
537
+ self,
538
+ hidden_states: torch.FloatTensor,
539
+ attention_mask: Optional[torch.FloatTensor] = None,
540
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
541
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
542
+ timestep: Optional[torch.LongTensor] = None,
543
+ cross_attention_kwargs: Dict[str, Any] = None,
544
+ class_labels: Optional[torch.LongTensor] = None,
545
+ ):
546
+ """
547
+
548
+ :type attention_mask: object
549
+ """
550
+ assert attention_mask is None # not supported yet
551
+ # Notice that normalization is always applied before the real computation in the following blocks.
552
+ # 1. Self-Attention
553
+ if self.use_ada_layer_norm:
554
+ norm_hidden_states = self.norm1(hidden_states, timestep)
555
+ elif self.use_ada_layer_norm_zero:
556
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
557
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
558
+ )
559
+ else:
560
+ norm_hidden_states = self.norm1(hidden_states)
561
+
562
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
563
+
564
+ attn_output = self.attn1(norm_hidden_states,
565
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
566
+ attention_mask=attention_mask,
567
+ num_views=self.num_views,
568
+ multiview_attention=self.multiview_attention,
569
+ sparse_mv_attention=self.sparse_mv_attention,
570
+ mvcd_attention=self.mvcd_attention,
571
+ **cross_attention_kwargs,
572
+ )
573
+
574
+
575
+ if self.use_ada_layer_norm_zero:
576
+ attn_output = gate_msa.unsqueeze(1) * attn_output
577
+ hidden_states = attn_output + hidden_states
578
+
579
+ # joint attention twice
580
+ if self.cd_attention_mid:
581
+ norm_hidden_states = (
582
+ self.norm_joint_mid(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_mid(hidden_states)
583
+ )
584
+ hidden_states = self.attn_joint_mid(norm_hidden_states) + hidden_states
585
+
586
+ # 2. Cross-Attention
587
+ if self.attn2 is not None:
588
+ norm_hidden_states = (
589
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
590
+ )
591
+
592
+ attn_output = self.attn2(
593
+ norm_hidden_states,
594
+ encoder_hidden_states=encoder_hidden_states,
595
+ attention_mask=encoder_attention_mask,
596
+ **cross_attention_kwargs,
597
+ )
598
+ hidden_states = attn_output + hidden_states
599
+
600
+ # 3. Feed-forward
601
+ norm_hidden_states = self.norm3(hidden_states)
602
+
603
+ if self.use_ada_layer_norm_zero:
604
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
605
+
606
+ if self._chunk_size is not None:
607
+ # "feed_forward_chunk_size" can be used to save memory
608
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
609
+ raise ValueError(
610
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
611
+ )
612
+
613
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
614
+ ff_output = torch.cat(
615
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
616
+ dim=self._chunk_dim,
617
+ )
618
+ else:
619
+ ff_output = self.ff(norm_hidden_states)
620
+
621
+ if self.use_ada_layer_norm_zero:
622
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
623
+
624
+ hidden_states = ff_output + hidden_states
625
+
626
+ if self.cd_attention_last:
627
+ norm_hidden_states = (
628
+ self.norm_joint_last(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_last(hidden_states)
629
+ )
630
+ hidden_states = self.attn_joint_last(norm_hidden_states) + hidden_states
631
+
632
+ return hidden_states
633
+
634
+
635
+ class CustomAttention(Attention):
636
+ def set_use_memory_efficient_attention_xformers(
637
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
638
+ ):
639
+ processor = XFormersMVAttnProcessor()
640
+ self.set_processor(processor)
641
+ # print("using xformers attention processor")
642
+
643
+
644
+ class CustomJointAttention(Attention):
645
+ def set_use_memory_efficient_attention_xformers(
646
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
647
+ ):
648
+ processor = XFormersJointAttnProcessor()
649
+ self.set_processor(processor)
650
+ # print("using xformers attention processor")
651
+
652
+ class MVAttnProcessor:
653
+ r"""
654
+ Default processor for performing attention-related computations.
655
+ """
656
+
657
+ def __call__(
658
+ self,
659
+ attn: Attention,
660
+ hidden_states,
661
+ encoder_hidden_states=None,
662
+ attention_mask=None,
663
+ temb=None,
664
+ num_views=1,
665
+ multiview_attention=True,
666
+ sparse_mv_attention=False,
667
+ mvcd_attention=False,
668
+ ):
669
+ residual = hidden_states
670
+
671
+ if attn.spatial_norm is not None:
672
+ hidden_states = attn.spatial_norm(hidden_states, temb)
673
+
674
+ input_ndim = hidden_states.ndim
675
+
676
+ if input_ndim == 4:
677
+ batch_size, channel, height, width = hidden_states.shape
678
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
679
+
680
+ batch_size, sequence_length, input_dim = (
681
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
682
+ )
683
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
684
+
685
+ if attn.group_norm is not None:
686
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
687
+
688
+ query = attn.to_q(hidden_states)
689
+
690
+ if encoder_hidden_states is None:
691
+ encoder_hidden_states = hidden_states
692
+ elif attn.norm_cross:
693
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
694
+
695
+ key = attn.to_k(encoder_hidden_states)
696
+ value = attn.to_v(encoder_hidden_states)
697
+
698
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
699
+ #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
700
+ # pdb.set_trace()
701
+ # multi-view self-attention
702
+ if multiview_attention:
703
+ key = rearrange(key, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
704
+ value = rearrange(value, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
705
+
706
+ # batch, n_heads, n_tokens, channel
707
+ query = attn.head_to_batch_dim(query, out_dim=4).contiguous()
708
+ key = attn.head_to_batch_dim(key, out_dim=4).contiguous()
709
+ value = attn.head_to_batch_dim(value, out_dim=4).contiguous()
710
+
711
+ with torch.backends.cuda.sdp_kernel(
712
+ enable_flash=True,
713
+ enable_math=False,
714
+ enable_mem_efficient=True
715
+ ):
716
+ hidden_states = F.scaled_dot_product_attention(query, key, value)
717
+
718
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, sequence_length, input_dim)
719
+
720
+ # linear proj
721
+ hidden_states = attn.to_out[0](hidden_states)
722
+ # dropout
723
+ hidden_states = attn.to_out[1](hidden_states)
724
+
725
+ if input_ndim == 4:
726
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
727
+
728
+ if attn.residual_connection:
729
+ hidden_states = hidden_states + residual
730
+
731
+ hidden_states = hidden_states / attn.rescale_output_factor
732
+
733
+ return hidden_states
734
+
735
+
736
+ class XFormersMVAttnProcessor:
737
+ r"""
738
+ Default processor for performing attention-related computations.
739
+ """
740
+
741
+ def __call__(
742
+ self,
743
+ attn: Attention,
744
+ hidden_states,
745
+ encoder_hidden_states=None,
746
+ attention_mask=None,
747
+ temb=None,
748
+ num_views=1.,
749
+ multiview_attention=True,
750
+ sparse_mv_attention=False,
751
+ mvcd_attention=False,
752
+ ):
753
+ residual = hidden_states
754
+
755
+ if attn.spatial_norm is not None:
756
+ hidden_states = attn.spatial_norm(hidden_states, temb)
757
+
758
+ input_ndim = hidden_states.ndim
759
+
760
+ if input_ndim == 4:
761
+ batch_size, channel, height, width = hidden_states.shape
762
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
763
+
764
+ batch_size, sequence_length, _ = (
765
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
766
+ )
767
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
768
+
769
+ # from yuancheng; here attention_mask is None
770
+ if attention_mask is not None:
771
+ # expand our mask's singleton query_tokens dimension:
772
+ # [batch*heads, 1, key_tokens] ->
773
+ # [batch*heads, query_tokens, key_tokens]
774
+ # so that it can be added as a bias onto the attention scores that xformers computes:
775
+ # [batch*heads, query_tokens, key_tokens]
776
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
777
+ _, query_tokens, _ = hidden_states.shape
778
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
779
+
780
+ if attn.group_norm is not None:
781
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
782
+
783
+ query = attn.to_q(hidden_states)
784
+
785
+ if encoder_hidden_states is None:
786
+ encoder_hidden_states = hidden_states
787
+ elif attn.norm_cross:
788
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
789
+
790
+ key_raw = attn.to_k(encoder_hidden_states)
791
+ value_raw = attn.to_v(encoder_hidden_states)
792
+
793
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
794
+ #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
795
+ # pdb.set_trace()
796
+ # multi-view self-attention
797
+ if multiview_attention:
798
+ if not sparse_mv_attention:
799
+ key = my_repeat(rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views), num_views)
800
+ value = my_repeat(rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views), num_views)
801
+ else:
802
+ key_front = my_repeat(rearrange(key_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views) # [(b t), d, c]
803
+ value_front = my_repeat(rearrange(value_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views)
804
+ key = torch.cat([key_front, key_raw], dim=1) # shape (b t) (2 d) c
805
+ value = torch.cat([value_front, value_raw], dim=1)
806
+
807
+ else:
808
+ # print("don't use multiview attention.")
809
+ key = key_raw
810
+ value = value_raw
811
+
812
+ query = attn.head_to_batch_dim(query)
813
+ key = attn.head_to_batch_dim(key)
814
+ value = attn.head_to_batch_dim(value)
815
+
816
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
817
+ # for flash attention implementation
818
+ # with torch.backends.cuda.sdp_kernel(enable_math=False):
819
+ # hidden_states = F.scaled_dot_product_attention(query, key, value, attn_bias=attention_mask)
820
+ # hidden_states = attn.batch_to_head_dim(hidden_states)
821
+
822
+ # linear proj
823
+ hidden_states = attn.to_out[0](hidden_states)
824
+ # dropout
825
+ hidden_states = attn.to_out[1](hidden_states)
826
+
827
+ if input_ndim == 4:
828
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
829
+
830
+ if attn.residual_connection:
831
+ hidden_states = hidden_states + residual
832
+
833
+ hidden_states = hidden_states / attn.rescale_output_factor
834
+
835
+ return hidden_states
836
+
837
+
838
+
839
+ class XFormersJointAttnProcessor:
840
+ r"""
841
+ Default processor for performing attention-related computations.
842
+ """
843
+
844
+ def __call__(
845
+ self,
846
+ attn: Attention,
847
+ hidden_states,
848
+ encoder_hidden_states=None,
849
+ attention_mask=None,
850
+ temb=None,
851
+ num_tasks=2
852
+ ):
853
+
854
+ residual = hidden_states
855
+
856
+ if attn.spatial_norm is not None:
857
+ hidden_states = attn.spatial_norm(hidden_states, temb)
858
+
859
+ input_ndim = hidden_states.ndim
860
+
861
+ if input_ndim == 4:
862
+ batch_size, channel, height, width = hidden_states.shape
863
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
864
+
865
+ batch_size, sequence_length, _ = (
866
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
867
+ )
868
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
869
+
870
+ # from yuancheng; here attention_mask is None
871
+ if attention_mask is not None:
872
+ # expand our mask's singleton query_tokens dimension:
873
+ # [batch*heads, 1, key_tokens] ->
874
+ # [batch*heads, query_tokens, key_tokens]
875
+ # so that it can be added as a bias onto the attention scores that xformers computes:
876
+ # [batch*heads, query_tokens, key_tokens]
877
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
878
+ _, query_tokens, _ = hidden_states.shape
879
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
880
+
881
+ if attn.group_norm is not None:
882
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
883
+
884
+ query = attn.to_q(hidden_states)
885
+
886
+ if encoder_hidden_states is None:
887
+ encoder_hidden_states = hidden_states
888
+ elif attn.norm_cross:
889
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
890
+
891
+ key = attn.to_k(encoder_hidden_states)
892
+ value = attn.to_v(encoder_hidden_states)
893
+
894
+ assert num_tasks == 2 # only support two tasks now
895
+
896
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
897
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
898
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
899
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
900
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
901
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
902
+
903
+
904
+ query = attn.head_to_batch_dim(query).contiguous()
905
+ key = attn.head_to_batch_dim(key).contiguous()
906
+ value = attn.head_to_batch_dim(value).contiguous()
907
+
908
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
909
+ # for flash attention implementation
910
+ # with torch.backends.cuda.sdp_kernel(enable_math=False):
911
+ # hidden_states = F.scaled_dot_product_attention(query, key, value, attn_bias=attention_mask)
912
+ # hidden_states = attn.batch_to_head_dim(hidden_states)
913
+
914
+ # linear proj
915
+ hidden_states = attn.to_out[0](hidden_states)
916
+ # dropout
917
+ hidden_states = attn.to_out[1](hidden_states)
918
+
919
+ if input_ndim == 4:
920
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
921
+
922
+ if attn.residual_connection:
923
+ hidden_states = hidden_states + residual
924
+
925
+ hidden_states = hidden_states / attn.rescale_output_factor
926
+
927
+ return hidden_states
928
+
929
+
930
+ # class JointAttnProcessor:
931
+ # r"""
932
+ # Default processor for performing attention-related computations.
933
+ # """
934
+ #
935
+ # def __call__(
936
+ # self,
937
+ # attn: Attention,
938
+ # hidden_states,
939
+ # encoder_hidden_states=None,
940
+ # attention_mask=None,
941
+ # temb=None,
942
+ # num_tasks=2
943
+ # ):
944
+ #
945
+ # residual = hidden_states
946
+ #
947
+ # if attn.spatial_norm is not None:
948
+ # hidden_states = attn.spatial_norm(hidden_states, temb)
949
+ #
950
+ # input_ndim = hidden_states.ndim
951
+ #
952
+ # if input_ndim == 4:
953
+ # batch_size, channel, height, width = hidden_states.shape
954
+ # hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
955
+ #
956
+ # batch_size, sequence_length, input_dim = (
957
+ # hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
958
+ # )
959
+ # attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
960
+ #
961
+ #
962
+ # if attn.group_norm is not None:
963
+ # hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
964
+ #
965
+ # query = attn.to_q(hidden_states)
966
+ #
967
+ # if encoder_hidden_states is None:
968
+ # encoder_hidden_states = hidden_states
969
+ # elif attn.norm_cross:
970
+ # encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
971
+ #
972
+ # key = attn.to_k(encoder_hidden_states)
973
+ # value = attn.to_v(encoder_hidden_states)
974
+ #
975
+ # assert num_tasks == 2 # only support two tasks now
976
+ #
977
+ # key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
978
+ # value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
979
+ # key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
980
+ # value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
981
+ # key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
982
+ # value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
983
+ #
984
+ #
985
+ # # batch, n_heads, n_tokens, channel
986
+ # query = attn.head_to_batch_dim(query, out_dim=4).contiguous()
987
+ # key = attn.head_to_batch_dim(key, out_dim=4).contiguous()
988
+ # value = attn.head_to_batch_dim(value, out_dim=4).contiguous()
989
+ #
990
+ # # attention_probs = attn.get_attention_scores(query, key, attention_mask)
991
+ # # hidden_states = torch.bmm(attention_probs, value)
992
+ # # hidden_states = attn.batch_to_head_dim(hidden_states)
993
+ #
994
+ # # for flash attention implementation
995
+ # with torch.backends.cuda.sdp_kernel(
996
+ # enable_flash=True,
997
+ # enable_math=False,
998
+ # enable_mem_efficient=True
999
+ # ):
1000
+ # hidden_states = F.scaled_dot_product_attention(query, key, value)
1001
+ #
1002
+ # hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, sequence_length, input_dim)
1003
+ #
1004
+ # # linear proj
1005
+ # hidden_states = attn.to_out[0](hidden_states)
1006
+ # # dropout
1007
+ # hidden_states = attn.to_out[1](hidden_states)
1008
+ #
1009
+ # if input_ndim == 4:
1010
+ # hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1011
+ #
1012
+ # if attn.residual_connection:
1013
+ # hidden_states = hidden_states + residual
1014
+ #
1015
+ # hidden_states = hidden_states / attn.rescale_output_factor
1016
+ #
1017
+ # return hidden_states
1018
+
1019
+ class JointAttnProcessor:
1020
+ r"""
1021
+ Default processor for performing attention-related computations.
1022
+ """
1023
+
1024
+ def __call__(
1025
+ self,
1026
+ attn: Attention,
1027
+ hidden_states,
1028
+ encoder_hidden_states=None,
1029
+ attention_mask=None,
1030
+ temb=None,
1031
+ num_tasks=2
1032
+ ):
1033
+
1034
+ residual = hidden_states
1035
+
1036
+ if attn.spatial_norm is not None:
1037
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1038
+
1039
+ input_ndim = hidden_states.ndim
1040
+
1041
+ if input_ndim == 4:
1042
+ batch_size, channel, height, width = hidden_states.shape
1043
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1044
+
1045
+ batch_size, sequence_length, _ = (
1046
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1047
+ )
1048
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1049
+
1050
+ if attn.group_norm is not None:
1051
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1052
+
1053
+ query = attn.to_q(hidden_states)
1054
+
1055
+ if encoder_hidden_states is None:
1056
+ encoder_hidden_states = hidden_states
1057
+ elif attn.norm_cross:
1058
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1059
+
1060
+ key = attn.to_k(encoder_hidden_states)
1061
+ value = attn.to_v(encoder_hidden_states)
1062
+
1063
+ assert num_tasks == 2 # only support two tasks now
1064
+
1065
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
1066
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
1067
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
1068
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
1069
+ key = torch.cat([key] * 2, dim=0) # ( 2 b t) 2d c
1070
+ value = torch.cat([value] * 2, dim=0) # (2 b t) 2d c
1071
+
1072
+ query = attn.head_to_batch_dim(query).contiguous()
1073
+ key = attn.head_to_batch_dim(key).contiguous()
1074
+ value = attn.head_to_batch_dim(value).contiguous()
1075
+
1076
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
1077
+ hidden_states = torch.bmm(attention_probs, value)
1078
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1079
+
1080
+ # linear proj
1081
+ hidden_states = attn.to_out[0](hidden_states)
1082
+ # dropout
1083
+ hidden_states = attn.to_out[1](hidden_states)
1084
+
1085
+ if input_ndim == 4:
1086
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1087
+
1088
+ if attn.residual_connection:
1089
+ hidden_states = hidden_states + residual
1090
+
1091
+ hidden_states = hidden_states / attn.rescale_output_factor
1092
+
1093
+ return hidden_states
mv_diffusion_30/models/unet_mv2d_blocks.py ADDED
@@ -0,0 +1,922 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional, Tuple
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.utils import is_torch_version, logging
22
+ # from diffusers.models.normalization import AdaGroupNorm
23
+ # from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
24
+ # from diffusers.models.transformers.dual_transformer_2d import DualTransformer2DModel
25
+ from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
26
+ from mv_diffusion_30.models.transformer_mv2d import TransformerMV2DModel
27
+
28
+ from diffusers.models.unets.unet_2d_blocks import DownBlock2D, ResnetDownsampleBlock2D, AttnDownBlock2D, CrossAttnDownBlock2D, SimpleCrossAttnDownBlock2D, SkipDownBlock2D, AttnSkipDownBlock2D, DownEncoderBlock2D, AttnDownEncoderBlock2D, KDownBlock2D, KCrossAttnDownBlock2D
29
+ from diffusers.models.unets.unet_2d_blocks import UpBlock2D, ResnetUpsampleBlock2D, CrossAttnUpBlock2D, SimpleCrossAttnUpBlock2D, AttnUpBlock2D, SkipUpBlock2D, AttnSkipUpBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D, KUpBlock2D, KCrossAttnUpBlock2D
30
+
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+
35
+ def get_down_block(
36
+ down_block_type,
37
+ num_layers,
38
+ in_channels,
39
+ out_channels,
40
+ temb_channels,
41
+ add_downsample,
42
+ resnet_eps,
43
+ resnet_act_fn,
44
+ transformer_layers_per_block=1,
45
+ num_attention_heads=None,
46
+ resnet_groups=None,
47
+ cross_attention_dim=None,
48
+ downsample_padding=None,
49
+ dual_cross_attention=False,
50
+ use_linear_projection=False,
51
+ only_cross_attention=False,
52
+ upcast_attention=False,
53
+ resnet_time_scale_shift="default",
54
+ resnet_skip_time_act=False,
55
+ resnet_out_scale_factor=1.0,
56
+ cross_attention_norm=None,
57
+ attention_head_dim=None,
58
+ downsample_type=None,
59
+ num_views=1,
60
+ cd_attention_last: bool = False,
61
+ cd_attention_mid: bool = False,
62
+ multiview_attention: bool = True,
63
+ sparse_mv_attention: bool = False,
64
+ mvcd_attention: bool=False
65
+ ):
66
+ # If attn head dim is not defined, we default it to the number of heads
67
+ if attention_head_dim is None:
68
+ logger.warn(
69
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
70
+ )
71
+ attention_head_dim = num_attention_heads
72
+
73
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
74
+ if down_block_type == "DownBlock2D":
75
+ return DownBlock2D(
76
+ num_layers=num_layers,
77
+ in_channels=in_channels,
78
+ out_channels=out_channels,
79
+ temb_channels=temb_channels,
80
+ add_downsample=add_downsample,
81
+ resnet_eps=resnet_eps,
82
+ resnet_act_fn=resnet_act_fn,
83
+ resnet_groups=resnet_groups,
84
+ downsample_padding=downsample_padding,
85
+ resnet_time_scale_shift=resnet_time_scale_shift,
86
+ )
87
+ elif down_block_type == "ResnetDownsampleBlock2D":
88
+ return ResnetDownsampleBlock2D(
89
+ num_layers=num_layers,
90
+ in_channels=in_channels,
91
+ out_channels=out_channels,
92
+ temb_channels=temb_channels,
93
+ add_downsample=add_downsample,
94
+ resnet_eps=resnet_eps,
95
+ resnet_act_fn=resnet_act_fn,
96
+ resnet_groups=resnet_groups,
97
+ resnet_time_scale_shift=resnet_time_scale_shift,
98
+ skip_time_act=resnet_skip_time_act,
99
+ output_scale_factor=resnet_out_scale_factor,
100
+ )
101
+ elif down_block_type == "AttnDownBlock2D":
102
+ if add_downsample is False:
103
+ downsample_type = None
104
+ else:
105
+ downsample_type = downsample_type or "conv" # default to 'conv'
106
+ return AttnDownBlock2D(
107
+ num_layers=num_layers,
108
+ in_channels=in_channels,
109
+ out_channels=out_channels,
110
+ temb_channels=temb_channels,
111
+ resnet_eps=resnet_eps,
112
+ resnet_act_fn=resnet_act_fn,
113
+ resnet_groups=resnet_groups,
114
+ downsample_padding=downsample_padding,
115
+ attention_head_dim=attention_head_dim,
116
+ resnet_time_scale_shift=resnet_time_scale_shift,
117
+ downsample_type=downsample_type,
118
+ )
119
+ elif down_block_type == "CrossAttnDownBlock2D":
120
+ if cross_attention_dim is None:
121
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
122
+ return CrossAttnDownBlock2D(
123
+ num_layers=num_layers,
124
+ transformer_layers_per_block=transformer_layers_per_block,
125
+ in_channels=in_channels,
126
+ out_channels=out_channels,
127
+ temb_channels=temb_channels,
128
+ add_downsample=add_downsample,
129
+ resnet_eps=resnet_eps,
130
+ resnet_act_fn=resnet_act_fn,
131
+ resnet_groups=resnet_groups,
132
+ downsample_padding=downsample_padding,
133
+ cross_attention_dim=cross_attention_dim,
134
+ num_attention_heads=num_attention_heads,
135
+ dual_cross_attention=dual_cross_attention,
136
+ use_linear_projection=use_linear_projection,
137
+ only_cross_attention=only_cross_attention,
138
+ upcast_attention=upcast_attention,
139
+ resnet_time_scale_shift=resnet_time_scale_shift,
140
+ )
141
+ # custom MV2D attention block
142
+ elif down_block_type == "CrossAttnDownBlockMV2D":
143
+ if cross_attention_dim is None:
144
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMV2D")
145
+ return CrossAttnDownBlockMV2D(
146
+ num_layers=num_layers,
147
+ transformer_layers_per_block=transformer_layers_per_block,
148
+ in_channels=in_channels,
149
+ out_channels=out_channels,
150
+ temb_channels=temb_channels,
151
+ add_downsample=add_downsample,
152
+ resnet_eps=resnet_eps,
153
+ resnet_act_fn=resnet_act_fn,
154
+ resnet_groups=resnet_groups,
155
+ downsample_padding=downsample_padding,
156
+ cross_attention_dim=cross_attention_dim,
157
+ num_attention_heads=num_attention_heads,
158
+ dual_cross_attention=dual_cross_attention,
159
+ use_linear_projection=use_linear_projection,
160
+ only_cross_attention=only_cross_attention,
161
+ upcast_attention=upcast_attention,
162
+ resnet_time_scale_shift=resnet_time_scale_shift,
163
+ num_views=num_views,
164
+ cd_attention_last=cd_attention_last,
165
+ cd_attention_mid=cd_attention_mid,
166
+ multiview_attention=multiview_attention,
167
+ sparse_mv_attention=sparse_mv_attention,
168
+ mvcd_attention=mvcd_attention
169
+ )
170
+ elif down_block_type == "SimpleCrossAttnDownBlock2D":
171
+ if cross_attention_dim is None:
172
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
173
+ return SimpleCrossAttnDownBlock2D(
174
+ num_layers=num_layers,
175
+ in_channels=in_channels,
176
+ out_channels=out_channels,
177
+ temb_channels=temb_channels,
178
+ add_downsample=add_downsample,
179
+ resnet_eps=resnet_eps,
180
+ resnet_act_fn=resnet_act_fn,
181
+ resnet_groups=resnet_groups,
182
+ cross_attention_dim=cross_attention_dim,
183
+ attention_head_dim=attention_head_dim,
184
+ resnet_time_scale_shift=resnet_time_scale_shift,
185
+ skip_time_act=resnet_skip_time_act,
186
+ output_scale_factor=resnet_out_scale_factor,
187
+ only_cross_attention=only_cross_attention,
188
+ cross_attention_norm=cross_attention_norm,
189
+ )
190
+ elif down_block_type == "SkipDownBlock2D":
191
+ return SkipDownBlock2D(
192
+ num_layers=num_layers,
193
+ in_channels=in_channels,
194
+ out_channels=out_channels,
195
+ temb_channels=temb_channels,
196
+ add_downsample=add_downsample,
197
+ resnet_eps=resnet_eps,
198
+ resnet_act_fn=resnet_act_fn,
199
+ downsample_padding=downsample_padding,
200
+ resnet_time_scale_shift=resnet_time_scale_shift,
201
+ )
202
+ elif down_block_type == "AttnSkipDownBlock2D":
203
+ return AttnSkipDownBlock2D(
204
+ num_layers=num_layers,
205
+ in_channels=in_channels,
206
+ out_channels=out_channels,
207
+ temb_channels=temb_channels,
208
+ add_downsample=add_downsample,
209
+ resnet_eps=resnet_eps,
210
+ resnet_act_fn=resnet_act_fn,
211
+ attention_head_dim=attention_head_dim,
212
+ resnet_time_scale_shift=resnet_time_scale_shift,
213
+ )
214
+ elif down_block_type == "DownEncoderBlock2D":
215
+ return DownEncoderBlock2D(
216
+ num_layers=num_layers,
217
+ in_channels=in_channels,
218
+ out_channels=out_channels,
219
+ add_downsample=add_downsample,
220
+ resnet_eps=resnet_eps,
221
+ resnet_act_fn=resnet_act_fn,
222
+ resnet_groups=resnet_groups,
223
+ downsample_padding=downsample_padding,
224
+ resnet_time_scale_shift=resnet_time_scale_shift,
225
+ )
226
+ elif down_block_type == "AttnDownEncoderBlock2D":
227
+ return AttnDownEncoderBlock2D(
228
+ num_layers=num_layers,
229
+ in_channels=in_channels,
230
+ out_channels=out_channels,
231
+ add_downsample=add_downsample,
232
+ resnet_eps=resnet_eps,
233
+ resnet_act_fn=resnet_act_fn,
234
+ resnet_groups=resnet_groups,
235
+ downsample_padding=downsample_padding,
236
+ attention_head_dim=attention_head_dim,
237
+ resnet_time_scale_shift=resnet_time_scale_shift,
238
+ )
239
+ elif down_block_type == "KDownBlock2D":
240
+ return KDownBlock2D(
241
+ num_layers=num_layers,
242
+ in_channels=in_channels,
243
+ out_channels=out_channels,
244
+ temb_channels=temb_channels,
245
+ add_downsample=add_downsample,
246
+ resnet_eps=resnet_eps,
247
+ resnet_act_fn=resnet_act_fn,
248
+ )
249
+ elif down_block_type == "KCrossAttnDownBlock2D":
250
+ return KCrossAttnDownBlock2D(
251
+ num_layers=num_layers,
252
+ in_channels=in_channels,
253
+ out_channels=out_channels,
254
+ temb_channels=temb_channels,
255
+ add_downsample=add_downsample,
256
+ resnet_eps=resnet_eps,
257
+ resnet_act_fn=resnet_act_fn,
258
+ cross_attention_dim=cross_attention_dim,
259
+ attention_head_dim=attention_head_dim,
260
+ add_self_attention=True if not add_downsample else False,
261
+ )
262
+ raise ValueError(f"{down_block_type} does not exist.")
263
+
264
+
265
+ def get_up_block(
266
+ up_block_type,
267
+ num_layers,
268
+ in_channels,
269
+ out_channels,
270
+ prev_output_channel,
271
+ temb_channels,
272
+ add_upsample,
273
+ resnet_eps,
274
+ resnet_act_fn,
275
+ transformer_layers_per_block=1,
276
+ num_attention_heads=None,
277
+ resnet_groups=None,
278
+ cross_attention_dim=None,
279
+ dual_cross_attention=False,
280
+ use_linear_projection=False,
281
+ only_cross_attention=False,
282
+ upcast_attention=False,
283
+ resnet_time_scale_shift="default",
284
+ resnet_skip_time_act=False,
285
+ resnet_out_scale_factor=1.0,
286
+ cross_attention_norm=None,
287
+ attention_head_dim=None,
288
+ upsample_type=None,
289
+ num_views=1,
290
+ cd_attention_last: bool = False,
291
+ cd_attention_mid: bool = False,
292
+ multiview_attention: bool = True,
293
+ sparse_mv_attention: bool = False,
294
+ mvcd_attention: bool=False
295
+ ):
296
+ # If attn head dim is not defined, we default it to the number of heads
297
+ if attention_head_dim is None:
298
+ logger.warn(
299
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
300
+ )
301
+ attention_head_dim = num_attention_heads
302
+
303
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
304
+ if up_block_type == "UpBlock2D":
305
+ return UpBlock2D(
306
+ num_layers=num_layers,
307
+ in_channels=in_channels,
308
+ out_channels=out_channels,
309
+ prev_output_channel=prev_output_channel,
310
+ temb_channels=temb_channels,
311
+ add_upsample=add_upsample,
312
+ resnet_eps=resnet_eps,
313
+ resnet_act_fn=resnet_act_fn,
314
+ resnet_groups=resnet_groups,
315
+ resnet_time_scale_shift=resnet_time_scale_shift,
316
+ )
317
+ elif up_block_type == "ResnetUpsampleBlock2D":
318
+ return ResnetUpsampleBlock2D(
319
+ num_layers=num_layers,
320
+ in_channels=in_channels,
321
+ out_channels=out_channels,
322
+ prev_output_channel=prev_output_channel,
323
+ temb_channels=temb_channels,
324
+ add_upsample=add_upsample,
325
+ resnet_eps=resnet_eps,
326
+ resnet_act_fn=resnet_act_fn,
327
+ resnet_groups=resnet_groups,
328
+ resnet_time_scale_shift=resnet_time_scale_shift,
329
+ skip_time_act=resnet_skip_time_act,
330
+ output_scale_factor=resnet_out_scale_factor,
331
+ )
332
+ elif up_block_type == "CrossAttnUpBlock2D":
333
+ if cross_attention_dim is None:
334
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
335
+ return CrossAttnUpBlock2D(
336
+ num_layers=num_layers,
337
+ transformer_layers_per_block=transformer_layers_per_block,
338
+ in_channels=in_channels,
339
+ out_channels=out_channels,
340
+ prev_output_channel=prev_output_channel,
341
+ temb_channels=temb_channels,
342
+ add_upsample=add_upsample,
343
+ resnet_eps=resnet_eps,
344
+ resnet_act_fn=resnet_act_fn,
345
+ resnet_groups=resnet_groups,
346
+ cross_attention_dim=cross_attention_dim,
347
+ num_attention_heads=num_attention_heads,
348
+ dual_cross_attention=dual_cross_attention,
349
+ use_linear_projection=use_linear_projection,
350
+ only_cross_attention=only_cross_attention,
351
+ upcast_attention=upcast_attention,
352
+ resnet_time_scale_shift=resnet_time_scale_shift,
353
+ )
354
+ # custom MV2D attention block
355
+ elif up_block_type == "CrossAttnUpBlockMV2D":
356
+ if cross_attention_dim is None:
357
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMV2D")
358
+ return CrossAttnUpBlockMV2D(
359
+ num_layers=num_layers,
360
+ transformer_layers_per_block=transformer_layers_per_block,
361
+ in_channels=in_channels,
362
+ out_channels=out_channels,
363
+ prev_output_channel=prev_output_channel,
364
+ temb_channels=temb_channels,
365
+ add_upsample=add_upsample,
366
+ resnet_eps=resnet_eps,
367
+ resnet_act_fn=resnet_act_fn,
368
+ resnet_groups=resnet_groups,
369
+ cross_attention_dim=cross_attention_dim,
370
+ num_attention_heads=num_attention_heads,
371
+ dual_cross_attention=dual_cross_attention,
372
+ use_linear_projection=use_linear_projection,
373
+ only_cross_attention=only_cross_attention,
374
+ upcast_attention=upcast_attention,
375
+ resnet_time_scale_shift=resnet_time_scale_shift,
376
+ num_views=num_views,
377
+ cd_attention_last=cd_attention_last,
378
+ cd_attention_mid=cd_attention_mid,
379
+ multiview_attention=multiview_attention,
380
+ sparse_mv_attention=sparse_mv_attention,
381
+ mvcd_attention=mvcd_attention
382
+ )
383
+ elif up_block_type == "SimpleCrossAttnUpBlock2D":
384
+ if cross_attention_dim is None:
385
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D")
386
+ return SimpleCrossAttnUpBlock2D(
387
+ num_layers=num_layers,
388
+ in_channels=in_channels,
389
+ out_channels=out_channels,
390
+ prev_output_channel=prev_output_channel,
391
+ temb_channels=temb_channels,
392
+ add_upsample=add_upsample,
393
+ resnet_eps=resnet_eps,
394
+ resnet_act_fn=resnet_act_fn,
395
+ resnet_groups=resnet_groups,
396
+ cross_attention_dim=cross_attention_dim,
397
+ attention_head_dim=attention_head_dim,
398
+ resnet_time_scale_shift=resnet_time_scale_shift,
399
+ skip_time_act=resnet_skip_time_act,
400
+ output_scale_factor=resnet_out_scale_factor,
401
+ only_cross_attention=only_cross_attention,
402
+ cross_attention_norm=cross_attention_norm,
403
+ )
404
+ elif up_block_type == "AttnUpBlock2D":
405
+ if add_upsample is False:
406
+ upsample_type = None
407
+ else:
408
+ upsample_type = upsample_type or "conv" # default to 'conv'
409
+
410
+ return AttnUpBlock2D(
411
+ num_layers=num_layers,
412
+ in_channels=in_channels,
413
+ out_channels=out_channels,
414
+ prev_output_channel=prev_output_channel,
415
+ temb_channels=temb_channels,
416
+ resnet_eps=resnet_eps,
417
+ resnet_act_fn=resnet_act_fn,
418
+ resnet_groups=resnet_groups,
419
+ attention_head_dim=attention_head_dim,
420
+ resnet_time_scale_shift=resnet_time_scale_shift,
421
+ upsample_type=upsample_type,
422
+ )
423
+ elif up_block_type == "SkipUpBlock2D":
424
+ return SkipUpBlock2D(
425
+ num_layers=num_layers,
426
+ in_channels=in_channels,
427
+ out_channels=out_channels,
428
+ prev_output_channel=prev_output_channel,
429
+ temb_channels=temb_channels,
430
+ add_upsample=add_upsample,
431
+ resnet_eps=resnet_eps,
432
+ resnet_act_fn=resnet_act_fn,
433
+ resnet_time_scale_shift=resnet_time_scale_shift,
434
+ )
435
+ elif up_block_type == "AttnSkipUpBlock2D":
436
+ return AttnSkipUpBlock2D(
437
+ num_layers=num_layers,
438
+ in_channels=in_channels,
439
+ out_channels=out_channels,
440
+ prev_output_channel=prev_output_channel,
441
+ temb_channels=temb_channels,
442
+ add_upsample=add_upsample,
443
+ resnet_eps=resnet_eps,
444
+ resnet_act_fn=resnet_act_fn,
445
+ attention_head_dim=attention_head_dim,
446
+ resnet_time_scale_shift=resnet_time_scale_shift,
447
+ )
448
+ elif up_block_type == "UpDecoderBlock2D":
449
+ return UpDecoderBlock2D(
450
+ num_layers=num_layers,
451
+ in_channels=in_channels,
452
+ out_channels=out_channels,
453
+ add_upsample=add_upsample,
454
+ resnet_eps=resnet_eps,
455
+ resnet_act_fn=resnet_act_fn,
456
+ resnet_groups=resnet_groups,
457
+ resnet_time_scale_shift=resnet_time_scale_shift,
458
+ temb_channels=temb_channels,
459
+ )
460
+ elif up_block_type == "AttnUpDecoderBlock2D":
461
+ return AttnUpDecoderBlock2D(
462
+ num_layers=num_layers,
463
+ in_channels=in_channels,
464
+ out_channels=out_channels,
465
+ add_upsample=add_upsample,
466
+ resnet_eps=resnet_eps,
467
+ resnet_act_fn=resnet_act_fn,
468
+ resnet_groups=resnet_groups,
469
+ attention_head_dim=attention_head_dim,
470
+ resnet_time_scale_shift=resnet_time_scale_shift,
471
+ temb_channels=temb_channels,
472
+ )
473
+ elif up_block_type == "KUpBlock2D":
474
+ return KUpBlock2D(
475
+ num_layers=num_layers,
476
+ in_channels=in_channels,
477
+ out_channels=out_channels,
478
+ temb_channels=temb_channels,
479
+ add_upsample=add_upsample,
480
+ resnet_eps=resnet_eps,
481
+ resnet_act_fn=resnet_act_fn,
482
+ )
483
+ elif up_block_type == "KCrossAttnUpBlock2D":
484
+ return KCrossAttnUpBlock2D(
485
+ num_layers=num_layers,
486
+ in_channels=in_channels,
487
+ out_channels=out_channels,
488
+ temb_channels=temb_channels,
489
+ add_upsample=add_upsample,
490
+ resnet_eps=resnet_eps,
491
+ resnet_act_fn=resnet_act_fn,
492
+ cross_attention_dim=cross_attention_dim,
493
+ attention_head_dim=attention_head_dim,
494
+ )
495
+
496
+ raise ValueError(f"{up_block_type} does not exist.")
497
+
498
+
499
+ class UNetMidBlockMV2DCrossAttn(nn.Module):
500
+ def __init__(
501
+ self,
502
+ in_channels: int,
503
+ temb_channels: int,
504
+ dropout: float = 0.0,
505
+ num_layers: int = 1,
506
+ transformer_layers_per_block: int = 1,
507
+ resnet_eps: float = 1e-6,
508
+ resnet_time_scale_shift: str = "default",
509
+ resnet_act_fn: str = "swish",
510
+ resnet_groups: int = 32,
511
+ resnet_pre_norm: bool = True,
512
+ num_attention_heads=1,
513
+ output_scale_factor=1.0,
514
+ cross_attention_dim=1280,
515
+ dual_cross_attention=False,
516
+ use_linear_projection=False,
517
+ upcast_attention=False,
518
+ num_views: int = 1,
519
+ cd_attention_last: bool = False,
520
+ cd_attention_mid: bool = False,
521
+ multiview_attention: bool = True,
522
+ sparse_mv_attention: bool = False,
523
+ mvcd_attention: bool=False
524
+ ):
525
+ super().__init__()
526
+
527
+ self.has_cross_attention = True
528
+ self.num_attention_heads = num_attention_heads
529
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
530
+
531
+ # there is always at least one resnet
532
+ resnets = [
533
+ ResnetBlock2D(
534
+ in_channels=in_channels,
535
+ out_channels=in_channels,
536
+ temb_channels=temb_channels,
537
+ eps=resnet_eps,
538
+ groups=resnet_groups,
539
+ dropout=dropout,
540
+ time_embedding_norm=resnet_time_scale_shift,
541
+ non_linearity=resnet_act_fn,
542
+ output_scale_factor=output_scale_factor,
543
+ pre_norm=resnet_pre_norm,
544
+ )
545
+ ]
546
+ attentions = []
547
+
548
+ for _ in range(num_layers):
549
+ if not dual_cross_attention:
550
+ attentions.append(
551
+ TransformerMV2DModel(
552
+ num_attention_heads,
553
+ in_channels // num_attention_heads,
554
+ in_channels=in_channels,
555
+ num_layers=transformer_layers_per_block,
556
+ cross_attention_dim=cross_attention_dim,
557
+ norm_num_groups=resnet_groups,
558
+ use_linear_projection=use_linear_projection,
559
+ upcast_attention=upcast_attention,
560
+ num_views=num_views,
561
+ cd_attention_last=cd_attention_last,
562
+ cd_attention_mid=cd_attention_mid,
563
+ multiview_attention=multiview_attention,
564
+ sparse_mv_attention=sparse_mv_attention,
565
+ mvcd_attention=mvcd_attention
566
+ )
567
+ )
568
+ else:
569
+ raise NotImplementedError
570
+ resnets.append(
571
+ ResnetBlock2D(
572
+ in_channels=in_channels,
573
+ out_channels=in_channels,
574
+ temb_channels=temb_channels,
575
+ eps=resnet_eps,
576
+ groups=resnet_groups,
577
+ dropout=dropout,
578
+ time_embedding_norm=resnet_time_scale_shift,
579
+ non_linearity=resnet_act_fn,
580
+ output_scale_factor=output_scale_factor,
581
+ pre_norm=resnet_pre_norm,
582
+ )
583
+ )
584
+
585
+ self.attentions = nn.ModuleList(attentions)
586
+ self.resnets = nn.ModuleList(resnets)
587
+
588
+ def forward(
589
+ self,
590
+ hidden_states: torch.FloatTensor,
591
+ temb: Optional[torch.FloatTensor] = None,
592
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
593
+ attention_mask: Optional[torch.FloatTensor] = None,
594
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
595
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
596
+ ) -> torch.FloatTensor:
597
+ hidden_states = self.resnets[0](hidden_states, temb)
598
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
599
+ hidden_states = attn(
600
+ hidden_states,
601
+ encoder_hidden_states=encoder_hidden_states,
602
+ cross_attention_kwargs=cross_attention_kwargs,
603
+ attention_mask=attention_mask,
604
+ encoder_attention_mask=encoder_attention_mask,
605
+ return_dict=False,
606
+ )[0]
607
+ hidden_states = resnet(hidden_states, temb)
608
+
609
+ return hidden_states
610
+
611
+
612
+ class CrossAttnUpBlockMV2D(nn.Module):
613
+ def __init__(
614
+ self,
615
+ in_channels: int,
616
+ out_channels: int,
617
+ prev_output_channel: int,
618
+ temb_channels: int,
619
+ dropout: float = 0.0,
620
+ num_layers: int = 1,
621
+ transformer_layers_per_block: int = 1,
622
+ resnet_eps: float = 1e-6,
623
+ resnet_time_scale_shift: str = "default",
624
+ resnet_act_fn: str = "swish",
625
+ resnet_groups: int = 32,
626
+ resnet_pre_norm: bool = True,
627
+ num_attention_heads=1,
628
+ cross_attention_dim=1280,
629
+ output_scale_factor=1.0,
630
+ add_upsample=True,
631
+ dual_cross_attention=False,
632
+ use_linear_projection=False,
633
+ only_cross_attention=False,
634
+ upcast_attention=False,
635
+ num_views: int = 1,
636
+ cd_attention_last: bool = False,
637
+ cd_attention_mid: bool = False,
638
+ multiview_attention: bool = True,
639
+ sparse_mv_attention: bool = False,
640
+ mvcd_attention: bool=False
641
+ ):
642
+ super().__init__()
643
+ resnets = []
644
+ attentions = []
645
+
646
+ self.has_cross_attention = True
647
+ self.num_attention_heads = num_attention_heads
648
+
649
+ for i in range(num_layers):
650
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
651
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
652
+
653
+ resnets.append(
654
+ ResnetBlock2D(
655
+ in_channels=resnet_in_channels + res_skip_channels,
656
+ out_channels=out_channels,
657
+ temb_channels=temb_channels,
658
+ eps=resnet_eps,
659
+ groups=resnet_groups,
660
+ dropout=dropout,
661
+ time_embedding_norm=resnet_time_scale_shift,
662
+ non_linearity=resnet_act_fn,
663
+ output_scale_factor=output_scale_factor,
664
+ pre_norm=resnet_pre_norm,
665
+ )
666
+ )
667
+ if not dual_cross_attention:
668
+ attentions.append(
669
+ TransformerMV2DModel(
670
+ num_attention_heads,
671
+ out_channels // num_attention_heads,
672
+ in_channels=out_channels,
673
+ num_layers=transformer_layers_per_block,
674
+ cross_attention_dim=cross_attention_dim,
675
+ norm_num_groups=resnet_groups,
676
+ use_linear_projection=use_linear_projection,
677
+ only_cross_attention=only_cross_attention,
678
+ upcast_attention=upcast_attention,
679
+ num_views=num_views,
680
+ cd_attention_last=cd_attention_last,
681
+ cd_attention_mid=cd_attention_mid,
682
+ multiview_attention=multiview_attention,
683
+ sparse_mv_attention=sparse_mv_attention,
684
+ mvcd_attention=mvcd_attention
685
+ )
686
+ )
687
+ else:
688
+ raise NotImplementedError
689
+ self.attentions = nn.ModuleList(attentions)
690
+ self.resnets = nn.ModuleList(resnets)
691
+
692
+ if add_upsample:
693
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
694
+ else:
695
+ self.upsamplers = None
696
+
697
+ self.gradient_checkpointing = False
698
+
699
+ def forward(
700
+ self,
701
+ hidden_states: torch.FloatTensor,
702
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
703
+ temb: Optional[torch.FloatTensor] = None,
704
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
705
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
706
+ upsample_size: Optional[int] = None,
707
+ attention_mask: Optional[torch.FloatTensor] = None,
708
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
709
+ ):
710
+ for resnet, attn in zip(self.resnets, self.attentions):
711
+ # pop res hidden states
712
+ res_hidden_states = res_hidden_states_tuple[-1]
713
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
714
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
715
+
716
+ if self.training and self.gradient_checkpointing:
717
+
718
+ def create_custom_forward(module, return_dict=None):
719
+ def custom_forward(*inputs):
720
+ if return_dict is not None:
721
+ return module(*inputs, return_dict=return_dict)
722
+ else:
723
+ return module(*inputs)
724
+
725
+ return custom_forward
726
+
727
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
728
+ hidden_states = torch.utils.checkpoint.checkpoint(
729
+ create_custom_forward(resnet),
730
+ hidden_states,
731
+ temb,
732
+ **ckpt_kwargs,
733
+ )
734
+ hidden_states = torch.utils.checkpoint.checkpoint(
735
+ create_custom_forward(attn, return_dict=False),
736
+ hidden_states,
737
+ encoder_hidden_states,
738
+ None, # timestep
739
+ None, # class_labels
740
+ cross_attention_kwargs,
741
+ attention_mask,
742
+ encoder_attention_mask,
743
+ **ckpt_kwargs,
744
+ )[0]
745
+ else:
746
+ hidden_states = resnet(hidden_states, temb)
747
+ hidden_states = attn(
748
+ hidden_states,
749
+ encoder_hidden_states=encoder_hidden_states,
750
+ cross_attention_kwargs=cross_attention_kwargs,
751
+ attention_mask=attention_mask,
752
+ encoder_attention_mask=encoder_attention_mask,
753
+ return_dict=False,
754
+ )[0]
755
+
756
+ if self.upsamplers is not None:
757
+ for upsampler in self.upsamplers:
758
+ hidden_states = upsampler(hidden_states, upsample_size)
759
+
760
+ return hidden_states
761
+
762
+
763
+ class CrossAttnDownBlockMV2D(nn.Module):
764
+ def __init__(
765
+ self,
766
+ in_channels: int,
767
+ out_channels: int,
768
+ temb_channels: int,
769
+ dropout: float = 0.0,
770
+ num_layers: int = 1,
771
+ transformer_layers_per_block: int = 1,
772
+ resnet_eps: float = 1e-6,
773
+ resnet_time_scale_shift: str = "default",
774
+ resnet_act_fn: str = "swish",
775
+ resnet_groups: int = 32,
776
+ resnet_pre_norm: bool = True,
777
+ num_attention_heads=1,
778
+ cross_attention_dim=1280,
779
+ output_scale_factor=1.0,
780
+ downsample_padding=1,
781
+ add_downsample=True,
782
+ dual_cross_attention=False,
783
+ use_linear_projection=False,
784
+ only_cross_attention=False,
785
+ upcast_attention=False,
786
+ num_views: int = 1,
787
+ cd_attention_last: bool = False,
788
+ cd_attention_mid: bool = False,
789
+ multiview_attention: bool = True,
790
+ sparse_mv_attention: bool = False,
791
+ mvcd_attention: bool=False
792
+ ):
793
+ super().__init__()
794
+ resnets = []
795
+ attentions = []
796
+
797
+ self.has_cross_attention = True
798
+ self.num_attention_heads = num_attention_heads
799
+
800
+ for i in range(num_layers):
801
+ in_channels = in_channels if i == 0 else out_channels
802
+ resnets.append(
803
+ ResnetBlock2D(
804
+ in_channels=in_channels,
805
+ out_channels=out_channels,
806
+ temb_channels=temb_channels,
807
+ eps=resnet_eps,
808
+ groups=resnet_groups,
809
+ dropout=dropout,
810
+ time_embedding_norm=resnet_time_scale_shift,
811
+ non_linearity=resnet_act_fn,
812
+ output_scale_factor=output_scale_factor,
813
+ pre_norm=resnet_pre_norm,
814
+ )
815
+ )
816
+ if not dual_cross_attention:
817
+ attentions.append(
818
+ TransformerMV2DModel(
819
+ num_attention_heads,
820
+ out_channels // num_attention_heads,
821
+ in_channels=out_channels,
822
+ num_layers=transformer_layers_per_block,
823
+ cross_attention_dim=cross_attention_dim,
824
+ norm_num_groups=resnet_groups,
825
+ use_linear_projection=use_linear_projection,
826
+ only_cross_attention=only_cross_attention,
827
+ upcast_attention=upcast_attention,
828
+ num_views=num_views,
829
+ cd_attention_last=cd_attention_last,
830
+ cd_attention_mid=cd_attention_mid,
831
+ multiview_attention=multiview_attention,
832
+ sparse_mv_attention=sparse_mv_attention,
833
+ mvcd_attention=mvcd_attention
834
+ )
835
+ )
836
+ else:
837
+ raise NotImplementedError
838
+ self.attentions = nn.ModuleList(attentions)
839
+ self.resnets = nn.ModuleList(resnets)
840
+
841
+ if add_downsample:
842
+ self.downsamplers = nn.ModuleList(
843
+ [
844
+ Downsample2D(
845
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
846
+ )
847
+ ]
848
+ )
849
+ else:
850
+ self.downsamplers = None
851
+
852
+ self.gradient_checkpointing = False
853
+
854
+ def forward(
855
+ self,
856
+ hidden_states: torch.FloatTensor,
857
+ temb: Optional[torch.FloatTensor] = None,
858
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
859
+ attention_mask: Optional[torch.FloatTensor] = None,
860
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
861
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
862
+ additional_residuals=None,
863
+ ):
864
+ output_states = ()
865
+
866
+ blocks = list(zip(self.resnets, self.attentions))
867
+
868
+ for i, (resnet, attn) in enumerate(blocks):
869
+ if self.training and self.gradient_checkpointing:
870
+
871
+ def create_custom_forward(module, return_dict=None):
872
+ def custom_forward(*inputs):
873
+ if return_dict is not None:
874
+ return module(*inputs, return_dict=return_dict)
875
+ else:
876
+ return module(*inputs)
877
+
878
+ return custom_forward
879
+
880
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
881
+ hidden_states = torch.utils.checkpoint.checkpoint(
882
+ create_custom_forward(resnet),
883
+ hidden_states,
884
+ temb,
885
+ **ckpt_kwargs,
886
+ )
887
+ hidden_states = torch.utils.checkpoint.checkpoint(
888
+ create_custom_forward(attn, return_dict=False),
889
+ hidden_states,
890
+ encoder_hidden_states,
891
+ None, # timestep
892
+ None, # class_labels
893
+ cross_attention_kwargs,
894
+ attention_mask,
895
+ encoder_attention_mask,
896
+ **ckpt_kwargs,
897
+ )[0]
898
+ else:
899
+ hidden_states = resnet(hidden_states, temb)
900
+ hidden_states = attn(
901
+ hidden_states,
902
+ encoder_hidden_states=encoder_hidden_states,
903
+ cross_attention_kwargs=cross_attention_kwargs,
904
+ attention_mask=attention_mask,
905
+ encoder_attention_mask=encoder_attention_mask,
906
+ return_dict=False,
907
+ )[0]
908
+
909
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
910
+ if i == len(blocks) - 1 and additional_residuals is not None:
911
+ hidden_states = hidden_states + additional_residuals
912
+
913
+ output_states = output_states + (hidden_states,)
914
+
915
+ if self.downsamplers is not None:
916
+ for downsampler in self.downsamplers:
917
+ hidden_states = downsampler(hidden_states)
918
+
919
+ output_states = output_states + (hidden_states,)
920
+
921
+ return hidden_states, output_states
922
+
mv_diffusion_30/models/unet_mv2d_condition.py ADDED
@@ -0,0 +1,1491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+ import os
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.loaders import UNet2DConditionLoadersMixin
24
+ from diffusers.utils import BaseOutput, logging
25
+ from diffusers.models.activations import get_activation
26
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
27
+ from diffusers.models.embeddings import (
28
+ GaussianFourierProjection,
29
+ ImageHintTimeEmbedding,
30
+ ImageProjection,
31
+ ImageTimeEmbedding,
32
+ TextImageProjection,
33
+ TextImageTimeEmbedding,
34
+ TextTimeEmbedding,
35
+ TimestepEmbedding,
36
+ Timesteps,
37
+ )
38
+ from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model
39
+ from diffusers.models.unets.unet_2d_blocks import (
40
+ CrossAttnDownBlock2D,
41
+ CrossAttnUpBlock2D,
42
+ DownBlock2D,
43
+ UNetMidBlock2DCrossAttn,
44
+ UNetMidBlock2DSimpleCrossAttn,
45
+ UpBlock2D,
46
+ )
47
+ from diffusers.utils import (
48
+ CONFIG_NAME,
49
+ HF_MODULES_CACHE,
50
+ FLAX_WEIGHTS_NAME,
51
+ SAFETENSORS_WEIGHTS_NAME,
52
+ WEIGHTS_NAME,
53
+ _add_variant,
54
+ _get_model_file,
55
+ deprecate,
56
+ is_accelerate_available,
57
+ is_safetensors_available,
58
+ is_torch_version,
59
+ logging,
60
+ )
61
+ from diffusers import __version__
62
+ from mv_diffusion_30.models.unet_mv2d_blocks import (
63
+ CrossAttnDownBlockMV2D,
64
+ CrossAttnUpBlockMV2D,
65
+ UNetMidBlockMV2DCrossAttn,
66
+ get_down_block,
67
+ get_up_block,
68
+ )
69
+
70
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
71
+
72
+
73
+ @dataclass
74
+ class UNetMV2DConditionOutput(BaseOutput):
75
+ """
76
+ The output of [`UNet2DConditionModel`].
77
+
78
+ Args:
79
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
80
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
81
+ """
82
+
83
+ sample: torch.FloatTensor = None
84
+
85
+
86
+ class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
87
+ r"""
88
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
89
+ shaped output.
90
+
91
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
92
+ for all models (such as downloading or saving).
93
+
94
+ Parameters:
95
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
96
+ Height and width of input/output sample.
97
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
98
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
99
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
100
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
101
+ Whether to flip the sin to cos in the time embedding.
102
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
103
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
104
+ The tuple of downsample blocks to use.
105
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
106
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
107
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
108
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
109
+ The tuple of upsample blocks to use.
110
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
111
+ Whether to include self-attention in the basic transformer blocks, see
112
+ [`~models.attention.BasicTransformerBlock`].
113
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
114
+ The tuple of output channels for each block.
115
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
116
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
117
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
118
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
119
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
120
+ If `None`, normalization and activation layers is skipped in post-processing.
121
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
122
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
123
+ The dimension of the cross attention features.
124
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
125
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
126
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
127
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
128
+ encoder_hid_dim (`int`, *optional*, defaults to None):
129
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
130
+ dimension to `cross_attention_dim`.
131
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
132
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
133
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
134
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
135
+ num_attention_heads (`int`, *optional*):
136
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
137
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
138
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
139
+ class_embed_type (`str`, *optional*, defaults to `None`):
140
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
141
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
142
+ addition_embed_type (`str`, *optional*, defaults to `None`):
143
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
144
+ "text". "text" will use the `TextTimeEmbedding` layer.
145
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
146
+ Dimension for the timestep embeddings.
147
+ num_class_embeds (`int`, *optional*, defaults to `None`):
148
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
149
+ class conditioning with `class_embed_type` equal to `None`.
150
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
151
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
152
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
153
+ An optional override for the dimension of the projected time embedding.
154
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
155
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
156
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
157
+ timestep_post_act (`str`, *optional*, defaults to `None`):
158
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
159
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
160
+ The dimension of `cond_proj` layer in the timestep embedding.
161
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
162
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
163
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
164
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
165
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
166
+ embeddings with the class embeddings.
167
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
168
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
169
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
170
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
171
+ otherwise.
172
+ """
173
+
174
+ _supports_gradient_checkpointing = True
175
+
176
+ @register_to_config
177
+ def __init__(
178
+ self,
179
+ sample_size: Optional[int] = None,
180
+ in_channels: int = 4,
181
+ out_channels: int = 4,
182
+ center_input_sample: bool = False,
183
+ flip_sin_to_cos: bool = True,
184
+ freq_shift: int = 0,
185
+ down_block_types: Tuple[str] = (
186
+ "CrossAttnDownBlockMV2D",
187
+ "CrossAttnDownBlockMV2D",
188
+ "CrossAttnDownBlockMV2D",
189
+ "DownBlock2D",
190
+ ),
191
+ mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn",
192
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"),
193
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
194
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
195
+ layers_per_block: Union[int, Tuple[int]] = 2,
196
+ downsample_padding: int = 1,
197
+ mid_block_scale_factor: float = 1,
198
+ act_fn: str = "silu",
199
+ norm_num_groups: Optional[int] = 32,
200
+ norm_eps: float = 1e-5,
201
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
202
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
203
+ encoder_hid_dim: Optional[int] = None,
204
+ encoder_hid_dim_type: Optional[str] = None,
205
+ attention_head_dim: Union[int, Tuple[int]] = 8,
206
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
207
+ dual_cross_attention: bool = False,
208
+ use_linear_projection: bool = False,
209
+ class_embed_type: Optional[str] = None,
210
+ addition_embed_type: Optional[str] = None,
211
+ addition_time_embed_dim: Optional[int] = None,
212
+ num_class_embeds: Optional[int] = None,
213
+ upcast_attention: bool = False,
214
+ resnet_time_scale_shift: str = "default",
215
+ resnet_skip_time_act: bool = False,
216
+ resnet_out_scale_factor: int = 1.0,
217
+ time_embedding_type: str = "positional",
218
+ time_embedding_dim: Optional[int] = None,
219
+ time_embedding_act_fn: Optional[str] = None,
220
+ timestep_post_act: Optional[str] = None,
221
+ time_cond_proj_dim: Optional[int] = None,
222
+ conv_in_kernel: int = 3,
223
+ conv_out_kernel: int = 3,
224
+ projection_class_embeddings_input_dim: Optional[int] = None,
225
+ class_embeddings_concat: bool = False,
226
+ mid_block_only_cross_attention: Optional[bool] = None,
227
+ cross_attention_norm: Optional[str] = None,
228
+ addition_embed_type_num_heads=64,
229
+ num_views: int = 1,
230
+ cd_attention_last: bool = False,
231
+ cd_attention_mid: bool = False,
232
+ multiview_attention: bool = True,
233
+ sparse_mv_attention: bool = False,
234
+ mvcd_attention: bool = False
235
+ ):
236
+ super().__init__()
237
+
238
+ self.sample_size = sample_size
239
+
240
+ if num_attention_heads is not None:
241
+ raise ValueError(
242
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
243
+ )
244
+
245
+ # If `num_attention_heads` is not defined (which is the case for most models)
246
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
247
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
248
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
249
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
250
+ # which is why we correct for the naming here.
251
+ num_attention_heads = num_attention_heads or attention_head_dim
252
+
253
+ # Check inputs
254
+ if len(down_block_types) != len(up_block_types):
255
+ raise ValueError(
256
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
257
+ )
258
+
259
+ if len(block_out_channels) != len(down_block_types):
260
+ raise ValueError(
261
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
262
+ )
263
+
264
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
265
+ raise ValueError(
266
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
267
+ )
268
+
269
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
270
+ raise ValueError(
271
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
272
+ )
273
+
274
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
275
+ raise ValueError(
276
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
277
+ )
278
+
279
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
280
+ raise ValueError(
281
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
282
+ )
283
+
284
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
285
+ raise ValueError(
286
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
287
+ )
288
+
289
+ # input
290
+ conv_in_padding = (conv_in_kernel - 1) // 2
291
+ self.conv_in = nn.Conv2d(
292
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
293
+ )
294
+
295
+ # time
296
+ if time_embedding_type == "fourier":
297
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
298
+ if time_embed_dim % 2 != 0:
299
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
300
+ self.time_proj = GaussianFourierProjection(
301
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
302
+ )
303
+ timestep_input_dim = time_embed_dim
304
+ elif time_embedding_type == "positional":
305
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
306
+
307
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
308
+ timestep_input_dim = block_out_channels[0]
309
+ else:
310
+ raise ValueError(
311
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
312
+ )
313
+
314
+ self.time_embedding = TimestepEmbedding(
315
+ timestep_input_dim,
316
+ time_embed_dim,
317
+ act_fn=act_fn,
318
+ post_act_fn=timestep_post_act,
319
+ cond_proj_dim=time_cond_proj_dim,
320
+ )
321
+
322
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
323
+ encoder_hid_dim_type = "text_proj"
324
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
325
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
326
+
327
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
328
+ raise ValueError(
329
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
330
+ )
331
+
332
+ if encoder_hid_dim_type == "text_proj":
333
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
334
+ elif encoder_hid_dim_type == "text_image_proj":
335
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
336
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
337
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
338
+ self.encoder_hid_proj = TextImageProjection(
339
+ text_embed_dim=encoder_hid_dim,
340
+ image_embed_dim=cross_attention_dim,
341
+ cross_attention_dim=cross_attention_dim,
342
+ )
343
+ elif encoder_hid_dim_type == "image_proj":
344
+ # Kandinsky 2.2
345
+ self.encoder_hid_proj = ImageProjection(
346
+ image_embed_dim=encoder_hid_dim,
347
+ cross_attention_dim=cross_attention_dim,
348
+ )
349
+ elif encoder_hid_dim_type is not None:
350
+ raise ValueError(
351
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
352
+ )
353
+ else:
354
+ self.encoder_hid_proj = None
355
+
356
+ # class embedding
357
+ if class_embed_type is None and num_class_embeds is not None:
358
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
359
+ elif class_embed_type == "timestep":
360
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
361
+ elif class_embed_type == "identity":
362
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
363
+ elif class_embed_type == "projection":
364
+ if projection_class_embeddings_input_dim is None:
365
+ raise ValueError(
366
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
367
+ )
368
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
369
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
370
+ # 2. it projects from an arbitrary input dimension.
371
+ #
372
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
373
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
374
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
375
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
376
+ elif class_embed_type == "simple_projection":
377
+ if projection_class_embeddings_input_dim is None:
378
+ raise ValueError(
379
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
380
+ )
381
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
382
+ else:
383
+ self.class_embedding = None
384
+
385
+ if addition_embed_type == "text":
386
+ if encoder_hid_dim is not None:
387
+ text_time_embedding_from_dim = encoder_hid_dim
388
+ else:
389
+ text_time_embedding_from_dim = cross_attention_dim
390
+
391
+ self.add_embedding = TextTimeEmbedding(
392
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
393
+ )
394
+ elif addition_embed_type == "text_image":
395
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
396
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
397
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
398
+ self.add_embedding = TextImageTimeEmbedding(
399
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
400
+ )
401
+ elif addition_embed_type == "text_time":
402
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
403
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
404
+ elif addition_embed_type == "image":
405
+ # Kandinsky 2.2
406
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
407
+ elif addition_embed_type == "image_hint":
408
+ # Kandinsky 2.2 ControlNet
409
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
410
+ elif addition_embed_type is not None:
411
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
412
+
413
+ if time_embedding_act_fn is None:
414
+ self.time_embed_act = None
415
+ else:
416
+ self.time_embed_act = get_activation(time_embedding_act_fn)
417
+
418
+ self.down_blocks = nn.ModuleList([])
419
+ self.up_blocks = nn.ModuleList([])
420
+
421
+ if isinstance(only_cross_attention, bool):
422
+ if mid_block_only_cross_attention is None:
423
+ mid_block_only_cross_attention = only_cross_attention
424
+
425
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
426
+
427
+ if mid_block_only_cross_attention is None:
428
+ mid_block_only_cross_attention = False
429
+
430
+ if isinstance(num_attention_heads, int):
431
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
432
+
433
+ if isinstance(attention_head_dim, int):
434
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
435
+
436
+ if isinstance(cross_attention_dim, int):
437
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
438
+
439
+ if isinstance(layers_per_block, int):
440
+ layers_per_block = [layers_per_block] * len(down_block_types)
441
+
442
+ if isinstance(transformer_layers_per_block, int):
443
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
444
+
445
+ if class_embeddings_concat:
446
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
447
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
448
+ # regular time embeddings
449
+ blocks_time_embed_dim = time_embed_dim * 2
450
+ else:
451
+ blocks_time_embed_dim = time_embed_dim
452
+
453
+ # down
454
+ output_channel = block_out_channels[0]
455
+ for i, down_block_type in enumerate(down_block_types):
456
+ input_channel = output_channel
457
+ output_channel = block_out_channels[i]
458
+ is_final_block = i == len(block_out_channels) - 1
459
+
460
+ down_block = get_down_block(
461
+ down_block_type,
462
+ num_layers=layers_per_block[i],
463
+ transformer_layers_per_block=transformer_layers_per_block[i],
464
+ in_channels=input_channel,
465
+ out_channels=output_channel,
466
+ temb_channels=blocks_time_embed_dim,
467
+ add_downsample=not is_final_block,
468
+ resnet_eps=norm_eps,
469
+ resnet_act_fn=act_fn,
470
+ resnet_groups=norm_num_groups,
471
+ cross_attention_dim=cross_attention_dim[i],
472
+ num_attention_heads=num_attention_heads[i],
473
+ downsample_padding=downsample_padding,
474
+ dual_cross_attention=dual_cross_attention,
475
+ use_linear_projection=use_linear_projection,
476
+ only_cross_attention=only_cross_attention[i],
477
+ upcast_attention=upcast_attention,
478
+ resnet_time_scale_shift=resnet_time_scale_shift,
479
+ resnet_skip_time_act=resnet_skip_time_act,
480
+ resnet_out_scale_factor=resnet_out_scale_factor,
481
+ cross_attention_norm=cross_attention_norm,
482
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
483
+ num_views=num_views,
484
+ cd_attention_last=cd_attention_last,
485
+ cd_attention_mid=cd_attention_mid,
486
+ multiview_attention=multiview_attention,
487
+ sparse_mv_attention=sparse_mv_attention,
488
+ mvcd_attention=mvcd_attention
489
+ )
490
+ self.down_blocks.append(down_block)
491
+
492
+ # mid
493
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
494
+ self.mid_block = UNetMidBlock2DCrossAttn(
495
+ transformer_layers_per_block=transformer_layers_per_block[-1],
496
+ in_channels=block_out_channels[-1],
497
+ temb_channels=blocks_time_embed_dim,
498
+ resnet_eps=norm_eps,
499
+ resnet_act_fn=act_fn,
500
+ output_scale_factor=mid_block_scale_factor,
501
+ resnet_time_scale_shift=resnet_time_scale_shift,
502
+ cross_attention_dim=cross_attention_dim[-1],
503
+ num_attention_heads=num_attention_heads[-1],
504
+ resnet_groups=norm_num_groups,
505
+ dual_cross_attention=dual_cross_attention,
506
+ use_linear_projection=use_linear_projection,
507
+ upcast_attention=upcast_attention,
508
+ )
509
+ # custom MV2D attention block
510
+ elif mid_block_type == "UNetMidBlockMV2DCrossAttn":
511
+ self.mid_block = UNetMidBlockMV2DCrossAttn(
512
+ transformer_layers_per_block=transformer_layers_per_block[-1],
513
+ in_channels=block_out_channels[-1],
514
+ temb_channels=blocks_time_embed_dim,
515
+ resnet_eps=norm_eps,
516
+ resnet_act_fn=act_fn,
517
+ output_scale_factor=mid_block_scale_factor,
518
+ resnet_time_scale_shift=resnet_time_scale_shift,
519
+ cross_attention_dim=cross_attention_dim[-1],
520
+ num_attention_heads=num_attention_heads[-1],
521
+ resnet_groups=norm_num_groups,
522
+ dual_cross_attention=dual_cross_attention,
523
+ use_linear_projection=use_linear_projection,
524
+ upcast_attention=upcast_attention,
525
+ num_views=num_views,
526
+ cd_attention_last=cd_attention_last,
527
+ cd_attention_mid=cd_attention_mid,
528
+ multiview_attention=multiview_attention,
529
+ sparse_mv_attention=sparse_mv_attention,
530
+ mvcd_attention=mvcd_attention
531
+ )
532
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
533
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
534
+ in_channels=block_out_channels[-1],
535
+ temb_channels=blocks_time_embed_dim,
536
+ resnet_eps=norm_eps,
537
+ resnet_act_fn=act_fn,
538
+ output_scale_factor=mid_block_scale_factor,
539
+ cross_attention_dim=cross_attention_dim[-1],
540
+ attention_head_dim=attention_head_dim[-1],
541
+ resnet_groups=norm_num_groups,
542
+ resnet_time_scale_shift=resnet_time_scale_shift,
543
+ skip_time_act=resnet_skip_time_act,
544
+ only_cross_attention=mid_block_only_cross_attention,
545
+ cross_attention_norm=cross_attention_norm,
546
+ )
547
+ elif mid_block_type is None:
548
+ self.mid_block = None
549
+ else:
550
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
551
+
552
+ # count how many layers upsample the images
553
+ self.num_upsamplers = 0
554
+
555
+ # up
556
+ reversed_block_out_channels = list(reversed(block_out_channels))
557
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
558
+ reversed_layers_per_block = list(reversed(layers_per_block))
559
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
560
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
561
+ only_cross_attention = list(reversed(only_cross_attention))
562
+
563
+ output_channel = reversed_block_out_channels[0]
564
+ for i, up_block_type in enumerate(up_block_types):
565
+ is_final_block = i == len(block_out_channels) - 1
566
+
567
+ prev_output_channel = output_channel
568
+ output_channel = reversed_block_out_channels[i]
569
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
570
+
571
+ # add upsample block for all BUT final layer
572
+ if not is_final_block:
573
+ add_upsample = True
574
+ self.num_upsamplers += 1
575
+ else:
576
+ add_upsample = False
577
+
578
+ up_block = get_up_block(
579
+ up_block_type,
580
+ num_layers=reversed_layers_per_block[i] + 1,
581
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
582
+ in_channels=input_channel,
583
+ out_channels=output_channel,
584
+ prev_output_channel=prev_output_channel,
585
+ temb_channels=blocks_time_embed_dim,
586
+ add_upsample=add_upsample,
587
+ resnet_eps=norm_eps,
588
+ resnet_act_fn=act_fn,
589
+ resnet_groups=norm_num_groups,
590
+ cross_attention_dim=reversed_cross_attention_dim[i],
591
+ num_attention_heads=reversed_num_attention_heads[i],
592
+ dual_cross_attention=dual_cross_attention,
593
+ use_linear_projection=use_linear_projection,
594
+ only_cross_attention=only_cross_attention[i],
595
+ upcast_attention=upcast_attention,
596
+ resnet_time_scale_shift=resnet_time_scale_shift,
597
+ resnet_skip_time_act=resnet_skip_time_act,
598
+ resnet_out_scale_factor=resnet_out_scale_factor,
599
+ cross_attention_norm=cross_attention_norm,
600
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
601
+ num_views=num_views,
602
+ cd_attention_last=cd_attention_last,
603
+ cd_attention_mid=cd_attention_mid,
604
+ multiview_attention=multiview_attention,
605
+ sparse_mv_attention=sparse_mv_attention,
606
+ mvcd_attention=mvcd_attention
607
+ )
608
+ self.up_blocks.append(up_block)
609
+ prev_output_channel = output_channel
610
+
611
+ # out
612
+ if norm_num_groups is not None:
613
+ self.conv_norm_out = nn.GroupNorm(
614
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
615
+ )
616
+
617
+ self.conv_act = get_activation(act_fn)
618
+
619
+ else:
620
+ self.conv_norm_out = None
621
+ self.conv_act = None
622
+
623
+ conv_out_padding = (conv_out_kernel - 1) // 2
624
+ self.conv_out = nn.Conv2d(
625
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
626
+ )
627
+
628
+ @property
629
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
630
+ r"""
631
+ Returns:
632
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
633
+ indexed by its weight name.
634
+ """
635
+ # set recursively
636
+ processors = {}
637
+
638
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
639
+ if hasattr(module, "set_processor"):
640
+ processors[f"{name}.processor"] = module.processor
641
+
642
+ for sub_name, child in module.named_children():
643
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
644
+
645
+ return processors
646
+
647
+ for name, module in self.named_children():
648
+ fn_recursive_add_processors(name, module, processors)
649
+
650
+ return processors
651
+
652
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
653
+ r"""
654
+ Sets the attention processor to use to compute attention.
655
+
656
+ Parameters:
657
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
658
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
659
+ for **all** `Attention` layers.
660
+
661
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
662
+ processor. This is strongly recommended when setting trainable attention processors.
663
+
664
+ """
665
+ count = len(self.attn_processors.keys())
666
+
667
+ if isinstance(processor, dict) and len(processor) != count:
668
+ raise ValueError(
669
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
670
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
671
+ )
672
+
673
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
674
+ if hasattr(module, "set_processor"):
675
+ if not isinstance(processor, dict):
676
+ module.set_processor(processor)
677
+ else:
678
+ module.set_processor(processor.pop(f"{name}.processor"))
679
+
680
+ for sub_name, child in module.named_children():
681
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
682
+
683
+ for name, module in self.named_children():
684
+ fn_recursive_attn_processor(name, module, processor)
685
+
686
+ def set_default_attn_processor(self):
687
+ """
688
+ Disables custom attention processors and sets the default attention implementation.
689
+ """
690
+ self.set_attn_processor(AttnProcessor())
691
+
692
+ def set_attention_slice(self, slice_size):
693
+ r"""
694
+ Enable sliced attention computation.
695
+
696
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
697
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
698
+
699
+ Args:
700
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
701
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
702
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
703
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
704
+ must be a multiple of `slice_size`.
705
+ """
706
+ sliceable_head_dims = []
707
+
708
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
709
+ if hasattr(module, "set_attention_slice"):
710
+ sliceable_head_dims.append(module.sliceable_head_dim)
711
+
712
+ for child in module.children():
713
+ fn_recursive_retrieve_sliceable_dims(child)
714
+
715
+ # retrieve number of attention layers
716
+ for module in self.children():
717
+ fn_recursive_retrieve_sliceable_dims(module)
718
+
719
+ num_sliceable_layers = len(sliceable_head_dims)
720
+
721
+ if slice_size == "auto":
722
+ # half the attention head size is usually a good trade-off between
723
+ # speed and memory
724
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
725
+ elif slice_size == "max":
726
+ # make smallest slice possible
727
+ slice_size = num_sliceable_layers * [1]
728
+
729
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
730
+
731
+ if len(slice_size) != len(sliceable_head_dims):
732
+ raise ValueError(
733
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
734
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
735
+ )
736
+
737
+ for i in range(len(slice_size)):
738
+ size = slice_size[i]
739
+ dim = sliceable_head_dims[i]
740
+ if size is not None and size > dim:
741
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
742
+
743
+ # Recursively walk through all the children.
744
+ # Any children which exposes the set_attention_slice method
745
+ # gets the message
746
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
747
+ if hasattr(module, "set_attention_slice"):
748
+ module.set_attention_slice(slice_size.pop())
749
+
750
+ for child in module.children():
751
+ fn_recursive_set_attention_slice(child, slice_size)
752
+
753
+ reversed_slice_size = list(reversed(slice_size))
754
+ for module in self.children():
755
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
756
+
757
+ def _set_gradient_checkpointing(self, module, value=False):
758
+ if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)):
759
+ module.gradient_checkpointing = value
760
+
761
+ def forward(
762
+ self,
763
+ sample: torch.FloatTensor,
764
+ timestep: Union[torch.Tensor, float, int],
765
+ encoder_hidden_states: torch.Tensor,
766
+ class_labels: Optional[torch.Tensor] = None,
767
+ timestep_cond: Optional[torch.Tensor] = None,
768
+ attention_mask: Optional[torch.Tensor] = None,
769
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
770
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
771
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
772
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
773
+ encoder_attention_mask: Optional[torch.Tensor] = None,
774
+ return_dict: bool = True,
775
+ ) -> Union[UNetMV2DConditionOutput, Tuple]:
776
+ r"""
777
+ The [`UNet2DConditionModel`] forward method.
778
+
779
+ Args:
780
+ sample (`torch.FloatTensor`):
781
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
782
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
783
+ encoder_hidden_states (`torch.FloatTensor`):
784
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
785
+ encoder_attention_mask (`torch.Tensor`):
786
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
787
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
788
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
789
+ return_dict (`bool`, *optional*, defaults to `True`):
790
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
791
+ tuple.
792
+ cross_attention_kwargs (`dict`, *optional*):
793
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
794
+ added_cond_kwargs: (`dict`, *optional*):
795
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
796
+ are passed along to the UNet blocks.
797
+
798
+ Returns:
799
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
800
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
801
+ a `tuple` is returned where the first element is the sample tensor.
802
+ """
803
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
804
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
805
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
806
+ # on the fly if necessary.
807
+ default_overall_up_factor = 2**self.num_upsamplers
808
+
809
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
810
+ forward_upsample_size = False
811
+ upsample_size = None
812
+
813
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
814
+ logger.info("Forward upsample size to force interpolation output size.")
815
+ forward_upsample_size = True
816
+
817
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
818
+ # expects mask of shape:
819
+ # [batch, key_tokens]
820
+ # adds singleton query_tokens dimension:
821
+ # [batch, 1, key_tokens]
822
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
823
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
824
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
825
+ if attention_mask is not None:
826
+ # assume that mask is expressed as:
827
+ # (1 = keep, 0 = discard)
828
+ # convert mask into a bias that can be added to attention scores:
829
+ # (keep = +0, discard = -10000.0)
830
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
831
+ attention_mask = attention_mask.unsqueeze(1)
832
+
833
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
834
+ if encoder_attention_mask is not None:
835
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
836
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
837
+
838
+ # 0. center input if necessary
839
+ if self.config.center_input_sample:
840
+ sample = 2 * sample - 1.0
841
+
842
+ # 1. time
843
+ timesteps = timestep
844
+ if not torch.is_tensor(timesteps):
845
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
846
+ # This would be a good case for the `match` statement (Python 3.10+)
847
+ is_mps = sample.device.type == "mps"
848
+ if isinstance(timestep, float):
849
+ dtype = torch.float32 if is_mps else torch.float64
850
+ else:
851
+ dtype = torch.int32 if is_mps else torch.int64
852
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
853
+ elif len(timesteps.shape) == 0:
854
+ timesteps = timesteps[None].to(sample.device)
855
+
856
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
857
+ timesteps = timesteps.expand(sample.shape[0])
858
+
859
+ t_emb = self.time_proj(timesteps)
860
+
861
+ # `Timesteps` does not contain any weights and will always return f32 tensors
862
+ # but time_embedding might actually be running in fp16. so we need to cast here.
863
+ # there might be better ways to encapsulate this.
864
+ t_emb = t_emb.to(dtype=sample.dtype)
865
+
866
+ # self.time_embedding.to(dtype=t_emb.dtype)
867
+ emb = self.time_embedding(t_emb, timestep_cond)
868
+ aug_emb = None
869
+
870
+ if self.class_embedding is not None:
871
+ if class_labels is None:
872
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
873
+
874
+ if self.config.class_embed_type == "timestep":
875
+ class_labels = self.time_proj(class_labels)
876
+
877
+ # `Timesteps` does not contain any weights and will always return f32 tensors
878
+ # there might be better ways to encapsulate this.
879
+ class_labels = class_labels.to(dtype=sample.dtype)
880
+
881
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
882
+
883
+ if self.config.class_embeddings_concat:
884
+ emb = torch.cat([emb, class_emb], dim=-1)
885
+ else:
886
+ emb = emb + class_emb
887
+
888
+ if self.config.addition_embed_type == "text":
889
+ aug_emb = self.add_embedding(encoder_hidden_states)
890
+ elif self.config.addition_embed_type == "text_image":
891
+ # Kandinsky 2.1 - style
892
+ if "image_embeds" not in added_cond_kwargs:
893
+ raise ValueError(
894
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
895
+ )
896
+
897
+ image_embs = added_cond_kwargs.get("image_embeds")
898
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
899
+ aug_emb = self.add_embedding(text_embs, image_embs)
900
+ elif self.config.addition_embed_type == "text_time":
901
+ # SDXL - style
902
+ if "text_embeds" not in added_cond_kwargs:
903
+ raise ValueError(
904
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
905
+ )
906
+ text_embeds = added_cond_kwargs.get("text_embeds")
907
+ if "time_ids" not in added_cond_kwargs:
908
+ raise ValueError(
909
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
910
+ )
911
+ time_ids = added_cond_kwargs.get("time_ids")
912
+ time_embeds = self.add_time_proj(time_ids.flatten())
913
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
914
+
915
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
916
+ add_embeds = add_embeds.to(emb.dtype)
917
+ aug_emb = self.add_embedding(add_embeds)
918
+ elif self.config.addition_embed_type == "image":
919
+ # Kandinsky 2.2 - style
920
+ if "image_embeds" not in added_cond_kwargs:
921
+ raise ValueError(
922
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
923
+ )
924
+ image_embs = added_cond_kwargs.get("image_embeds")
925
+ aug_emb = self.add_embedding(image_embs)
926
+ elif self.config.addition_embed_type == "image_hint":
927
+ # Kandinsky 2.2 - style
928
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
929
+ raise ValueError(
930
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
931
+ )
932
+ image_embs = added_cond_kwargs.get("image_embeds")
933
+ hint = added_cond_kwargs.get("hint")
934
+ aug_emb, hint = self.add_embedding(image_embs, hint)
935
+ sample = torch.cat([sample, hint], dim=1)
936
+
937
+ emb = emb + aug_emb if aug_emb is not None else emb
938
+
939
+ if self.time_embed_act is not None:
940
+ emb = self.time_embed_act(emb)
941
+
942
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
943
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
944
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
945
+ # Kadinsky 2.1 - style
946
+ if "image_embeds" not in added_cond_kwargs:
947
+ raise ValueError(
948
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
949
+ )
950
+
951
+ image_embeds = added_cond_kwargs.get("image_embeds")
952
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
953
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
954
+ # Kandinsky 2.2 - style
955
+ if "image_embeds" not in added_cond_kwargs:
956
+ raise ValueError(
957
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
958
+ )
959
+ image_embeds = added_cond_kwargs.get("image_embeds")
960
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
961
+ # 2. pre-process
962
+ sample = self.conv_in(sample)
963
+
964
+ # 3. down
965
+
966
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
967
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
968
+
969
+ down_block_res_samples = (sample,)
970
+ for downsample_block in self.down_blocks:
971
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
972
+ # For t2i-adapter CrossAttnDownBlock2D
973
+ additional_residuals = {}
974
+ if is_adapter and len(down_block_additional_residuals) > 0:
975
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
976
+
977
+ sample, res_samples = downsample_block(
978
+ hidden_states=sample,
979
+ temb=emb,
980
+ encoder_hidden_states=encoder_hidden_states,
981
+ attention_mask=attention_mask,
982
+ cross_attention_kwargs=cross_attention_kwargs,
983
+ encoder_attention_mask=encoder_attention_mask,
984
+ **additional_residuals,
985
+ )
986
+ else:
987
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
988
+
989
+ if is_adapter and len(down_block_additional_residuals) > 0:
990
+ sample += down_block_additional_residuals.pop(0)
991
+
992
+ down_block_res_samples += res_samples
993
+
994
+ if is_controlnet:
995
+ new_down_block_res_samples = ()
996
+
997
+ for down_block_res_sample, down_block_additional_residual in zip(
998
+ down_block_res_samples, down_block_additional_residuals
999
+ ):
1000
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1001
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1002
+
1003
+ down_block_res_samples = new_down_block_res_samples
1004
+
1005
+ # 4. mid
1006
+ if self.mid_block is not None:
1007
+ sample = self.mid_block(
1008
+ sample,
1009
+ emb,
1010
+ encoder_hidden_states=encoder_hidden_states,
1011
+ attention_mask=attention_mask,
1012
+ cross_attention_kwargs=cross_attention_kwargs,
1013
+ encoder_attention_mask=encoder_attention_mask,
1014
+ )
1015
+
1016
+ if is_controlnet:
1017
+ sample = sample + mid_block_additional_residual
1018
+
1019
+ # 5. up
1020
+ for i, upsample_block in enumerate(self.up_blocks):
1021
+ is_final_block = i == len(self.up_blocks) - 1
1022
+
1023
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1024
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1025
+
1026
+ # if we have not reached the final block and need to forward the
1027
+ # upsample size, we do it here
1028
+ if not is_final_block and forward_upsample_size:
1029
+ upsample_size = down_block_res_samples[-1].shape[2:]
1030
+
1031
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1032
+ sample = upsample_block(
1033
+ hidden_states=sample,
1034
+ temb=emb,
1035
+ res_hidden_states_tuple=res_samples,
1036
+ encoder_hidden_states=encoder_hidden_states,
1037
+ cross_attention_kwargs=cross_attention_kwargs,
1038
+ upsample_size=upsample_size,
1039
+ attention_mask=attention_mask,
1040
+ encoder_attention_mask=encoder_attention_mask,
1041
+ )
1042
+ else:
1043
+ sample = upsample_block(
1044
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1045
+ )
1046
+
1047
+ # 6. post-process
1048
+ if self.conv_norm_out:
1049
+ sample = self.conv_norm_out(sample)
1050
+ sample = self.conv_act(sample)
1051
+ sample = self.conv_out(sample)
1052
+
1053
+ if not return_dict:
1054
+ return (sample,)
1055
+
1056
+ return UNetMV2DConditionOutput(sample=sample)
1057
+
1058
+ @classmethod
1059
+ def from_pretrained_2d(
1060
+ cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
1061
+ camera_embedding_type: str, num_views: int, sample_size: int,
1062
+ zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False,
1063
+ projection_class_embeddings_input_dim: int=6, cd_attention_last: bool = False,
1064
+ cd_attention_mid: bool = False, multiview_attention: bool = True,
1065
+ sparse_mv_attention: bool = False, mvcd_attention: bool = False,
1066
+ in_channels: int = 8, out_channels: int = 4,
1067
+ **kwargs
1068
+ ):
1069
+ r"""
1070
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
1071
+
1072
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
1073
+ train the model, set it back in training mode with `model.train()`.
1074
+
1075
+ Parameters:
1076
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
1077
+ Can be either:
1078
+
1079
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
1080
+ the Hub.
1081
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
1082
+ with [`~ModelMixin.save_pretrained`].
1083
+
1084
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
1085
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1086
+ is not used.
1087
+ torch_dtype (`str` or `torch.dtype`, *optional*):
1088
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
1089
+ dtype is automatically derived from the model's weights.
1090
+ force_download (`bool`, *optional*, defaults to `False`):
1091
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1092
+ cached versions if they exist.
1093
+ resume_download (`bool`, *optional*, defaults to `False`):
1094
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
1095
+ incompletely downloaded files are deleted.
1096
+ proxies (`Dict[str, str]`, *optional*):
1097
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1098
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1099
+ output_loading_info (`bool`, *optional*, defaults to `False`):
1100
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
1101
+ local_files_only(`bool`, *optional*, defaults to `False`):
1102
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
1103
+ won't be downloaded from the Hub.
1104
+ use_auth_token (`str` or *bool*, *optional*):
1105
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1106
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
1107
+ revision (`str`, *optional*, defaults to `"main"`):
1108
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1109
+ allowed by Git.
1110
+ from_flax (`bool`, *optional*, defaults to `False`):
1111
+ Load the model weights from a Flax checkpoint save file.
1112
+ subfolder (`str`, *optional*, defaults to `""`):
1113
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
1114
+ mirror (`str`, *optional*):
1115
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
1116
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
1117
+ information.
1118
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
1119
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
1120
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
1121
+ same device.
1122
+
1123
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
1124
+ more information about each option see [designing a device
1125
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
1126
+ max_memory (`Dict`, *optional*):
1127
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
1128
+ each GPU and the available CPU RAM if unset.
1129
+ offload_folder (`str` or `os.PathLike`, *optional*):
1130
+ The path to offload weights if `device_map` contains the value `"disk"`.
1131
+ offload_state_dict (`bool`, *optional*):
1132
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
1133
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
1134
+ when there is some disk offload.
1135
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
1136
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
1137
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
1138
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
1139
+ argument to `True` will raise an error.
1140
+ variant (`str`, *optional*):
1141
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
1142
+ loading `from_flax`.
1143
+ use_safetensors (`bool`, *optional*, defaults to `None`):
1144
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
1145
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
1146
+ weights. If set to `False`, `safetensors` weights are not loaded.
1147
+
1148
+ <Tip>
1149
+
1150
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
1151
+ `huggingface-cli login`. You can also activate the special
1152
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
1153
+ firewalled environment.
1154
+
1155
+ </Tip>
1156
+
1157
+ Example:
1158
+
1159
+ ```py
1160
+ from diffusers import UNet2DConditionModel
1161
+
1162
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
1163
+ ```
1164
+
1165
+ If you get the error message below, you need to finetune the weights for your downstream task:
1166
+
1167
+ ```bash
1168
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
1169
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
1170
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
1171
+ ```
1172
+ """
1173
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1174
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
1175
+ force_download = kwargs.pop("force_download", False)
1176
+ from_flax = kwargs.pop("from_flax", False)
1177
+ resume_download = kwargs.pop("resume_download", False)
1178
+ proxies = kwargs.pop("proxies", None)
1179
+ output_loading_info = kwargs.pop("output_loading_info", False)
1180
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1181
+ use_auth_token = kwargs.pop("use_auth_token", None)
1182
+ revision = kwargs.pop("revision", None)
1183
+ torch_dtype = kwargs.pop("torch_dtype", None)
1184
+ subfolder = kwargs.pop("subfolder", None)
1185
+ device_map = kwargs.pop("device_map", None)
1186
+ max_memory = kwargs.pop("max_memory", None)
1187
+ offload_folder = kwargs.pop("offload_folder", None)
1188
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
1189
+ variant = kwargs.pop("variant", None)
1190
+ use_safetensors = kwargs.pop("use_safetensors", None)
1191
+
1192
+ if use_safetensors and not is_safetensors_available():
1193
+ raise ValueError(
1194
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
1195
+ )
1196
+
1197
+ allow_pickle = False
1198
+ if use_safetensors is None:
1199
+ use_safetensors = is_safetensors_available()
1200
+ allow_pickle = True
1201
+
1202
+ if device_map is not None and not is_accelerate_available():
1203
+ raise NotImplementedError(
1204
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
1205
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
1206
+ )
1207
+
1208
+ # Check if we can handle device_map and dispatching the weights
1209
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
1210
+ raise NotImplementedError(
1211
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
1212
+ " `device_map=None`."
1213
+ )
1214
+
1215
+ # Load config if we don't provide a configuration
1216
+ config_path = pretrained_model_name_or_path
1217
+
1218
+ user_agent = {
1219
+ "diffusers": __version__,
1220
+ "file_type": "model",
1221
+ "framework": "pytorch",
1222
+ }
1223
+
1224
+ # load config
1225
+ config, unused_kwargs, commit_hash = cls.load_config(
1226
+ config_path,
1227
+ cache_dir=cache_dir,
1228
+ return_unused_kwargs=True,
1229
+ return_commit_hash=True,
1230
+ force_download=force_download,
1231
+ resume_download=resume_download,
1232
+ proxies=proxies,
1233
+ local_files_only=local_files_only,
1234
+ use_auth_token=use_auth_token,
1235
+ revision=revision,
1236
+ subfolder=subfolder,
1237
+ device_map=device_map,
1238
+ max_memory=max_memory,
1239
+ offload_folder=offload_folder,
1240
+ offload_state_dict=offload_state_dict,
1241
+ user_agent=user_agent,
1242
+ **kwargs,
1243
+ )
1244
+
1245
+ # modify config
1246
+ config["_class_name"] = cls.__name__
1247
+ config['in_channels'] = in_channels
1248
+ config['out_channels'] = out_channels
1249
+ config['sample_size'] = sample_size # training resolution
1250
+ config['num_views'] = num_views
1251
+ config['cd_attention_last'] = cd_attention_last
1252
+ config['cd_attention_mid'] = cd_attention_mid
1253
+ config['multiview_attention'] = multiview_attention
1254
+ config['sparse_mv_attention'] = sparse_mv_attention
1255
+ config['mvcd_attention'] = mvcd_attention
1256
+ config["down_block_types"] = [
1257
+ "CrossAttnDownBlockMV2D",
1258
+ "CrossAttnDownBlockMV2D",
1259
+ "CrossAttnDownBlockMV2D",
1260
+ "DownBlock2D"
1261
+ ]
1262
+ config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn"
1263
+ config["up_block_types"] = [
1264
+ "UpBlock2D",
1265
+ "CrossAttnUpBlockMV2D",
1266
+ "CrossAttnUpBlockMV2D",
1267
+ "CrossAttnUpBlockMV2D"
1268
+ ]
1269
+ config['class_embed_type'] = 'projection'
1270
+ if camera_embedding_type == 'e_de_da_sincos':
1271
+ config['projection_class_embeddings_input_dim'] = projection_class_embeddings_input_dim # default 6
1272
+ else:
1273
+ raise NotImplementedError
1274
+
1275
+ # load model
1276
+ model_file = None
1277
+ if from_flax:
1278
+ raise NotImplementedError
1279
+ else:
1280
+ if use_safetensors:
1281
+ try:
1282
+ model_file = _get_model_file(
1283
+ pretrained_model_name_or_path,
1284
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
1285
+ cache_dir=cache_dir,
1286
+ force_download=force_download,
1287
+ resume_download=resume_download,
1288
+ proxies=proxies,
1289
+ local_files_only=local_files_only,
1290
+ use_auth_token=use_auth_token,
1291
+ revision=revision,
1292
+ subfolder=subfolder,
1293
+ user_agent=user_agent,
1294
+ commit_hash=commit_hash,
1295
+ )
1296
+ except IOError as e:
1297
+ if not allow_pickle:
1298
+ raise e
1299
+ pass
1300
+ if model_file is None:
1301
+ model_file = _get_model_file(
1302
+ pretrained_model_name_or_path,
1303
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
1304
+ cache_dir=cache_dir,
1305
+ force_download=force_download,
1306
+ resume_download=resume_download,
1307
+ proxies=proxies,
1308
+ local_files_only=local_files_only,
1309
+ use_auth_token=use_auth_token,
1310
+ revision=revision,
1311
+ subfolder=subfolder,
1312
+ user_agent=user_agent,
1313
+ commit_hash=commit_hash,
1314
+ )
1315
+
1316
+ model = cls.from_config(config, **unused_kwargs)
1317
+ import copy
1318
+ state_dict_v0 = load_state_dict(model_file, variant=variant)
1319
+ state_dict = copy.deepcopy(state_dict_v0)
1320
+ # attn_joint -> attn_joint_last; norm_joint -> norm_joint_last
1321
+ # attn_joint_twice -> attn_joint_mid; norm_joint_twice -> norm_joint_mid
1322
+ for key in state_dict_v0:
1323
+ if 'attn_joint.' in key:
1324
+ tmp = copy.deepcopy(key)
1325
+ state_dict[key.replace("attn_joint.", "attn_joint_last.")] = state_dict.pop(tmp)
1326
+ if 'norm_joint.' in key:
1327
+ tmp = copy.deepcopy(key)
1328
+ state_dict[key.replace("norm_joint.", "norm_joint_last.")] = state_dict.pop(tmp)
1329
+ if 'attn_joint_twice.' in key:
1330
+ tmp = copy.deepcopy(key)
1331
+ state_dict[key.replace("attn_joint_twice.", "attn_joint_mid.")] = state_dict.pop(tmp)
1332
+ if 'norm_joint_twice.' in key:
1333
+ tmp = copy.deepcopy(key)
1334
+ state_dict[key.replace("norm_joint_twice.", "norm_joint_mid.")] = state_dict.pop(tmp)
1335
+
1336
+ model._convert_deprecated_attention_blocks(state_dict)
1337
+
1338
+ conv_in_weight = state_dict['conv_in.weight']
1339
+ conv_out_weight = state_dict['conv_out.weight']
1340
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d(
1341
+ model,
1342
+ state_dict,
1343
+ model_file,
1344
+ pretrained_model_name_or_path,
1345
+ ignore_mismatched_sizes=True,
1346
+ )
1347
+ if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]):
1348
+ # initialize from the original SD structure
1349
+ model.conv_in.weight.data[:,:4] = conv_in_weight
1350
+
1351
+ # whether to place all zero to new layers?
1352
+ if zero_init_conv_in:
1353
+ model.conv_in.weight.data[:,4:] = 0.
1354
+
1355
+ if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]):
1356
+ # initialize from the original SD structure
1357
+ model.conv_out.weight.data[:,:4] = conv_out_weight
1358
+ if out_channels == 8: # copy for the last 4 channels
1359
+ model.conv_out.weight.data[:, 4:] = conv_out_weight
1360
+
1361
+ # if zero_init_camera_projection:
1362
+ # for p in model.class_embedding.parameters():
1363
+ # torch.nn.init.zeros_(p)
1364
+
1365
+ loading_info = {
1366
+ "missing_keys": missing_keys,
1367
+ "unexpected_keys": unexpected_keys,
1368
+ "mismatched_keys": mismatched_keys,
1369
+ "error_msgs": error_msgs,
1370
+ }
1371
+
1372
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
1373
+ raise ValueError(
1374
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
1375
+ )
1376
+ elif torch_dtype is not None:
1377
+ model = model.to(torch_dtype)
1378
+
1379
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
1380
+
1381
+ # Set model in evaluation mode to deactivate DropOut modules by default
1382
+ model.eval()
1383
+ if output_loading_info:
1384
+ return model, loading_info
1385
+
1386
+ return model
1387
+
1388
+ @classmethod
1389
+ def _load_pretrained_model_2d(
1390
+ cls,
1391
+ model,
1392
+ state_dict,
1393
+ resolved_archive_file,
1394
+ pretrained_model_name_or_path,
1395
+ ignore_mismatched_sizes=False,
1396
+ ):
1397
+ # Retrieve missing & unexpected_keys
1398
+ model_state_dict = model.state_dict()
1399
+ loaded_keys = list(state_dict.keys())
1400
+
1401
+ expected_keys = list(model_state_dict.keys())
1402
+
1403
+ original_loaded_keys = loaded_keys
1404
+
1405
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
1406
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
1407
+
1408
+ # Make sure we are able to load base models as well as derived models (with heads)
1409
+ model_to_load = model
1410
+
1411
+ def _find_mismatched_keys(
1412
+ state_dict,
1413
+ model_state_dict,
1414
+ loaded_keys,
1415
+ ignore_mismatched_sizes,
1416
+ ):
1417
+ mismatched_keys = []
1418
+ if ignore_mismatched_sizes:
1419
+ for checkpoint_key in loaded_keys:
1420
+ model_key = checkpoint_key
1421
+
1422
+ if (
1423
+ model_key in model_state_dict
1424
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
1425
+ ):
1426
+ mismatched_keys.append(
1427
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
1428
+ )
1429
+ del state_dict[checkpoint_key]
1430
+ return mismatched_keys
1431
+
1432
+ if state_dict is not None:
1433
+ # Whole checkpoint
1434
+ mismatched_keys = _find_mismatched_keys(
1435
+ state_dict,
1436
+ model_state_dict,
1437
+ original_loaded_keys,
1438
+ ignore_mismatched_sizes,
1439
+ )
1440
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
1441
+
1442
+ if len(error_msgs) > 0:
1443
+ error_msg = "\n\t".join(error_msgs)
1444
+ if "size mismatch" in error_msg:
1445
+ error_msg += (
1446
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
1447
+ )
1448
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
1449
+
1450
+ if len(unexpected_keys) > 0:
1451
+ logger.warning(
1452
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
1453
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
1454
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
1455
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
1456
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
1457
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
1458
+ " identical (initializing a BertForSequenceClassification model from a"
1459
+ " BertForSequenceClassification model)."
1460
+ )
1461
+ else:
1462
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
1463
+ if len(missing_keys) > 0:
1464
+ logger.warning(
1465
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1466
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
1467
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
1468
+ )
1469
+ elif len(mismatched_keys) == 0:
1470
+ logger.info(
1471
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
1472
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
1473
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
1474
+ " without further training."
1475
+ )
1476
+ if len(mismatched_keys) > 0:
1477
+ mismatched_warning = "\n".join(
1478
+ [
1479
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
1480
+ for key, shape1, shape2 in mismatched_keys
1481
+ ]
1482
+ )
1483
+ logger.warning(
1484
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1485
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
1486
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
1487
+ " able to use it for predictions and inference."
1488
+ )
1489
+
1490
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
1491
+
mv_diffusion_30/pipelines/pipeline_mvdiffusion_image.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ import warnings
17
+ from typing import Callable, List, Optional, Union
18
+
19
+ import PIL
20
+ import torch
21
+ import torchvision.transforms.functional as TF
22
+ from packaging import version
23
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
24
+
25
+ from diffusers.configuration_utils import FrozenDict
26
+ from diffusers.image_processor import VaeImageProcessor
27
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
28
+ from diffusers.schedulers import KarrasDiffusionSchedulers
29
+ from diffusers.utils.torch_utils import logging, randn_tensor
30
+ from diffusers.utils.deprecation_utils import deprecate
31
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
32
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
33
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
34
+ from einops import rearrange, repeat
35
+
36
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
+
38
+
39
+ class MVDiffusionImagePipeline(DiffusionPipeline):
40
+ r"""
41
+ Pipeline to generate image variations from an input image using Stable Diffusion.
42
+
43
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
44
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
45
+
46
+ Args:
47
+ vae ([`AutoencoderKL`]):
48
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
49
+ image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
50
+ Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
51
+ text_encoder ([`~transformers.CLIPTextModel`]):
52
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
53
+ tokenizer ([`~transformers.CLIPTokenizer`]):
54
+ A `CLIPTokenizer` to tokenize text.
55
+ unet ([`UNet2DConditionModel`]):
56
+ A `UNet2DConditionModel` to denoise the encoded image latents.
57
+ scheduler ([`SchedulerMixin`]):
58
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
59
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
60
+ safety_checker ([`StableDiffusionSafetyChecker`]):
61
+ Classification module that estimates whether generated images could be considered offensive or harmful.
62
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
63
+ about a model's potential harms.
64
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
65
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
66
+ """
67
+ # TODO: feature_extractor is required to encode images (if they are in PIL format),
68
+ # we should give a descriptive message if the pipeline doesn't have one.
69
+ _optional_components = ["safety_checker"]
70
+
71
+ def __init__(
72
+ self,
73
+ vae: AutoencoderKL,
74
+ image_encoder: CLIPVisionModelWithProjection,
75
+ unet: UNet2DConditionModel,
76
+ scheduler: KarrasDiffusionSchedulers,
77
+ safety_checker: StableDiffusionSafetyChecker,
78
+ feature_extractor: CLIPImageProcessor,
79
+ requires_safety_checker: bool = True,
80
+ camera_embedding_type: str = 'e_de_da_sincos',
81
+ num_views: int = 6,
82
+ pred_type: str = 'color',
83
+ ):
84
+ super().__init__()
85
+
86
+ if safety_checker is None and requires_safety_checker:
87
+ logger.warn(
88
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
89
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
90
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
91
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
92
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
93
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
94
+ )
95
+
96
+ if safety_checker is not None and feature_extractor is None:
97
+ raise ValueError(
98
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
99
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
100
+ )
101
+
102
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
103
+ version.parse(unet.config._diffusers_version).base_version
104
+ ) < version.parse("0.9.0.dev0")
105
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
106
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
107
+ deprecation_message = (
108
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
109
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
110
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
111
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
112
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
113
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
114
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
115
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
116
+ " the `unet/config.json` file"
117
+ )
118
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
119
+ new_config = dict(unet.config)
120
+ new_config["sample_size"] = 64
121
+ unet._internal_dict = FrozenDict(new_config)
122
+
123
+ self.register_modules(
124
+ vae=vae,
125
+ image_encoder=image_encoder,
126
+ unet=unet,
127
+ scheduler=scheduler,
128
+ safety_checker=safety_checker,
129
+ feature_extractor=feature_extractor,
130
+ )
131
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
132
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
133
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
134
+
135
+ self.camera_embedding_type: str = camera_embedding_type
136
+ self.num_views: int = num_views
137
+ self.pred_type = pred_type
138
+
139
+ self.camera_embedding = torch.tensor(
140
+ [[ 0.0000, 0.0000, 0.0000, 1.0000, 0.0000],
141
+ [ 0.0000, -0.2362, 0.8125, 1.0000, 0.0000],
142
+ [ 0.0000, -0.1686, 1.6934, 1.0000, 0.0000],
143
+ [ 0.0000, 0.5220, 3.1406, 1.0000, 0.0000],
144
+ [ 0.0000, 0.6904, 4.8359, 1.0000, 0.0000],
145
+ [ 0.0000, 0.3733, 5.5859, 1.0000, 0.0000],
146
+ [ 0.0000, 0.0000, 0.0000, 0.0000, 1.0000],
147
+ [ 0.0000, -0.2362, 0.8125, 0.0000, 1.0000],
148
+ [ 0.0000, -0.1686, 1.6934, 0.0000, 1.0000],
149
+ [ 0.0000, 0.5220, 3.1406, 0.0000, 1.0000],
150
+ [ 0.0000, 0.6904, 4.8359, 0.0000, 1.0000],
151
+ [ 0.0000, 0.3733, 5.5859, 0.0000, 1.0000]], dtype=torch.float16)
152
+
153
+ def _encode_image(self, image_pil, device, num_images_per_prompt, do_classifier_free_guidance):
154
+ dtype = next(self.image_encoder.parameters()).dtype
155
+
156
+ image_pt = self.feature_extractor(images=image_pil, return_tensors="pt").pixel_values
157
+ image_pt = image_pt.to(device=device, dtype=dtype)
158
+ image_embeddings = self.image_encoder(image_pt).image_embeds
159
+ image_embeddings = image_embeddings.unsqueeze(1)
160
+
161
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
162
+ # Note: repeat differently from official pipelines
163
+ # B1B2B3B4 -> B1B2B3B4B1B2B3B4
164
+ bs_embed, seq_len, _ = image_embeddings.shape
165
+ image_embeddings = image_embeddings.repeat(num_images_per_prompt, 1, 1)
166
+
167
+ if do_classifier_free_guidance:
168
+ negative_prompt_embeds = torch.zeros_like(image_embeddings)
169
+
170
+ # For classifier free guidance, we need to do two forward passes.
171
+ # Here we concatenate the unconditional and text embeddings into a single batch
172
+ # to avoid doing two forward passes
173
+ image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
174
+
175
+ image_pt = torch.stack([TF.to_tensor(img) for img in image_pil], dim=0).to(device).to(dtype)
176
+ image_pt = image_pt * 2.0 - 1.0
177
+ image_latents = self.vae.encode(image_pt).latent_dist.mode() * self.vae.config.scaling_factor
178
+ # Note: repeat differently from official pipelines
179
+ # B1B2B3B4 -> B1B2B3B4B1B2B3B4
180
+ image_latents = image_latents.repeat(num_images_per_prompt, 1, 1, 1)
181
+
182
+ if do_classifier_free_guidance:
183
+ image_latents = torch.cat([torch.zeros_like(image_latents), image_latents])
184
+
185
+ return image_embeddings, image_latents
186
+
187
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
188
+ def run_safety_checker(self, image, device, dtype):
189
+ if self.safety_checker is None:
190
+ has_nsfw_concept = None
191
+ else:
192
+ if torch.is_tensor(image):
193
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
194
+ else:
195
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
196
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
197
+ image, has_nsfw_concept = self.safety_checker(
198
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
199
+ )
200
+ return image, has_nsfw_concept
201
+
202
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
203
+ def decode_latents(self, latents):
204
+ warnings.warn(
205
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
206
+ " use VaeImageProcessor instead",
207
+ FutureWarning,
208
+ )
209
+ latents = 1 / self.vae.config.scaling_factor * latents
210
+ image = self.vae.decode(latents, return_dict=False)[0]
211
+ image = (image / 2 + 0.5).clamp(0, 1)
212
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
213
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
214
+ return image
215
+
216
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
217
+ def prepare_extra_step_kwargs(self, generator, eta):
218
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
219
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
220
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
221
+ # and should be between [0, 1]
222
+
223
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
224
+ extra_step_kwargs = {}
225
+ if accepts_eta:
226
+ extra_step_kwargs["eta"] = eta
227
+
228
+ # check if the scheduler accepts generator
229
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
230
+ if accepts_generator:
231
+ extra_step_kwargs["generator"] = generator
232
+ return extra_step_kwargs
233
+
234
+ def check_inputs(self, image, height, width, callback_steps):
235
+ if (
236
+ not isinstance(image, torch.Tensor)
237
+ and not isinstance(image, PIL.Image.Image)
238
+ and not isinstance(image, list)
239
+ ):
240
+ raise ValueError(
241
+ "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
242
+ f" {type(image)}"
243
+ )
244
+
245
+ if height % 8 != 0 or width % 8 != 0:
246
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
247
+
248
+ if (callback_steps is None) or (
249
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
250
+ ):
251
+ raise ValueError(
252
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
253
+ f" {type(callback_steps)}."
254
+ )
255
+
256
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
257
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, cross_domain_latnte=True):
258
+ if cross_domain_latnte:
259
+ # generate cross-domain initial latents
260
+ # for cross-domain task, make sure the two domain are start from a same initial latents
261
+ assert batch_size % 2 == 0
262
+ batch_size = batch_size // 2
263
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
264
+ if isinstance(generator, list) and len(generator) != batch_size:
265
+ raise ValueError(
266
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
267
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
268
+ )
269
+
270
+ if latents is None:
271
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
272
+ else:
273
+ latents = latents.to(device)
274
+
275
+ # scale the initial noise by the standard deviation required by the scheduler
276
+ latents = latents * self.scheduler.init_noise_sigma
277
+ if cross_domain_latnte:
278
+ latents = torch.cat([latents] * 2)
279
+ return latents
280
+
281
+ def prepare_camera_embedding(self, camera_embedding: Union[float, torch.Tensor], do_classifier_free_guidance, num_images_per_prompt=1):
282
+ # (B, 3)
283
+ camera_embedding = camera_embedding.to(dtype=self.unet.dtype, device=self.unet.device)
284
+
285
+ if self.camera_embedding_type == 'e_de_da_sincos':
286
+ # (B, 6)
287
+ camera_embedding = torch.cat([
288
+ torch.sin(camera_embedding),
289
+ torch.cos(camera_embedding)
290
+ ], dim=-1)
291
+ assert self.unet.config.class_embed_type == 'projection'
292
+ assert self.unet.config.projection_class_embeddings_input_dim == 14 or self.unet.config.projection_class_embeddings_input_dim == 10
293
+ else:
294
+ raise NotImplementedError
295
+
296
+ # Note: repeat differently from official pipelines
297
+ # B1B2B3B4 -> B1B2B3B4B1B2B3B4
298
+ camera_embedding = camera_embedding.repeat(num_images_per_prompt, 1)
299
+
300
+ if do_classifier_free_guidance:
301
+ camera_embedding = torch.cat([
302
+ camera_embedding,
303
+ camera_embedding
304
+ ], dim=0)
305
+
306
+ return camera_embedding
307
+
308
+ def reshape_to_cd_input(self, input):
309
+ # reshape input for cross-domain attention
310
+ input_norm_uc, input_rgb_uc, input_norm_cond, input_rgb_cond = torch.chunk(
311
+ input, dim=0, chunks=4)
312
+ input = torch.cat(
313
+ [input_norm_uc, input_norm_cond, input_rgb_uc, input_rgb_cond], dim=0)
314
+ return input
315
+
316
+ def reshape_to_cfg_output(self, output):
317
+ # reshape input for cfg
318
+ output_norm_uc, output_norm_cond, output_rgb_uc, output_rgb_cond = torch.chunk(
319
+ output, dim=0, chunks=4)
320
+ output = torch.cat(
321
+ [output_norm_uc, output_rgb_uc, output_norm_cond, output_rgb_cond],
322
+ dim=0)
323
+ return output
324
+
325
+ @torch.no_grad()
326
+ def __call__(
327
+ self,
328
+ image: Union[List[PIL.Image.Image], torch.FloatTensor],
329
+ # elevation_cond: torch.FloatTensor,
330
+ # elevation: torch.FloatTensor,
331
+ # azimuth: torch.FloatTensor,
332
+ camera_embedding: Optional[torch.FloatTensor]=None,
333
+ height: Optional[int] = None,
334
+ width: Optional[int] = None,
335
+ num_inference_steps: int = 50,
336
+ guidance_scale: float = 7.5,
337
+ num_images_per_prompt: Optional[int] = 1,
338
+ eta: float = 0.0,
339
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
340
+ latents: Optional[torch.FloatTensor] = None,
341
+ output_type: Optional[str] = "pil",
342
+ return_dict: bool = True,
343
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
344
+ callback_steps: int = 1,
345
+ normal_cond: Optional[Union[List[PIL.Image.Image], torch.FloatTensor]] = None,
346
+ ):
347
+ r"""
348
+ The call function to the pipeline for generation.
349
+
350
+ Args:
351
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
352
+ Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
353
+ [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
354
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
355
+ The height in pixels of the generated image.
356
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
357
+ The width in pixels of the generated image.
358
+ num_inference_steps (`int`, *optional*, defaults to 50):
359
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
360
+ expense of slower inference. This parameter is modulated by `strength`.
361
+ guidance_scale (`float`, *optional*, defaults to 7.5):
362
+ A higher guidance scale value encourages the model to generate images closely linked to the text
363
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
364
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
365
+ The number of images to generate per prompt.
366
+ eta (`float`, *optional*, defaults to 0.0):
367
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
368
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
369
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
370
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
371
+ generation deterministic.
372
+ latents (`torch.FloatTensor`, *optional*):
373
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
374
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
375
+ tensor is generated by sampling using the supplied random `generator`.
376
+ output_type (`str`, *optional*, defaults to `"pil"`):
377
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
378
+ return_dict (`bool`, *optional*, defaults to `True`):
379
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
380
+ plain tuple.
381
+ callback (`Callable`, *optional*):
382
+ A function that calls every `callback_steps` steps during inference. The function is called with the
383
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
384
+ callback_steps (`int`, *optional*, defaults to 1):
385
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
386
+ every step.
387
+
388
+ Returns:
389
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
390
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
391
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
392
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
393
+ "not-safe-for-work" (nsfw) content.
394
+
395
+ Examples:
396
+
397
+ ```py
398
+ from diffusers import StableDiffusionImageVariationPipeline
399
+ from PIL import Image
400
+ from io import BytesIO
401
+ import requests
402
+
403
+ pipe = StableDiffusionImageVariationPipeline.from_pretrained(
404
+ "lambdalabs/sd-image-variations-diffusers", revision="v2.0"
405
+ )
406
+ pipe = pipe.to("cuda")
407
+
408
+ url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200"
409
+
410
+ response = requests.get(url)
411
+ image = Image.open(BytesIO(response.content)).convert("RGB")
412
+
413
+ out = pipe(image, num_images_per_prompt=3, guidance_scale=15)
414
+ out["images"][0].save("result.jpg")
415
+ ```
416
+ """
417
+ # 0. Default height and width to unet
418
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
419
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
420
+
421
+ # 1. Check inputs. Raise error if not correct
422
+ self.check_inputs(image, height, width, callback_steps)
423
+
424
+
425
+ # 2. Define call parameters
426
+ if isinstance(image, list):
427
+ batch_size = len(image)
428
+ elif isinstance(image, torch.Tensor):
429
+ batch_size = image.shape[0]
430
+ assert batch_size >= self.num_views and batch_size % self.num_views == 0
431
+ elif isinstance(image, PIL.Image.Image):
432
+ image = [image]*self.num_views*2
433
+ batch_size = self.num_views*2
434
+
435
+ device = self._execution_device
436
+ dtype = self.vae.dtype
437
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
438
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
439
+ # corresponds to doing no classifier free guidance.
440
+ do_classifier_free_guidance = guidance_scale != 1.0
441
+
442
+ # 3. Encode input image
443
+ if isinstance(image, list):
444
+ image_pil = image
445
+ elif isinstance(image, torch.Tensor):
446
+ image_pil = [TF.to_pil_image(image[i]) for i in range(image.shape[0])]
447
+ image_embeddings, image_latents = self._encode_image(image_pil, device, num_images_per_prompt, do_classifier_free_guidance)
448
+
449
+ if normal_cond is not None:
450
+ if isinstance(normal_cond, list):
451
+ normal_cond_pil = normal_cond
452
+ elif isinstance(normal_cond, torch.Tensor):
453
+ normal_cond_pil = [TF.to_pil_image(normal_cond[i]) for i in range(normal_cond.shape[0])]
454
+ _, image_latents = self._encode_image(normal_cond_pil, device, num_images_per_prompt, do_classifier_free_guidance)
455
+
456
+
457
+ # assert len(elevation_cond) == batch_size and len(elevation) == batch_size and len(azimuth) == batch_size
458
+ # camera_embeddings = self.prepare_camera_condition(elevation_cond, elevation, azimuth, do_classifier_free_guidance=do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt)
459
+
460
+ if camera_embedding is not None:
461
+ assert len(camera_embedding) == batch_size
462
+ else:
463
+ camera_embedding = self.camera_embedding.to(dtype)
464
+ camera_embedding = repeat(camera_embedding, "Nv Nce -> (B Nv) Nce", B=batch_size//len(camera_embedding))
465
+ camera_embeddings = self.prepare_camera_embedding(camera_embedding, do_classifier_free_guidance=do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt)
466
+
467
+ # 4. Prepare timesteps
468
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
469
+ timesteps = self.scheduler.timesteps
470
+
471
+ # 5. Prepare latent variables
472
+ num_channels_latents = self.unet.config.out_channels
473
+ latents = self.prepare_latents(
474
+ batch_size * num_images_per_prompt,
475
+ num_channels_latents,
476
+ height,
477
+ width,
478
+ image_embeddings.dtype,
479
+ device,
480
+ generator,
481
+ latents,
482
+ cross_domain_latnte=True
483
+ )
484
+
485
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
486
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
487
+
488
+ # 7. Denoising loop
489
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
490
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
491
+ if do_classifier_free_guidance and self.pred_type == 'joint_color_normal':
492
+ print("reshape the input to cross-domain format")
493
+ image_embeddings = self.reshape_to_cd_input(image_embeddings)
494
+ camera_embeddings = self.reshape_to_cd_input(camera_embeddings)
495
+ image_latents = self.reshape_to_cd_input(image_latents)
496
+ for i, t in enumerate(timesteps):
497
+ # expand the latents if we are doing classifier free guidance
498
+ # latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
499
+ if do_classifier_free_guidance and self.pred_type == 'joint_color_normal':
500
+ latent_model_input = torch.cat([latents] * 2)
501
+ latent_model_input = self.reshape_to_cd_input(latent_model_input)
502
+ elif do_classifier_free_guidance and self.pred_type != 'joint_color_normal':
503
+ latent_model_input = torch.cat([latents] * 2)
504
+ else:
505
+ latent_model_input = latents
506
+
507
+ latent_model_input = torch.cat([
508
+ latent_model_input, image_latents
509
+ ], dim=1)
510
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
511
+
512
+ # predict the noise residual
513
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings,
514
+ class_labels=camera_embeddings).sample
515
+
516
+ # perform guidance
517
+ if do_classifier_free_guidance and self.pred_type != 'joint_color_normal':
518
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
519
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
520
+ elif do_classifier_free_guidance and self.pred_type == 'joint_color_normal':
521
+ noise_pred = self.reshape_to_cfg_output(noise_pred)
522
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
523
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
524
+
525
+ # compute the previous noisy sample x_t -> x_t-1
526
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
527
+
528
+ # call the callback, if provided
529
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
530
+ progress_bar.update()
531
+ if callback is not None and i % callback_steps == 0:
532
+ callback(i, t, latents)
533
+
534
+ if not output_type == "latent":
535
+ if num_channels_latents == 8:
536
+ latents = torch.cat([latents[:, :4], latents[:, 4:]], dim=0)
537
+
538
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
539
+ image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
540
+ else:
541
+ image = latents
542
+ has_nsfw_concept = None
543
+
544
+ if has_nsfw_concept is None:
545
+ do_denormalize = [True] * image.shape[0]
546
+ else:
547
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
548
+
549
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
550
+
551
+ if not return_dict:
552
+ return (image, has_nsfw_concept)
553
+
554
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
555
+