qitaoz commited on
Commit
4f54ccd
·
verified ·
1 Parent(s): ee4a9d9

init commit

Browse files
sparseags/cam_utils.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.spatial.transform import Rotation as R
3
+
4
+ # import ipdb
5
+ import math
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from pytorch3d.transforms import Rotate, Translate
9
+
10
+
11
+ def intersect_skew_line_groups(p, r, mask):
12
+ # p, r both of shape (B, N, n_intersected_lines, 3)
13
+ # mask of shape (B, N, n_intersected_lines)
14
+ p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask)
15
+ if p_intersect is None:
16
+ return None, None, None, None
17
+ _, p_line_intersect = point_line_distance(
18
+ p, r, p_intersect[..., None, :].expand_as(p)
19
+ )
20
+ intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum(
21
+ dim=-1
22
+ )
23
+ return p_intersect, p_line_intersect, intersect_dist_squared, r
24
+
25
+
26
+ def intersect_skew_lines_high_dim(p, r, mask=None):
27
+ # Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions
28
+ dim = p.shape[-1]
29
+ # make sure the heading vectors are l2-normed
30
+ if mask is None:
31
+ mask = torch.ones_like(p[..., 0])
32
+ r = torch.nn.functional.normalize(r, dim=-1)
33
+
34
+ eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None]
35
+ I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None]
36
+ sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3)
37
+
38
+ # I_eps = torch.zeros_like(I_min_cov.sum(dim=-3)) + 1e-10
39
+ # p_intersect = torch.pinverse(I_min_cov.sum(dim=-3) + I_eps).matmul(sum_proj)[..., 0]
40
+ p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
41
+
42
+ # I_min_cov.sum(dim=-3): torch.Size([1, 1, 3, 3])
43
+ # sum_proj: torch.Size([1, 1, 3, 1])
44
+
45
+ # p_intersect = np.linalg.lstsq(I_min_cov.sum(dim=-3).numpy(), sum_proj.numpy(), rcond=None)[0]
46
+
47
+ if torch.any(torch.isnan(p_intersect)):
48
+ print(p_intersect)
49
+ return None, None
50
+ ipdb.set_trace()
51
+ assert False
52
+ return p_intersect, r
53
+
54
+
55
+ def point_line_distance(p1, r1, p2):
56
+ df = p2 - p1
57
+ proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1)
58
+ line_pt_nearest = p2 - proj_vector
59
+ d = (proj_vector).norm(dim=-1)
60
+ return d, line_pt_nearest
61
+
62
+
63
+ def compute_optical_axis_intersection(cameras, in_ndc=True):
64
+ centers = cameras.get_camera_center()
65
+ principal_points = cameras.principal_point
66
+
67
+ one_vec = torch.ones((len(cameras), 1), device=centers.device)
68
+ optical_axis = torch.cat((principal_points, one_vec), -1)
69
+
70
+ # optical_axis = torch.cat(
71
+ # (principal_points, cameras.focal_length[:, 0].unsqueeze(1)), -1
72
+ # )
73
+
74
+ pp = cameras.unproject_points(optical_axis, from_ndc=in_ndc, world_coordinates=True)
75
+ pp2 = torch.diagonal(pp, dim1=0, dim2=1).T
76
+
77
+ directions = pp2 - centers
78
+ centers = centers.unsqueeze(0).unsqueeze(0)
79
+ directions = directions.unsqueeze(0).unsqueeze(0)
80
+
81
+ p_intersect, p_line_intersect, _, r = intersect_skew_line_groups(
82
+ p=centers, r=directions, mask=None
83
+ )
84
+
85
+ if p_intersect is None:
86
+ dist = None
87
+ else:
88
+ p_intersect = p_intersect.squeeze().unsqueeze(0)
89
+ dist = (p_intersect - centers).norm(dim=-1)
90
+
91
+ return p_intersect, dist, p_line_intersect, pp2, r
92
+
93
+
94
+ def normalize_cameras_with_up_axis(cameras, sequence_name, scale=1.0, in_ndc=True):
95
+ """
96
+ Normalizes cameras such that the optical axes point to the origin and the average
97
+ distance to the origin is 1.
98
+
99
+ Args:
100
+ cameras (List[camera]).
101
+ """
102
+
103
+ # Let distance from first camera to origin be unit
104
+ new_cameras = cameras.clone()
105
+ new_transform = new_cameras.get_world_to_view_transform()
106
+
107
+ p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection(
108
+ cameras,
109
+ in_ndc=in_ndc
110
+ )
111
+ t = Translate(p_intersect)
112
+
113
+ # scale = dist.squeeze()[0]
114
+ scale = dist.squeeze().mean()
115
+
116
+ # Degenerate case
117
+ if scale == 0:
118
+ print(cameras.T)
119
+ print(new_transform.get_matrix()[:, 3, :3])
120
+ return -1
121
+ assert scale != 0
122
+
123
+ new_transform = t.compose(new_transform)
124
+ new_cameras.R = new_transform.get_matrix()[:, :3, :3]
125
+ new_cameras.T = new_transform.get_matrix()[:, 3, :3] / scale * 1.85
126
+
127
+ needs_checking = False
128
+
129
+ # ===== Rotation normalization
130
+ # Estimate the world 'up' direction assuming that yaw is small
131
+ # and running SVD on the x-vectors of the cameras
132
+ x_vectors = new_cameras.R.transpose(1, 2)[:, 0, :].clone()
133
+ x_vectors -= x_vectors.mean(dim=0, keepdim=True)
134
+ U, S, Vh = torch.linalg.svd(x_vectors)
135
+ V = Vh.mH
136
+ # vector with the smallest variation is to the normal to
137
+ # the plane of x-vectors (assume this to be the up direction)
138
+ if S[0] / S[1] > S[1] / S[2]:
139
+ print('Warning: unexpected singular values in sequence {}: {}'.format(sequence_name, S))
140
+ needs_checking = True
141
+ # return None, None, None, None, None
142
+ estimated_world_up = V[:, 2:]
143
+ # check all cameras have the same y-direction
144
+ for camera_idx in range(len(new_cameras.T)):
145
+ if torch.sign(torch.dot(estimated_world_up[:, 0],
146
+ new_cameras.R[0].transpose(0,1)[1, :])) != torch.sign(torch.dot(estimated_world_up[:, 0],
147
+ new_cameras.R[camera_idx].transpose(0,1)[1, :])):
148
+ print("Some cameras appear to be flipped in sequence {}".format(sequence_name) )
149
+ needs_checking = True
150
+ # return None, None, None, None, None
151
+ flip = torch.sign(torch.dot(estimated_world_up[:, 0], new_cameras.R[0].transpose(0,1)[1, :])) < 0
152
+ if flip:
153
+ estimated_world_up = V[:, 2:] * -1
154
+ # build the target coordinate basis using the estimated world up
155
+ target_coordinate_basis = torch.cat([V[:, :1],
156
+ estimated_world_up,
157
+ torch.linalg.cross(V[:, :1], estimated_world_up, dim=0)],
158
+ dim=1)
159
+ new_cameras.R = torch.matmul(target_coordinate_basis.T, new_cameras.R)
160
+ return new_cameras, p_intersect, p_line_intersect, pp, r, needs_checking
161
+
162
+
163
+ def dot(x, y):
164
+ if isinstance(x, np.ndarray):
165
+ return np.sum(x * y, -1, keepdims=True)
166
+ else:
167
+ return torch.sum(x * y, -1, keepdim=True)
168
+
169
+
170
+ def length(x, eps=1e-20):
171
+ if isinstance(x, np.ndarray):
172
+ return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps))
173
+ else:
174
+ return torch.sqrt(torch.clamp(dot(x, x), min=eps))
175
+
176
+
177
+ def safe_normalize(x, eps=1e-20):
178
+ return x / length(x, eps)
179
+
180
+
181
+ def look_at(campos, target, opengl=True):
182
+ # campos: [N, 3], camera/eye position
183
+ # target: [N, 3], object to look at
184
+ # return: [N, 3, 3], rotation matrix
185
+ if not opengl:
186
+ # camera forward aligns with -z
187
+ forward_vector = safe_normalize(target - campos)
188
+ up_vector = np.array([0, 1, 0], dtype=np.float32)
189
+ right_vector = safe_normalize(np.cross(forward_vector, up_vector))
190
+ up_vector = safe_normalize(np.cross(right_vector, forward_vector))
191
+ else:
192
+ # camera forward aligns with +z
193
+ forward_vector = safe_normalize(campos - target)
194
+ up_vector = np.array([0, 1, 0], dtype=np.float32)
195
+ right_vector = safe_normalize(np.cross(up_vector, forward_vector))
196
+ up_vector = safe_normalize(np.cross(forward_vector, right_vector))
197
+ R = np.stack([right_vector, up_vector, forward_vector], axis=1)
198
+ return R
199
+
200
+
201
+ # elevation & azimuth to pose (cam2world) matrix
202
+ def orbit_camera(elevation, azimuth, radius=1, is_degree=True, target=None, opengl=True):
203
+ # radius: scalar
204
+ # elevation: scalar, in (-90, 90), from +y to -y is (-90, 90)
205
+ # azimuth: scalar, in (-180, 180), from +z to +x is (0, 90)
206
+ # return: [4, 4], camera pose matrix
207
+ if is_degree:
208
+ elevation = np.deg2rad(elevation)
209
+ azimuth = np.deg2rad(azimuth)
210
+ x = radius * np.cos(elevation) * np.sin(azimuth)
211
+ y = - radius * np.sin(elevation)
212
+ z = radius * np.cos(elevation) * np.cos(azimuth)
213
+ if target is None:
214
+ target = np.zeros([3], dtype=np.float32)
215
+ campos = np.array([x, y, z]) + target # [3]
216
+ T = np.eye(4, dtype=np.float32)
217
+ T[:3, :3] = look_at(campos, target, opengl)
218
+ T[:3, 3] = campos
219
+
220
+ return T
221
+
222
+
223
+ def mat2latlon(T):
224
+ if not isinstance(T, np.ndarray):
225
+ xyz = T.cpu().detach().numpy()
226
+ else:
227
+ xyz = T.copy()
228
+ r = np.linalg.norm(xyz)
229
+ xyz = xyz / r
230
+ theta = -np.arcsin(xyz[1])
231
+ azi = np.arctan2(xyz[0], xyz[2])
232
+ return np.rad2deg(theta), np.rad2deg(azi), r
233
+
234
+
235
+ def extract_camera_properties(camera_to_world_matrix):
236
+ # Camera position is the translation part of the matrix
237
+ camera_position = camera_to_world_matrix[:3, 3]
238
+
239
+ # Extracting the forward direction vector (third column of rotation matrix)
240
+ forward = camera_to_world_matrix[:3, 2]
241
+
242
+ return camera_position, forward
243
+
244
+
245
+ def compute_angular_error_batch(rotation1, rotation2):
246
+ R_rel = np.einsum("Bij,Bjk ->Bik", rotation1.transpose(0, 2, 1), rotation2)
247
+ t = (np.trace(R_rel, axis1=1, axis2=2) - 1) / 2
248
+ theta = np.arccos(np.clip(t, -1, 1))
249
+ return theta * 180 / np.pi
250
+
251
+
252
+ def find_mask_center_and_translate(image, mask):
253
+ """
254
+ Calculate the center of the mask and translate the image such that
255
+ the mask center is at the image center.
256
+
257
+ Args:
258
+ - image (torch.Tensor): Input image tensor of shape (N, C, H, W)
259
+ - mask (torch.Tensor): Mask tensor of shape (N, 1, H, W)
260
+
261
+ Returns:
262
+ - Translated image of shape (N, C, H, W)
263
+ """
264
+ _, _, h, w = image.shape
265
+
266
+ # Calculate the center of mass of the mask
267
+ # Note: mask should be a binary mask of the same spatial dimensions as the image
268
+ y_coords, x_coords = torch.meshgrid(torch.arange(0, h), torch.arange(0, w), indexing='ij')
269
+ total_mass = mask.sum(dim=[2, 3], keepdim=True)
270
+ x_center = (mask * x_coords.to(image.device)).sum(dim=[2, 3], keepdim=True) / total_mass
271
+ y_center = (mask * y_coords.to(image.device)).sum(dim=[2, 3], keepdim=True) / total_mass
272
+
273
+ # Calculate the translation needed to move the mask center to the image center
274
+ image_center_x, image_center_y = w // 2, h // 2
275
+ delta_x = x_center.squeeze() - image_center_x
276
+ delta_y = y_center.squeeze() - image_center_y
277
+
278
+ return torch.tensor([delta_x, delta_y])
279
+
280
+
281
+ def create_voxel_grid(length, resolution=64):
282
+ """
283
+ Creates a voxel grid.
284
+ xyz_range: ((min_x, max_x), (min_y, max_y), (min_z, max_z))
285
+ resolution: The number of divisions along each axis.
286
+ Returns a 4D tensor representing the voxel grid, with each voxel initialized to 1 (solid).
287
+ """
288
+ x = torch.linspace(-length, length, resolution)
289
+ y = torch.linspace(-length, length, resolution)
290
+ z = torch.linspace(-length, length, resolution)
291
+
292
+ xx, yy, zz = torch.meshgrid(x, y, z, indexing='ij')
293
+ voxels = torch.stack([xx, yy, zz, torch.ones_like(xx)], dim=-1) # Homogeneous coordinates
294
+ return voxels
295
+
296
+
297
+ def project_voxels_to_image(voxels, camera):
298
+ """
299
+ Projects voxel centers into the camera's image plane.
300
+ voxels: 4D tensor of voxel grid in homogeneous coordinates.
301
+ K: Camera intrinsic matrix.
302
+ R: Camera rotation matrix.
303
+ t: Camera translation vector.
304
+ Returns a tensor of projected 2D points in image coordinates.
305
+ """
306
+ device = voxels.device
307
+ # K, R, t = torch.tensor(K, device=device), torch.tensor(R, device=device), torch.tensor(t, device=device)
308
+
309
+ # Flatten voxels to shape (N, 4) for matrix multiplication
310
+ N = voxels.nelement() // 4 # Total number of voxels
311
+ voxels_flat = voxels.reshape(-1, 4).t() # Shape (4, N)
312
+
313
+ # # Apply extrinsic parameters (rotation and translation)
314
+ # transformed_voxels = R @ voxels_flat[:3, :] + t[:, None]
315
+
316
+ # # Apply intrinsic parameters
317
+ # projected_voxels = K @ transformed_voxels
318
+
319
+ projected_voxels = camera.projection_matrix.transpose(0, 1) @ camera.world_view_transform.transpose(0, 1) @ voxels_flat
320
+
321
+ # Convert from homogeneous coordinates to 2D
322
+ projected_voxels_2d = (projected_voxels[:2, :] / projected_voxels[3, :]).t() # Reshape to grid dimensions with 2D points
323
+ projected_voxels_2d = (projected_voxels_2d.reshape(*voxels.shape[:-1], 2) + 1.) * 255 * 0.5
324
+
325
+ return projected_voxels_2d
326
+
327
+
328
+ def carve_voxels(voxel_grid, projected_points, mask):
329
+ """
330
+ Updates the voxel grid based on the comparison with the mask.
331
+ voxel_grid: 3D tensor representing the voxel grid.
332
+ projected_points: Projected 2D points in image coordinates.
333
+ mask: Binary mask image.
334
+ """
335
+ # Convert projected points to indices in the mask
336
+ indices_x = torch.clamp(projected_points[..., 0], 0, mask.shape[1] - 1).long()
337
+ indices_y = torch.clamp(projected_points[..., 1], 0, mask.shape[0] - 1).long()
338
+
339
+ # Check if projected points are within the object in the mask
340
+ in_object = mask[indices_y, indices_x]
341
+
342
+ # Carve out voxels where the projection does not fall within the object
343
+ voxel_grid[in_object == 0] = 0
344
+
345
+
346
+ def sample_points_from_voxel(cameras, masks, length=1, resolution=64, N=5000, inverse=False, device="cuda"):
347
+ """
348
+ Randomly sample N points from solid regions in a voxel grid.
349
+
350
+ Args:
351
+ - voxel_grid (torch.Tensor): A 3D tensor representing the voxel grid after carving.
352
+ Solid regions are marked with 1s.
353
+ - N (int): The number of points to sample.
354
+
355
+ Returns:
356
+ - sampled_points (torch.Tensor): A tensor of shape (N, 3) representing the sampled 3D coordinates.
357
+ """
358
+ voxel_grid = create_voxel_grid(length, resolution).to(device)
359
+ voxel_grid_indicator = torch.ones(resolution, resolution, resolution)
360
+
361
+ masks = torch.from_numpy(masks).to(device).squeeze()
362
+
363
+ for idx, cam in enumerate(cameras):
364
+ projected_points = project_voxels_to_image(voxel_grid, cam)
365
+ carve_voxels(voxel_grid_indicator, projected_points, masks[idx])
366
+
367
+ voxel_grid_indicator = voxel_grid_indicator.reshape(resolution, resolution, resolution)
368
+
369
+ # Identify the indices of solid voxels
370
+ if inverse:
371
+ solid_indices = torch.nonzero(voxel_grid_indicator == 0)
372
+ else:
373
+ solid_indices = torch.nonzero(voxel_grid_indicator == 1)
374
+
375
+ # Randomly select N indices from the solid indices
376
+ if N <= solid_indices.size(0):
377
+ # Randomly select N indices from the solid indices if there are enough solid voxels
378
+ sampled_indices = solid_indices[torch.randperm(solid_indices.size(0))[:N]]
379
+ else:
380
+ # If there are not enough solid voxels, sample with replacement
381
+ sampled_indices = solid_indices[torch.randint(0, solid_indices.size(0), (N,))]
382
+
383
+ # Convert indices to coordinates
384
+ # Note: This step assumes the voxel grid spans from 0 to 1 in each dimension.
385
+ # Adjust accordingly if your grid spans a different range.
386
+ sampled_points = sampled_indices.float() / (voxel_grid.size(0) - 1) * 2 * length - length
387
+
388
+ return sampled_points
389
+
390
+
391
+ class OrbitCamera:
392
+ def __init__(self, W, H, r=2, fovy=60, near=0.01, far=100):
393
+ self.W = W
394
+ self.H = H
395
+ self.radius = r # camera distance from center
396
+ self.fovy = np.deg2rad(fovy) # deg 2 rad
397
+ self.near = near
398
+ self.far = far
399
+ self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point
400
+ self.rot = R.from_matrix(np.eye(3))
401
+ self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized!
402
+
403
+ @property
404
+ def fovx(self):
405
+ return 2 * np.arctan(np.tan(self.fovy / 2) * self.W / self.H)
406
+
407
+ @property
408
+ def campos(self):
409
+ return self.pose[:3, 3]
410
+
411
+ # pose (c2w)
412
+ @property
413
+ def pose(self):
414
+ # first move camera to radius
415
+ res = np.eye(4, dtype=np.float32)
416
+ res[2, 3] = self.radius # opengl convention...
417
+ # rotate
418
+ rot = np.eye(4, dtype=np.float32)
419
+ rot[:3, :3] = self.rot.as_matrix()
420
+ res = rot @ res
421
+ # translate
422
+ res[:3, 3] -= self.center
423
+ return res
424
+
425
+ # view (w2c)
426
+ @property
427
+ def view(self):
428
+ return np.linalg.inv(self.pose)
429
+
430
+ # projection (perspective)
431
+ @property
432
+ def perspective(self):
433
+ y = np.tan(self.fovy / 2)
434
+ aspect = self.W / self.H
435
+ return np.array(
436
+ [
437
+ [1 / (y * aspect), 0, 0, 0],
438
+ [0, -1 / y, 0, 0],
439
+ [
440
+ 0,
441
+ 0,
442
+ -(self.far + self.near) / (self.far - self.near),
443
+ -(2 * self.far * self.near) / (self.far - self.near),
444
+ ],
445
+ [0, 0, -1, 0],
446
+ ],
447
+ dtype=np.float32,
448
+ )
449
+
450
+ # intrinsics
451
+ @property
452
+ def intrinsics(self):
453
+ focal = self.H / (2 * np.tan(self.fovy / 2))
454
+ return np.array([focal, focal, self.W // 2, self.H // 2], dtype=np.float32)
455
+
456
+ @property
457
+ def mvp(self):
458
+ return self.perspective @ np.linalg.inv(self.pose) # [4, 4]
459
+
460
+ def orbit(self, dx, dy):
461
+ # rotate along camera up/side axis!
462
+ side = self.rot.as_matrix()[:3, 0]
463
+ rotvec_x = self.up * np.radians(-0.05 * dx)
464
+ rotvec_y = side * np.radians(-0.05 * dy)
465
+ self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot
466
+
467
+ def scale(self, delta):
468
+ self.radius *= 1.1 ** (-delta)
469
+
470
+ def pan(self, dx, dy, dz=0):
471
+ # pan in camera coordinate system (careful on the sensitivity!)
472
+ self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([-dx, -dy, dz])
sparseags/dust3r_utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pytorch3d.renderer import PerspectiveCameras
3
+
4
+ import sys
5
+ sys.path.append('./')
6
+ from sparseags.cam_utils import normalize_cameras_with_up_axis
7
+
8
+ sys.path[0] = sys.path[0] + '/dust3r'
9
+ from dust3r.inference import inference
10
+ from dust3r.utils.image import load_images
11
+ from dust3r.image_pairs import make_pairs
12
+ from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
13
+
14
+
15
+ def infer_dust3r(dust3r_model, file_names, device='cuda'):
16
+ batch_size = 1
17
+ schedule = 'cosine'
18
+ lr = 0.01
19
+ niter = 300
20
+
21
+ images = load_images(file_names, size=224)
22
+ pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
23
+ output = inference(pairs, dust3r_model, device, batch_size=batch_size)
24
+
25
+ scene = global_aligner(output, optimize_pp=True, device=device, mode=GlobalAlignerMode.PointCloudOptimizer)
26
+ loss = scene.compute_global_alignment(init="mst", niter=niter, schedule=schedule, lr=lr)
27
+
28
+ # retrieve useful values from scene:
29
+ imgs = scene.imgs
30
+ cams2world = scene.get_im_poses()
31
+ w2c = torch.linalg.inv(cams2world)
32
+ pps = scene.get_principal_points() * 256 / 224
33
+ focals = scene.get_focals() * 256 / 224
34
+
35
+ w2c[:, :2] *= -1 # OpenCV to PyTorch3D
36
+ Rs = w2c[:, :3, :3].transpose(1, 2)
37
+ Ts = w2c[:, :3, 3]
38
+
39
+ cameras = PerspectiveCameras(
40
+ focal_length=focals,
41
+ principal_point=pps,
42
+ in_ndc=False,
43
+ R=Rs,
44
+ T=Ts,
45
+ )
46
+ normalized_cameras, _, _, _, _, needs_checking = normalize_cameras_with_up_axis(cameras, None, in_ndc=False)
47
+
48
+ if normalized_cameras is None:
49
+ print("It seems something wrong...")
50
+ return 0
51
+
52
+ data = {}
53
+ base_names = [file_name.split('/')[-1].split('.')[0] for file_name in file_names]
54
+ file_names = [file_name.replace('source', 'processed').replace('.png', '_rgba.png') for file_name in file_names]
55
+
56
+ for idx, base_name in enumerate(base_names):
57
+ data[base_name] = {}
58
+ data[base_name]["R"] = normalized_cameras.R[idx].cpu().tolist()
59
+ data[base_name]["T"] = normalized_cameras.T[idx].cpu().tolist()
60
+ data[base_name]["needs_checking"] = needs_checking
61
+ data[base_name]["principal_point"] = normalized_cameras.principal_point[idx].cpu().tolist()
62
+ data[base_name]["focal_length"] = normalized_cameras.focal_length[idx].cpu().tolist()
63
+ data[base_name]["flag"] = 1
64
+ data[base_name]["filepath"] = file_names[idx]
65
+
66
+ return data
sparseags/guidance_utils/zero123.py ADDED
@@ -0,0 +1,666 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ import math
17
+ import warnings
18
+ from typing import Any, Callable, Dict, List, Optional, Union
19
+
20
+ import PIL
21
+ import torch
22
+ import torchvision.transforms.functional as TF
23
+ from diffusers.configuration_utils import ConfigMixin, FrozenDict, register_to_config
24
+ from diffusers.image_processor import VaeImageProcessor
25
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
26
+ from diffusers.models.modeling_utils import ModelMixin
27
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
28
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
29
+ from diffusers.pipelines.stable_diffusion.safety_checker import (
30
+ StableDiffusionSafetyChecker,
31
+ )
32
+ from diffusers.schedulers import KarrasDiffusionSchedulers
33
+ from diffusers.utils import deprecate, is_accelerate_available, logging
34
+ from diffusers.utils.torch_utils import randn_tensor
35
+ from packaging import version
36
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ class CLIPCameraProjection(ModelMixin, ConfigMixin):
42
+ """
43
+ A Projection layer for CLIP embedding and camera embedding.
44
+
45
+ Parameters:
46
+ embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `clip_embed`
47
+ additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
48
+ projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
49
+ additional_embeddings`.
50
+ """
51
+
52
+ @register_to_config
53
+ def __init__(self, embedding_dim: int = 768, additional_embeddings: int = 4):
54
+ super().__init__()
55
+ self.embedding_dim = embedding_dim
56
+ self.additional_embeddings = additional_embeddings
57
+
58
+ self.input_dim = self.embedding_dim + self.additional_embeddings
59
+ self.output_dim = self.embedding_dim
60
+
61
+ self.proj = torch.nn.Linear(self.input_dim, self.output_dim)
62
+
63
+ def forward(
64
+ self,
65
+ embedding: torch.FloatTensor,
66
+ ):
67
+ """
68
+ The [`PriorTransformer`] forward method.
69
+
70
+ Args:
71
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, input_dim)`):
72
+ The currently input embeddings.
73
+
74
+ Returns:
75
+ The output embedding projection (`torch.FloatTensor` of shape `(batch_size, output_dim)`).
76
+ """
77
+ proj_embedding = self.proj(embedding)
78
+ return proj_embedding
79
+
80
+
81
+ class Zero123Pipeline(DiffusionPipeline):
82
+ r"""
83
+ Pipeline to generate variations from an input image using Stable Diffusion.
84
+
85
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
86
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
87
+
88
+ Args:
89
+ vae ([`AutoencoderKL`]):
90
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
91
+ image_encoder ([`CLIPVisionModelWithProjection`]):
92
+ Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of
93
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),
94
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
95
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
96
+ scheduler ([`SchedulerMixin`]):
97
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
98
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
99
+ safety_checker ([`StableDiffusionSafetyChecker`]):
100
+ Classification module that estimates whether generated images could be considered offensive or harmful.
101
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
102
+ feature_extractor ([`CLIPImageProcessor`]):
103
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
104
+ """
105
+ # TODO: feature_extractor is required to encode images (if they are in PIL format),
106
+ # we should give a descriptive message if the pipeline doesn't have one.
107
+ _optional_components = ["safety_checker"]
108
+
109
+ def __init__(
110
+ self,
111
+ vae: AutoencoderKL,
112
+ image_encoder: CLIPVisionModelWithProjection,
113
+ unet: UNet2DConditionModel,
114
+ scheduler: KarrasDiffusionSchedulers,
115
+ safety_checker: StableDiffusionSafetyChecker,
116
+ feature_extractor: CLIPImageProcessor,
117
+ clip_camera_projection: CLIPCameraProjection,
118
+ requires_safety_checker: bool = True,
119
+ ):
120
+ super().__init__()
121
+
122
+ if safety_checker is None and requires_safety_checker:
123
+ logger.warn(
124
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
125
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
126
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
127
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
128
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
129
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
130
+ )
131
+
132
+ if safety_checker is not None and feature_extractor is None:
133
+ raise ValueError(
134
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
135
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
136
+ )
137
+
138
+ is_unet_version_less_0_9_0 = hasattr(
139
+ unet.config, "_diffusers_version"
140
+ ) and version.parse(
141
+ version.parse(unet.config._diffusers_version).base_version
142
+ ) < version.parse(
143
+ "0.9.0.dev0"
144
+ )
145
+ is_unet_sample_size_less_64 = (
146
+ hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
147
+ )
148
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
149
+ deprecation_message = (
150
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
151
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
152
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
153
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
154
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
155
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
156
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
157
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
158
+ " the `unet/config.json` file"
159
+ )
160
+ deprecate(
161
+ "sample_size<64", "1.0.0", deprecation_message, standard_warn=False
162
+ )
163
+ new_config = dict(unet.config)
164
+ new_config["sample_size"] = 64
165
+ unet._internal_dict = FrozenDict(new_config)
166
+
167
+ self.register_modules(
168
+ vae=vae,
169
+ image_encoder=image_encoder,
170
+ unet=unet,
171
+ scheduler=scheduler,
172
+ safety_checker=safety_checker,
173
+ feature_extractor=feature_extractor,
174
+ clip_camera_projection=clip_camera_projection,
175
+ )
176
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
177
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
178
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
179
+
180
+ def enable_sequential_cpu_offload(self, gpu_id=0):
181
+ r"""
182
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
183
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
184
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
185
+ """
186
+ if is_accelerate_available():
187
+ from accelerate import cpu_offload
188
+ else:
189
+ raise ImportError("Please install accelerate via `pip install accelerate`")
190
+
191
+ device = torch.device(f"cuda:{gpu_id}")
192
+
193
+ for cpu_offloaded_model in [
194
+ self.unet,
195
+ self.image_encoder,
196
+ self.vae,
197
+ self.safety_checker,
198
+ ]:
199
+ if cpu_offloaded_model is not None:
200
+ cpu_offload(cpu_offloaded_model, device)
201
+
202
+ @property
203
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
204
+ def _execution_device(self):
205
+ r"""
206
+ Returns the device on which the pipeline's models will be executed. After calling
207
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
208
+ hooks.
209
+ """
210
+ if not hasattr(self.unet, "_hf_hook"):
211
+ return self.device
212
+ for module in self.unet.modules():
213
+ if (
214
+ hasattr(module, "_hf_hook")
215
+ and hasattr(module._hf_hook, "execution_device")
216
+ and module._hf_hook.execution_device is not None
217
+ ):
218
+ return torch.device(module._hf_hook.execution_device)
219
+ return self.device
220
+
221
+ def _encode_image(
222
+ self,
223
+ image,
224
+ elevation,
225
+ azimuth,
226
+ distance,
227
+ device,
228
+ num_images_per_prompt,
229
+ do_classifier_free_guidance,
230
+ clip_image_embeddings=None,
231
+ image_camera_embeddings=None,
232
+ ):
233
+ dtype = next(self.image_encoder.parameters()).dtype
234
+
235
+ if image_camera_embeddings is None:
236
+ if image is None:
237
+ assert clip_image_embeddings is not None
238
+ image_embeddings = clip_image_embeddings.to(device=device, dtype=dtype)
239
+ else:
240
+ if not isinstance(image, torch.Tensor):
241
+ image = self.feature_extractor(
242
+ images=image, return_tensors="pt"
243
+ ).pixel_values
244
+
245
+ image = image.to(device=device, dtype=dtype)
246
+ image_embeddings = self.image_encoder(image).image_embeds
247
+ image_embeddings = image_embeddings.unsqueeze(1)
248
+
249
+ bs_embed, seq_len, _ = image_embeddings.shape
250
+
251
+ if isinstance(elevation, float):
252
+ elevation = torch.as_tensor(
253
+ [elevation] * bs_embed, dtype=dtype, device=device
254
+ )
255
+ if isinstance(azimuth, float):
256
+ azimuth = torch.as_tensor(
257
+ [azimuth] * bs_embed, dtype=dtype, device=device
258
+ )
259
+ if isinstance(distance, float):
260
+ distance = torch.as_tensor(
261
+ [distance] * bs_embed, dtype=dtype, device=device
262
+ )
263
+
264
+ camera_embeddings = torch.stack(
265
+ [
266
+ torch.deg2rad(elevation),
267
+ torch.sin(torch.deg2rad(azimuth)),
268
+ torch.cos(torch.deg2rad(azimuth)),
269
+ distance,
270
+ ],
271
+ dim=-1,
272
+ )[:, None, :]
273
+
274
+ image_embeddings = torch.cat([image_embeddings, camera_embeddings], dim=-1)
275
+
276
+ # project (image, camera) embeddings to the same dimension as clip embeddings
277
+ image_embeddings = self.clip_camera_projection(image_embeddings)
278
+ else:
279
+ image_embeddings = image_camera_embeddings.to(device=device, dtype=dtype)
280
+ bs_embed, seq_len, _ = image_embeddings.shape
281
+
282
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
283
+ image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
284
+ image_embeddings = image_embeddings.view(
285
+ bs_embed * num_images_per_prompt, seq_len, -1
286
+ )
287
+
288
+ if do_classifier_free_guidance:
289
+ negative_prompt_embeds = torch.zeros_like(image_embeddings)
290
+
291
+ # For classifier free guidance, we need to do two forward passes.
292
+ # Here we concatenate the unconditional and text embeddings into a single batch
293
+ # to avoid doing two forward passes
294
+ image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
295
+
296
+ return image_embeddings
297
+
298
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
299
+ def run_safety_checker(self, image, device, dtype):
300
+ if self.safety_checker is None:
301
+ has_nsfw_concept = None
302
+ else:
303
+ if torch.is_tensor(image):
304
+ feature_extractor_input = self.image_processor.postprocess(
305
+ image, output_type="pil"
306
+ )
307
+ else:
308
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
309
+ safety_checker_input = self.feature_extractor(
310
+ feature_extractor_input, return_tensors="pt"
311
+ ).to(device)
312
+ image, has_nsfw_concept = self.safety_checker(
313
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
314
+ )
315
+ return image, has_nsfw_concept
316
+
317
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
318
+ def decode_latents(self, latents):
319
+ warnings.warn(
320
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
321
+ " use VaeImageProcessor instead",
322
+ FutureWarning,
323
+ )
324
+ latents = 1 / self.vae.config.scaling_factor * latents
325
+ image = self.vae.decode(latents, return_dict=False)[0]
326
+ image = (image / 2 + 0.5).clamp(0, 1)
327
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
328
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
329
+ return image
330
+
331
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
332
+ def prepare_extra_step_kwargs(self, generator, eta):
333
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
334
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
335
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
336
+ # and should be between [0, 1]
337
+
338
+ accepts_eta = "eta" in set(
339
+ inspect.signature(self.scheduler.step).parameters.keys()
340
+ )
341
+ extra_step_kwargs = {}
342
+ if accepts_eta:
343
+ extra_step_kwargs["eta"] = eta
344
+
345
+ # check if the scheduler accepts generator
346
+ accepts_generator = "generator" in set(
347
+ inspect.signature(self.scheduler.step).parameters.keys()
348
+ )
349
+ if accepts_generator:
350
+ extra_step_kwargs["generator"] = generator
351
+ return extra_step_kwargs
352
+
353
+ def check_inputs(self, image, height, width, callback_steps):
354
+ # TODO: check image size or adjust image size to (height, width)
355
+
356
+ if height % 8 != 0 or width % 8 != 0:
357
+ raise ValueError(
358
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
359
+ )
360
+
361
+ if (callback_steps is None) or (
362
+ callback_steps is not None
363
+ and (not isinstance(callback_steps, int) or callback_steps <= 0)
364
+ ):
365
+ raise ValueError(
366
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
367
+ f" {type(callback_steps)}."
368
+ )
369
+
370
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
371
+ def prepare_latents(
372
+ self,
373
+ batch_size,
374
+ num_channels_latents,
375
+ height,
376
+ width,
377
+ dtype,
378
+ device,
379
+ generator,
380
+ latents=None,
381
+ ):
382
+ shape = (
383
+ batch_size,
384
+ num_channels_latents,
385
+ height // self.vae_scale_factor,
386
+ width // self.vae_scale_factor,
387
+ )
388
+ if isinstance(generator, list) and len(generator) != batch_size:
389
+ raise ValueError(
390
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
391
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
392
+ )
393
+
394
+ if latents is None:
395
+ latents = randn_tensor(
396
+ shape, generator=generator, device=device, dtype=dtype
397
+ )
398
+ else:
399
+ latents = latents.to(device)
400
+
401
+ # scale the initial noise by the standard deviation required by the scheduler
402
+ latents = latents * self.scheduler.init_noise_sigma
403
+ return latents
404
+
405
+ def _get_latent_model_input(
406
+ self,
407
+ latents: torch.FloatTensor,
408
+ image: Optional[
409
+ Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]
410
+ ],
411
+ num_images_per_prompt: int,
412
+ do_classifier_free_guidance: bool,
413
+ image_latents: Optional[torch.FloatTensor] = None,
414
+ ):
415
+ if isinstance(image, PIL.Image.Image):
416
+ image_pt = TF.to_tensor(image).unsqueeze(0).to(latents)
417
+ elif isinstance(image, list):
418
+ image_pt = torch.stack([TF.to_tensor(img) for img in image], dim=0).to(
419
+ latents
420
+ )
421
+ elif isinstance(image, torch.Tensor):
422
+ image_pt = image
423
+ else:
424
+ image_pt = None
425
+
426
+ if image_pt is None:
427
+ assert image_latents is not None
428
+ image_pt = image_latents.repeat_interleave(num_images_per_prompt, dim=0)
429
+ else:
430
+ image_pt = image_pt * 2.0 - 1.0 # scale to [-1, 1]
431
+ # FIXME: encoded latents should be multiplied with self.vae.config.scaling_factor
432
+ # but zero123 was not trained this way
433
+ image_pt = self.vae.encode(image_pt).latent_dist.mode()
434
+ image_pt = image_pt.repeat_interleave(num_images_per_prompt, dim=0)
435
+ if do_classifier_free_guidance:
436
+ latent_model_input = torch.cat(
437
+ [
438
+ torch.cat([latents, latents], dim=0),
439
+ torch.cat([torch.zeros_like(image_pt), image_pt], dim=0),
440
+ ],
441
+ dim=1,
442
+ )
443
+ else:
444
+ latent_model_input = torch.cat([latents, image_pt], dim=1)
445
+
446
+ return latent_model_input
447
+
448
+ @torch.no_grad()
449
+ def __call__(
450
+ self,
451
+ image: Optional[
452
+ Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]
453
+ ] = None,
454
+ elevation: Optional[Union[float, torch.FloatTensor]] = None,
455
+ azimuth: Optional[Union[float, torch.FloatTensor]] = None,
456
+ distance: Optional[Union[float, torch.FloatTensor]] = None,
457
+ height: Optional[int] = None,
458
+ width: Optional[int] = None,
459
+ num_inference_steps: int = 50,
460
+ guidance_scale: float = 3.0,
461
+ num_images_per_prompt: int = 1,
462
+ eta: float = 0.0,
463
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
464
+ latents: Optional[torch.FloatTensor] = None,
465
+ clip_image_embeddings: Optional[torch.FloatTensor] = None,
466
+ image_camera_embeddings: Optional[torch.FloatTensor] = None,
467
+ image_latents: Optional[torch.FloatTensor] = None,
468
+ output_type: Optional[str] = "pil",
469
+ return_dict: bool = True,
470
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
471
+ callback_steps: int = 1,
472
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
473
+ ):
474
+ r"""
475
+ Function invoked when calling the pipeline for generation.
476
+
477
+ Args:
478
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
479
+ The image or images to guide the image generation. If you provide a tensor, it needs to comply with the
480
+ configuration of
481
+ [this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
482
+ `CLIPImageProcessor`
483
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
484
+ The height in pixels of the generated image.
485
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
486
+ The width in pixels of the generated image.
487
+ num_inference_steps (`int`, *optional*, defaults to 50):
488
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
489
+ expense of slower inference.
490
+ guidance_scale (`float`, *optional*, defaults to 7.5):
491
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
492
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
493
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
494
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
495
+ usually at the expense of lower image quality.
496
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
497
+ The number of images to generate per prompt.
498
+ eta (`float`, *optional*, defaults to 0.0):
499
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
500
+ [`schedulers.DDIMScheduler`], will be ignored for others.
501
+ generator (`torch.Generator`, *optional*):
502
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
503
+ to make generation deterministic.
504
+ latents (`torch.FloatTensor`, *optional*):
505
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
506
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
507
+ tensor will ge generated by sampling using the supplied random `generator`.
508
+ output_type (`str`, *optional*, defaults to `"pil"`):
509
+ The output format of the generate image. Choose between
510
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
511
+ return_dict (`bool`, *optional*, defaults to `True`):
512
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
513
+ plain tuple.
514
+ callback (`Callable`, *optional*):
515
+ A function that will be called every `callback_steps` steps during inference. The function will be
516
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
517
+ callback_steps (`int`, *optional*, defaults to 1):
518
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
519
+ called at every step.
520
+
521
+ Returns:
522
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
523
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
524
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
525
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
526
+ (nsfw) content, according to the `safety_checker`.
527
+ """
528
+ # 0. Default height and width to unet
529
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
530
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
531
+
532
+ # 1. Check inputs. Raise error if not correct
533
+ # TODO: check input elevation, azimuth, and distance
534
+ # TODO: check image, clip_image_embeddings, image_latents
535
+ self.check_inputs(image, height, width, callback_steps)
536
+
537
+ # 2. Define call parameters
538
+ if isinstance(image, PIL.Image.Image):
539
+ batch_size = 1
540
+ elif isinstance(image, list):
541
+ batch_size = len(image)
542
+ elif isinstance(image, torch.Tensor):
543
+ batch_size = image.shape[0]
544
+ else:
545
+ assert image_latents is not None
546
+ assert (
547
+ clip_image_embeddings is not None or image_camera_embeddings is not None
548
+ )
549
+ batch_size = image_latents.shape[0]
550
+
551
+ device = self._execution_device
552
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
553
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
554
+ # corresponds to doing no classifier free guidance.
555
+ do_classifier_free_guidance = guidance_scale > 1.0
556
+
557
+ # 3. Encode input image
558
+ if isinstance(image, PIL.Image.Image) or isinstance(image, list):
559
+ pil_image = image
560
+ elif isinstance(image, torch.Tensor):
561
+ pil_image = [TF.to_pil_image(image[i]) for i in range(image.shape[0])]
562
+ else:
563
+ pil_image = None
564
+ image_embeddings = self._encode_image(
565
+ pil_image,
566
+ elevation,
567
+ azimuth,
568
+ distance,
569
+ device,
570
+ num_images_per_prompt,
571
+ do_classifier_free_guidance,
572
+ clip_image_embeddings,
573
+ image_camera_embeddings,
574
+ )
575
+
576
+ # 4. Prepare timesteps
577
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
578
+ timesteps = self.scheduler.timesteps
579
+
580
+ # 5. Prepare latent variables
581
+ # num_channels_latents = self.unet.config.in_channels
582
+ num_channels_latents = 4 # FIXME: hard-coded
583
+ latents = self.prepare_latents(
584
+ batch_size * num_images_per_prompt,
585
+ num_channels_latents,
586
+ height,
587
+ width,
588
+ image_embeddings.dtype,
589
+ device,
590
+ generator,
591
+ latents,
592
+ )
593
+
594
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
595
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
596
+
597
+ # 7. Denoising loop
598
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
599
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
600
+ for i, t in enumerate(timesteps):
601
+ # expand the latents if we are doing classifier free guidance
602
+ latent_model_input = self._get_latent_model_input(
603
+ latents,
604
+ image,
605
+ num_images_per_prompt,
606
+ do_classifier_free_guidance,
607
+ image_latents,
608
+ )
609
+ latent_model_input = self.scheduler.scale_model_input(
610
+ latent_model_input, t
611
+ )
612
+
613
+ # predict the noise residual
614
+ noise_pred = self.unet(
615
+ latent_model_input,
616
+ t,
617
+ encoder_hidden_states=image_embeddings,
618
+ cross_attention_kwargs=cross_attention_kwargs,
619
+ ).sample
620
+
621
+ # perform guidance
622
+ if do_classifier_free_guidance:
623
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
624
+ noise_pred = noise_pred_uncond + guidance_scale * (
625
+ noise_pred_text - noise_pred_uncond
626
+ )
627
+
628
+ # compute the previous noisy sample x_t -> x_t-1
629
+ latents = self.scheduler.step(
630
+ noise_pred, t, latents, **extra_step_kwargs
631
+ ).prev_sample
632
+
633
+ # call the callback, if provided
634
+ if i == len(timesteps) - 1 or (
635
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
636
+ ):
637
+ progress_bar.update()
638
+ if callback is not None and i % callback_steps == 0:
639
+ callback(i, t, latents)
640
+
641
+ if not output_type == "latent":
642
+ image = self.vae.decode(
643
+ latents / self.vae.config.scaling_factor, return_dict=False
644
+ )[0]
645
+ image, has_nsfw_concept = self.run_safety_checker(
646
+ image, device, image_embeddings.dtype
647
+ )
648
+ else:
649
+ image = latents
650
+ has_nsfw_concept = None
651
+
652
+ if has_nsfw_concept is None:
653
+ do_denormalize = [True] * image.shape[0]
654
+ else:
655
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
656
+
657
+ image = self.image_processor.postprocess(
658
+ image, output_type=output_type, do_denormalize=do_denormalize
659
+ )
660
+
661
+ if not return_dict:
662
+ return (image, has_nsfw_concept)
663
+
664
+ return StableDiffusionPipelineOutput(
665
+ images=image, nsfw_content_detected=has_nsfw_concept
666
+ )
sparseags/guidance_utils/zero123_6d_utils.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DDIMScheduler
2
+ import torchvision.transforms.functional as TF
3
+
4
+ import numpy as np
5
+ from PIL import Image
6
+ import matplotlib.pyplot as plt
7
+ import torch
8
+ import torch.nn as nn
9
+ import torchvision
10
+ from torchvision.utils import save_image
11
+ from torchvision import transforms
12
+ import torch.nn.functional as F
13
+ from einops import rearrange
14
+
15
+ import sys
16
+ sys.path.append('./')
17
+
18
+ from sparseags.guidance_utils.zero123 import Zero123Pipeline
19
+
20
+
21
+ name_mapping = {
22
+ "model.diffusion_model.input_blocks.1.1.": "down_blocks.0.attentions.0.",
23
+ "model.diffusion_model.input_blocks.2.1.": "down_blocks.0.attentions.1.",
24
+ "model.diffusion_model.input_blocks.4.1.": "down_blocks.1.attentions.0.",
25
+ "model.diffusion_model.input_blocks.5.1.": "down_blocks.1.attentions.1.",
26
+ "model.diffusion_model.input_blocks.7.1.": "down_blocks.2.attentions.0.",
27
+ "model.diffusion_model.input_blocks.8.1.": "down_blocks.2.attentions.1.",
28
+ "model.diffusion_model.middle_block.1.": "mid_block.attentions.0.",
29
+ "model.diffusion_model.output_blocks.3.1.": "up_blocks.1.attentions.0.",
30
+ "model.diffusion_model.output_blocks.4.1.": "up_blocks.1.attentions.1.",
31
+ "model.diffusion_model.output_blocks.5.1.": "up_blocks.1.attentions.2.",
32
+ "model.diffusion_model.output_blocks.6.1.": "up_blocks.2.attentions.0.",
33
+ "model.diffusion_model.output_blocks.7.1.": "up_blocks.2.attentions.1.",
34
+ "model.diffusion_model.output_blocks.8.1.": "up_blocks.2.attentions.2.",
35
+ "model.diffusion_model.output_blocks.9.1.": "up_blocks.3.attentions.0.",
36
+ "model.diffusion_model.output_blocks.10.1.": "up_blocks.3.attentions.1.",
37
+ "model.diffusion_model.output_blocks.11.1.": "up_blocks.3.attentions.2.",
38
+ }
39
+
40
+ class Zero123(nn.Module):
41
+ def __init__(self, device, fp16=True, t_range=[0.02, 0.98], model_key="ashawkey/zero123-xl-diffusers"):
42
+ super().__init__()
43
+
44
+ self.device = device
45
+ self.fp16 = fp16
46
+ self.dtype = torch.float16 if fp16 else torch.float32
47
+
48
+ self.pipe = Zero123Pipeline.from_pretrained(
49
+ model_key,
50
+ trust_remote_code=True,
51
+ torch_dtype=self.dtype,
52
+ ).to(self.device)
53
+
54
+ # load weights from the checkpoint
55
+ ckpt_path = "checkpoints/zero123_6dof_23k.ckpt"
56
+ print(f'[INFO] loading checkpoint from {ckpt_path} ...')
57
+ old_state = torch.load(ckpt_path)
58
+ pretrained_weights = old_state['state_dict']['cc_projection.weight']
59
+ pretrained_biases = old_state['state_dict']['cc_projection.bias']
60
+ linear_layer = torch.nn.Linear(768 + 18, 768)
61
+ linear_layer.weight.data = pretrained_weights
62
+ linear_layer.bias.data = pretrained_biases
63
+ self.pipe.clip_camera_projection.proj = linear_layer.to(dtype=self.dtype, device=self.device)
64
+
65
+ for name in list(old_state['state_dict'].keys()):
66
+ for k, v in name_mapping.items():
67
+ if k in name:
68
+ old_state['state_dict'][name.replace(k, name_mapping[k])] = old_state['state_dict'][name].to(dtype=self.dtype, device=self.device)
69
+
70
+ m, u = self.pipe.unet.load_state_dict(old_state['state_dict'], strict=False)
71
+
72
+ # stable-zero123 has a different camera embedding
73
+ self.use_stable_zero123 = 'stable' in model_key
74
+
75
+ self.pipe.image_encoder.eval()
76
+ self.pipe.vae.eval()
77
+ self.pipe.unet.eval()
78
+ self.pipe.clip_camera_projection.eval()
79
+
80
+ self.vae = self.pipe.vae
81
+ self.unet = self.pipe.unet
82
+
83
+ self.pipe.set_progress_bar_config(disable=True)
84
+
85
+ self.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
86
+ self.num_train_timesteps = self.scheduler.config.num_train_timesteps
87
+
88
+ self.min_step = int(self.num_train_timesteps * t_range[0])
89
+ self.max_step = int(self.num_train_timesteps * t_range[1])
90
+ self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
91
+
92
+ self.embeddings = None
93
+
94
+ @torch.no_grad()
95
+ def get_img_embeds(self, x):
96
+ # x: image tensor in [0, 1]
97
+ x = F.interpolate(x, (256, 256), mode='bilinear', align_corners=False)
98
+ x_pil = [TF.to_pil_image(image) for image in x]
99
+ x_clip = self.pipe.feature_extractor(images=x_pil, return_tensors="pt").pixel_values.to(device=self.device, dtype=self.dtype)
100
+ c = self.pipe.image_encoder(x_clip).image_embeds
101
+ v = self.encode_imgs(x.to(self.dtype)) / self.vae.config.scaling_factor
102
+ self.embeddings = [c, v]
103
+
104
+ def get_cam_embeddings(self, polar, azimuth, radius, default_elevation=0):
105
+ if self.use_stable_zero123:
106
+ T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), np.deg2rad([90 + default_elevation] * len(polar))], axis=-1)
107
+ else:
108
+ # original zero123 camera embedding
109
+ T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), radius], axis=-1)
110
+ T = torch.from_numpy(T).unsqueeze(1).to(dtype=self.dtype, device=self.device) # [8, 1, 4]
111
+ return T
112
+
113
+ def get_cam_embeddings_6D(self, target_RT, cond_RT):
114
+ T_target = torch.from_numpy(target_RT["c2w"])
115
+ focal_len_target = torch.from_numpy(target_RT["focal_length"])
116
+
117
+ T_cond = torch.from_numpy(cond_RT["c2w"])
118
+ focal_len_cond = torch.from_numpy(cond_RT["focal_length"])
119
+
120
+ focal_len = focal_len_target / focal_len_cond
121
+
122
+ d_T = torch.linalg.inv(T_target) @ T_cond
123
+ d_T = torch.cat([d_T.flatten(), torch.log(focal_len)])
124
+ return d_T.unsqueeze(0).unsqueeze(0).to(dtype=self.dtype, device=self.device)
125
+
126
+ @torch.no_grad()
127
+ def refine(self, pred_rgb, cam_embed,
128
+ guidance_scale=5, steps=50, strength=0.8, idx=None
129
+ ):
130
+
131
+ ######## Slight modification ########
132
+ if pred_rgb is not None:
133
+ batch_size = pred_rgb.shape[0]
134
+ else:
135
+ batch_size = 1
136
+
137
+ self.scheduler.set_timesteps(steps)
138
+
139
+ if strength == 0:
140
+ init_step = 0
141
+ latents = torch.randn((1, 4, 32, 32), device=self.device, dtype=self.dtype)
142
+ else:
143
+ init_step = int(steps * strength)
144
+ pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False)
145
+ latents = self.encode_imgs(pred_rgb_256.to(self.dtype))
146
+ latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step])
147
+
148
+ T = cam_embed
149
+ if idx is not None:
150
+ cc_emb = torch.cat([self.embeddings[0][idx].repeat(batch_size, 1, 1), T], dim=-1)
151
+ else:
152
+ cc_emb = torch.cat([self.embeddings[0].repeat(batch_size, 1, 1), T], dim=-1)
153
+ cc_emb = self.pipe.clip_camera_projection(cc_emb)
154
+ cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0)
155
+
156
+ if idx is not None:
157
+ vae_emb = self.embeddings[1][idx].repeat(batch_size, 1, 1, 1)
158
+ else:
159
+ vae_emb = self.embeddings[1].repeat(batch_size, 1, 1, 1)
160
+ vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0)
161
+
162
+ for i, t in enumerate(self.scheduler.timesteps[init_step:]):
163
+
164
+ x_in = torch.cat([latents] * 2)
165
+ t_in = torch.cat([t.view(1)]).to(self.device)
166
+
167
+ noise_pred = self.unet(
168
+ torch.cat([x_in, vae_emb], dim=1),
169
+ t_in.to(self.unet.dtype),
170
+ encoder_hidden_states=cc_emb,
171
+ ).sample
172
+
173
+ noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
174
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
175
+
176
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
177
+
178
+ imgs = self.decode_latents(latents) # [1, 3, 256, 256]
179
+ return imgs
180
+
181
+ def train_step(self, pred_rgb, polar, azimuth, radius, step_ratio=None, guidance_scale=5, as_latent=False):
182
+ # pred_rgb: tensor [1, 3, H, W] in [0, 1]
183
+
184
+ batch_size = pred_rgb.shape[0]
185
+
186
+ if as_latent:
187
+ latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1
188
+ else:
189
+ pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False)
190
+ latents = self.encode_imgs(pred_rgb_256.to(self.dtype))
191
+
192
+ if step_ratio is not None:
193
+ # dreamtime-like
194
+ # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)
195
+ t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)
196
+ t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)
197
+ else:
198
+ t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)
199
+
200
+ w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1)
201
+
202
+ with torch.no_grad():
203
+ noise = torch.randn_like(latents)
204
+ latents_noisy = self.scheduler.add_noise(latents, noise, t)
205
+
206
+ x_in = torch.cat([latents_noisy] * 2)
207
+ t_in = torch.cat([t] * 2)
208
+
209
+ T = self.get_cam_embeddings(polar, azimuth, radius)
210
+ cc_emb = torch.cat([self.embeddings[0].repeat(batch_size, 1, 1), T], dim=-1)
211
+ cc_emb = self.pipe.clip_camera_projection(cc_emb)
212
+ cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0)
213
+
214
+ vae_emb = self.embeddings[1].repeat(batch_size, 1, 1, 1)
215
+ vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0)
216
+
217
+ noise_pred = self.unet(
218
+ torch.cat([x_in, vae_emb], dim=1),
219
+ t_in.to(self.unet.dtype),
220
+ encoder_hidden_states=cc_emb,
221
+ ).sample
222
+
223
+ noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
224
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
225
+
226
+ grad = w * (noise_pred - noise)
227
+ grad = torch.nan_to_num(grad)
228
+
229
+ target = (latents - grad).detach()
230
+ loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum')
231
+
232
+ return loss
233
+
234
+ def angle_between(self, sph_v1, sph_v2):
235
+ def sph2cart(sv):
236
+ r, theta, phi = sv[0], sv[1], sv[2]
237
+ # The polar representation is different from Stable-DreamFusion
238
+ return torch.tensor([r * torch.cos(theta) * torch.cos(phi), r * torch.cos(theta) * torch.sin(phi), r * torch.sin(theta)])
239
+ def unit_vector(v):
240
+ return v / torch.linalg.norm(v)
241
+ def angle_between_2_sph(sv1, sv2):
242
+ v1, v2 = sph2cart(sv1), sph2cart(sv2)
243
+ v1_u, v2_u = unit_vector(v1), unit_vector(v2)
244
+ return torch.arccos(torch.clip(torch.dot(v1_u, v2_u), -1.0, 1.0))
245
+ angles = torch.empty(len(sph_v1), len(sph_v2))
246
+ for i, sv1 in enumerate(sph_v1):
247
+ for j, sv2 in enumerate(sph_v2):
248
+ angles[i][j] = angle_between_2_sph(sv1, sv2)
249
+ return angles
250
+
251
+ def batch_train_step(self, pred_rgb, target_RT, cond_cams, step_ratio=None, guidance_scale=5, as_latent=False, step=None):
252
+ # pred_rgb: tensor [1, 3, H, W] in [0, 1]
253
+
254
+ batch_size = pred_rgb.shape[0]
255
+
256
+ if as_latent:
257
+ latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1
258
+ else:
259
+ pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False)
260
+ latents = self.encode_imgs(pred_rgb_256.to(self.dtype))
261
+
262
+ if step_ratio is not None:
263
+ # dreamtime-like
264
+ # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)
265
+ t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)
266
+ t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)
267
+ else:
268
+ t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)
269
+
270
+ w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1)
271
+
272
+ with torch.no_grad():
273
+ noise = torch.randn_like(latents)
274
+ latents_noisy = self.scheduler.add_noise(latents, noise, t)
275
+
276
+ x_in = torch.cat([latents_noisy] * 2 * self.num_views)
277
+ t_in = torch.cat([t] * 2 * self.num_views)
278
+
279
+ cc_embs = []
280
+ vae_embs = []
281
+ noise_preds = []
282
+ for idx in range(self.num_views):
283
+ cond_RT = {
284
+ "c2w": cond_cams[idx].c2w,
285
+ "focal_length": cond_cams[idx].focal_length,
286
+ }
287
+ T = self.get_cam_embeddings_6D(target_RT, cond_RT)
288
+ cc_emb = torch.cat([self.embeddings[0][idx].repeat(batch_size, 1, 1), T], dim=-1)
289
+ cc_emb = self.pipe.clip_camera_projection(cc_emb)
290
+ cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0)
291
+
292
+ vae_emb = self.embeddings[1][idx].repeat(batch_size, 1, 1, 1)
293
+ vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0)
294
+
295
+ cc_embs.append(cc_emb)
296
+ vae_embs.append(vae_emb)
297
+
298
+ cc_emb = torch.cat(cc_embs, dim=0)
299
+ vae_emb = torch.cat(vae_embs, dim=0)
300
+ noise_pred = self.unet(
301
+ torch.cat([x_in, vae_emb], dim=1),
302
+ t_in.to(self.unet.dtype),
303
+ encoder_hidden_states=cc_emb,
304
+ ).sample
305
+
306
+ noise_pred_chunks = noise_pred.chunk(self.num_views)
307
+ for idx in range(self.num_views):
308
+ noise_pred_cond, noise_pred_uncond = noise_pred_chunks[idx][0], noise_pred_chunks[idx][1]
309
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
310
+ noise_preds.append(noise_pred)
311
+
312
+ noise_pred = torch.stack(noise_preds).sum(dim=0) / len(noise_preds) # self.num_views # Average over all views
313
+
314
+ grad = w * (noise_pred - noise)
315
+ grad = torch.nan_to_num(grad)
316
+
317
+ target = (latents - grad).detach()
318
+ loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum')
319
+
320
+ return loss
321
+
322
+ def decode_latents(self, latents):
323
+ latents = 1 / self.vae.config.scaling_factor * latents
324
+
325
+ imgs = self.vae.decode(latents).sample
326
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
327
+
328
+ return imgs
329
+
330
+ def encode_imgs(self, imgs, mode=False):
331
+ # imgs: [B, 3, H, W]
332
+
333
+ imgs = 2 * imgs - 1
334
+
335
+ posterior = self.vae.encode(imgs).latent_dist
336
+ if mode:
337
+ latents = posterior.mode()
338
+ else:
339
+ latents = posterior.sample()
340
+ latents = latents * self.vae.config.scaling_factor
341
+
342
+ return latents
343
+
344
+
345
+ def process_im(im):
346
+ if im.shape[-1] == 3:
347
+ if self.bg_remover is None:
348
+ self.bg_remover = rembg.new_session()
349
+ im = rembg.remove(im, session=self.bg_remover)
350
+
351
+ im = im.astype(np.float32) / 255.0
352
+
353
+ input_mask = im[..., 3:]
354
+ input_img = im[..., :3] * input_mask + (1 - input_mask)
355
+ input_img = input_img[..., ::-1].copy()
356
+ image = torch.from_numpy(input_img).permute(2, 0, 1).unsqueeze(0).contiguous().to(device)
357
+ image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False)
358
+
359
+ return image
360
+
361
+
362
+ def get_T_6d(target_RT, cond_RT, use_objaverse):
363
+ if use_objaverse:
364
+ new_row = torch.tensor([[0., 0., 0., 1.]])
365
+
366
+ T_target = torch.from_numpy(target_RT) # world to cam matrix
367
+ T_target = torch.cat((T_target, new_row), dim=0)
368
+ T_target = torch.linalg.inv(T_target) # Cam to world matrix
369
+ T_target[:3, :] = T_target[[1, 2, 0]]
370
+
371
+ T_cond = torch.from_numpy(cond_RT)
372
+ T_cond = torch.cat((T_cond, new_row), dim=0)
373
+ T_cond = torch.linalg.inv(T_cond)
374
+ T_cond[:3, :] = T_cond[[1, 2, 0]]
375
+
376
+ focal_len = torch.tensor([1., 1.])
377
+
378
+ else:
379
+ T_target = torch.from_numpy(target_RT["c2w"])
380
+ focal_len_target = torch.from_numpy(target_RT["focal_length"])
381
+
382
+ T_cond = torch.from_numpy(cond_RT["c2w"])
383
+ focal_len_cond = torch.from_numpy(cond_RT["focal_length"])
384
+
385
+ focal_len = focal_len_target / focal_len_cond
386
+
387
+ d_T = torch.linalg.inv(T_target) @ T_cond
388
+ d_T = torch.cat([d_T.flatten(), torch.log(focal_len)])
389
+ return d_T
sparseags/main_stage1.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import sys
4
+ import json
5
+ import time
6
+ import tqdm
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+
13
+ import rembg
14
+ from liegroups.torch import SE3
15
+
16
+ import sys
17
+ sys.path.append('./')
18
+
19
+ from sparseags.cam_utils import orbit_camera, OrbitCamera, mat2latlon, find_mask_center_and_translate
20
+ from sparseags.render_utils.gs_renderer import Renderer, Camera, FoVCamera, CustomCamera
21
+ from sparseags.mesh_utils.grid_put import mipmap_linear_grid_put_2d
22
+ from sparseags.mesh_utils.mesh import Mesh, safe_normalize
23
+
24
+
25
+ class GUI:
26
+ def __init__(self, opt):
27
+ self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
28
+ self.gui = opt.gui # enable gui
29
+ self.W = opt.W
30
+ self.H = opt.H
31
+
32
+ self.mode = "image"
33
+ self.seed = 0
34
+
35
+ self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32)
36
+ self.need_update = True # update buffer_image
37
+
38
+ # models
39
+ self.device = torch.device("cuda")
40
+ self.bg_remover = None
41
+
42
+ self.guidance_sd = None
43
+ self.guidance_zero123 = None
44
+ self.guidance_dino = None
45
+
46
+ self.enable_sd = False
47
+ self.enable_zero123 = False
48
+ self.enable_dino = False
49
+
50
+ # renderer
51
+ self.renderer = Renderer(sh_degree=self.opt.sh_degree)
52
+ self.renderer.enable_dino = self.opt.lambda_dino > 0
53
+ self.renderer.gaussians.enable_dino = self.opt.lambda_dino > 0
54
+ self.renderer.gaussians.dino_feat_dim = 36
55
+ self.gaussain_scale_factor = 1
56
+
57
+ # input image
58
+ self.input_img = None
59
+ self.input_mask = None
60
+ self.input_img_torch = None
61
+ self.input_mask_torch = None
62
+
63
+ # training stuff
64
+ self.training = False
65
+ self.optimizer = None
66
+ self.step = 0
67
+ self.train_steps = 1 # steps per rendering loop
68
+
69
+ # load input data
70
+ self.load_input(self.opt.camera_path, self.opt.order_path)
71
+
72
+ self.cam = OrbitCamera(opt.W, opt.H, r=3, fovy=opt.fovy)
73
+
74
+ # override if provide a checkpoint
75
+ if self.opt.load is not None:
76
+ self.renderer.initialize(self.opt.load)
77
+ else:
78
+ # initialize gaussians to a blob
79
+ self.renderer.initialize(num_pts=self.opt.num_pts, radius=0.3, mode='sphere') # 0.5 for radius 3
80
+
81
+ # initialize gaussians to a carved voxel
82
+ # self.renderer.initialize(num_pts=self.opt.num_pts, radius=0.5, cameras=self.cams, masks=self.input_mask, mode='carve') # 0.5
83
+
84
+ def seed_everything(self):
85
+ try:
86
+ seed = int(self.seed)
87
+ except:
88
+ seed = np.random.randint(0, 1000000)
89
+
90
+ os.environ["PYTHONHASHSEED"] = str(seed)
91
+ np.random.seed(seed)
92
+ torch.manual_seed(seed)
93
+ torch.cuda.manual_seed(seed)
94
+ torch.backends.cudnn.deterministic = True
95
+ torch.backends.cudnn.benchmark = True
96
+
97
+ self.last_seed = seed
98
+
99
+ def prepare_train(self):
100
+ self.step = 0
101
+
102
+ # setup training
103
+ self.renderer.gaussians.training_setup(self.opt)
104
+ # do progressive sh-level
105
+ self.renderer.gaussians.active_sh_degree = 0
106
+ self.optimizer = self.renderer.gaussians.optimizer
107
+
108
+ self.enable_sd = self.opt.lambda_sd > 0 and self.prompt != ""
109
+ self.enable_zero123 = self.opt.lambda_zero123 > 0 and self.input_img is not None
110
+ self.enable_dino = self.opt.lambda_dino > 0
111
+
112
+ # lazy load guidance model
113
+ if self.guidance_zero123 is None and self.enable_zero123:
114
+ print(f"[INFO] loading zero123...")
115
+ from sparseags.guidance_utils.zero123_6d_utils import Zero123
116
+ self.guidance_zero123 = Zero123(self.device, model_key='ashawkey/zero123-xl-diffusers')
117
+ print(f"[INFO] loaded zero123!")
118
+ self.guidance_zero123.opt = self.opt
119
+ self.guidance_zero123.num_views = self.num_views
120
+
121
+ # input image
122
+ if self.input_img is not None:
123
+ import torchvision.transforms as transforms
124
+ from PIL import Image
125
+ self.input_img_torch = torch.from_numpy(self.input_img).permute(0, 3, 1, 2).to(self.device)
126
+ self.input_mask_torch = torch.from_numpy(self.input_mask).permute(0, 3, 1, 2).to(self.device)
127
+
128
+ # prepare embeddings
129
+ with torch.no_grad():
130
+ if self.enable_zero123:
131
+ self.guidance_zero123.get_img_embeds(self.input_img_torch)
132
+
133
+ def train_step(self):
134
+ starter = torch.cuda.Event(enable_timing=True)
135
+ ender = torch.cuda.Event(enable_timing=True)
136
+ starter.record()
137
+
138
+ for _ in range(self.train_steps):
139
+
140
+ self.step += 1
141
+ step_ratio = min(1, self.step / self.opt.iters)
142
+
143
+ # update lr
144
+ self.renderer.gaussians.update_learning_rate(self.step)
145
+
146
+ loss = 0
147
+
148
+ ### known view
149
+ for choice in range(self.num_views):
150
+ # For multiview training
151
+ cur_cam = self.cams[choice]
152
+
153
+ bg_size = self.renderer.gaussians.dino_feat_dim if self.enable_dino else 3
154
+ bg_color = torch.ones(
155
+ bg_size,
156
+ dtype=torch.float32,
157
+ device="cuda",
158
+ )
159
+ out = self.renderer.render(cur_cam, bg_color=bg_color)
160
+
161
+ # rgb loss
162
+ image = out["image"]
163
+ loss = loss + 10000 * step_ratio * F.mse_loss(image, self.input_img_torch[choice])
164
+
165
+ # mask loss
166
+ mask = out["alpha"]
167
+ loss = loss + 1000 * step_ratio * F.mse_loss(mask, self.input_mask_torch[choice])
168
+
169
+ # dino loss
170
+ if self.enable_dino:
171
+ feature = out["feature"]
172
+ loss = loss + 1000 * step_ratio * F.mse_loss(feature, self.guidance_dino.embeddings[choice])
173
+
174
+ ### novel view (manual batch)
175
+ render_resolution = 128 if step_ratio < 0.3 else (256 if step_ratio < 0.6 else 512)
176
+ images = []
177
+ masks = []
178
+ vers, hors, radii = [], [], []
179
+ # avoid too large elevation (> 80 or < -80)
180
+ min_ver = max(-60 + np.array(self.opt.ref_polars).min(), -80) # + - 30 for co3D
181
+ max_ver = min(60 + np.array(self.opt.ref_polars).max(), 80)
182
+
183
+ for _ in range(self.opt.batch_size):
184
+ # render random view
185
+ ver = np.random.randint(min_ver, max_ver) - self.opt.ref_polars[0]
186
+ hor = np.random.randint(-180, 180)
187
+ radius = 0
188
+
189
+ vers.append(ver)
190
+ hors.append(hor)
191
+ radii.append(radius)
192
+
193
+ pose = orbit_camera(
194
+ self.opt.ref_polars[0] + ver,
195
+ self.opt.ref_azimuths[0] + hor,
196
+ np.array(self.opt.ref_radii).mean() + radius,
197
+ )
198
+
199
+ # Azimuth
200
+ # [-180, -135): -4, [-135, -90): -3, [-90, -45): -2, [-45, 0): -1
201
+ # [0, 45): 0, [45, 90): 1, [90, 135): 2, [135, 180): 3.
202
+ # Elevation: [0, 90): 0 [-90, 0): 1
203
+ idx_ver, idx_hor = int((self.opt.ref_polars[0]+ver) < 0), hor // 45
204
+
205
+ flag = 0
206
+ cx, cy = self.pp_pools[idx_ver, idx_hor+4].tolist()
207
+ cnt = 0
208
+ fx, fy = self.fx, self.fy
209
+
210
+ # in each iter we modify cx, cy, fx, fy to make sure the rendered object is at the center and has a reasonable size
211
+ while not flag:
212
+
213
+ if cnt >= 10:
214
+ # print(f"[ERROR] Something might be wrong here!")
215
+ break
216
+
217
+ flag_principal_point, flag_focal_length = 0, 0
218
+
219
+ # we modified the field of view. Otherwise, the rendered object will be too small
220
+ # cur_cam = FoVCamera(pose, render_resolution, render_resolution, self.fovy, self.fovx, self.cam.near, self.cam.far)
221
+ cur_cam = Camera(pose, render_resolution, render_resolution, fx, fy, cx, cy, self.cam.near, self.cam.far)
222
+
223
+ bg_size = self.renderer.gaussians.dino_feat_dim if self.enable_dino else 3
224
+ bg_color = torch.ones(bg_size, dtype=torch.float32, device="cuda") if np.random.rand() > self.opt.invert_bg_prob else torch.zeros(bg_size, dtype=torch.float32, device="cuda")
225
+ out = self.renderer.render(cur_cam, bg_color=bg_color)
226
+
227
+ image = out["image"].unsqueeze(0)
228
+ mask = out["alpha"].unsqueeze(0)
229
+ delta_xy = find_mask_center_and_translate(image.detach(), mask.detach()) / render_resolution * 256
230
+
231
+ # (1) check if the principal points are appropriate
232
+ if delta_xy[0].abs() > 10 or delta_xy[1].abs() > 10:
233
+ cx -= delta_xy[0]
234
+ cy -= delta_xy[1]
235
+ self.pp_pools[idx_ver, idx_hor+4] = torch.tensor([cx, cy]) # Update pp_pools
236
+ else:
237
+ flag_principal_point = 1
238
+
239
+ num_pixs_mask = (mask > 0.5).float().sum().item()
240
+ target_num_pixs = render_resolution ** 2 / (1.2 ** 2)
241
+
242
+ mask_to_compute = (mask > 0.5).squeeze().detach().cpu().numpy()
243
+ y_indices, x_indices = np.where(mask_to_compute > 0)
244
+
245
+ if len(x_indices) == 0 or len(y_indices) == 0:
246
+ # return None or some indication that there's no object in the mask
247
+ continue
248
+
249
+ # find the bounding box coordinates
250
+ x1, y1 = np.min(x_indices), np.min(y_indices)
251
+ x2, y2 = np.max(x_indices), np.max(y_indices)
252
+
253
+ bbox = np.array([x1, y1, x2, y2])
254
+ extents = (bbox[2:] - bbox[:2]).max()
255
+ num_pixs_mask = extents ** 2
256
+
257
+ # (2) check if the focal lengths are appropriate
258
+ if abs(num_pixs_mask - target_num_pixs) > 0.05 * render_resolution ** 2:
259
+ if num_pixs_mask == 0:
260
+ pass
261
+ else:
262
+ fx = fx * np.sqrt(target_num_pixs / num_pixs_mask)
263
+ fy = fy * np.sqrt(target_num_pixs / num_pixs_mask)
264
+ else:
265
+ flag_focal_length = 1
266
+
267
+ if flag_principal_point * flag_focal_length == 1:
268
+ flag = 1
269
+
270
+ cnt += 1
271
+
272
+ images.append(image)
273
+ masks.append(mask)
274
+
275
+ images = torch.cat(images, dim=0)
276
+
277
+ if self.enable_zero123:
278
+ target_RT = {
279
+ "c2w": pose,
280
+ "focal_length": np.array(fx, fy),
281
+ }
282
+ loss = loss + self.opt.lambda_zero123 * self.guidance_zero123.batch_train_step(images, target_RT, self.cams, step_ratio=step_ratio if self.opt.anneal_timestep else None)
283
+
284
+ if self.enable_dino:
285
+ loss_dino = self.guidance_dino.train_step(
286
+ images,
287
+ out["feature"],
288
+ step_ratio=step_ratio if self.opt.anneal_timestep else None
289
+ )
290
+ loss = loss + self.opt.lambda_dino * loss_dino
291
+
292
+ # optimize step
293
+ loss.backward()
294
+ self.optimizer.step()
295
+ self.optimizer.zero_grad()
296
+ latlons = [mat2latlon(cam.c2w[:3, 3]) for cam in self.cams]
297
+ if self.opt.opt_cam:
298
+ for i, cam in enumerate(self.cams):
299
+ w2c = cam.w2c @ SE3.exp(cam.cam_params.detach()).as_matrix()
300
+ w2c[:2, :3] *= -1
301
+ w2c[:2, 3] *= -1
302
+ self.camera_tracks[i].append(w2c.tolist())
303
+ self.opt.ref_polars = [float(cam[0]) for cam in latlons]
304
+ self.opt.ref_azimuths = [float(cam[1]) for cam in latlons]
305
+ self.opt.ref_radii = [float(cam[2]) for cam in latlons]
306
+
307
+ # densify and prune
308
+ if self.step >= self.opt.density_start_iter and self.step <= self.opt.density_end_iter:
309
+ viewspace_point_tensor, visibility_filter, radii = out["viewspace_points"], out["visibility_filter"], out["radii"]
310
+ self.renderer.gaussians.max_radii2D[visibility_filter] = torch.max(self.renderer.gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
311
+ self.renderer.gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
312
+
313
+ if self.step % self.opt.densification_interval == 0:
314
+ self.renderer.gaussians.densify_and_prune(self.opt.densify_grad_threshold, min_opacity=0.01, extent=4, max_screen_size=1)
315
+
316
+ # if self.step % self.opt.opacity_reset_interval == 0:
317
+ # self.renderer.gaussians.reset_opacity()
318
+
319
+ if self.step % 100 == 0 and self.renderer.gaussians.max_sh_degree != 0:
320
+ self.renderer.gaussians.oneupSHdegree()
321
+
322
+ ender.record()
323
+ torch.cuda.synchronize()
324
+ t = starter.elapsed_time(ender)
325
+
326
+ self.need_update = True
327
+
328
+ def load_input(self, camera_path, order_path=None):
329
+ # load image
330
+ print(f'[INFO] load data from {camera_path}...')
331
+
332
+ if order_path is not None:
333
+ with open(order_path, 'r') as f:
334
+ indices = json.load(f)
335
+ else:
336
+ indices = None
337
+
338
+ with open(camera_path, 'r') as f:
339
+ data = json.load(f)
340
+
341
+ self.cam_params = {}
342
+ for k, v in data.items():
343
+ if indices is None:
344
+ self.cam_params[k] = data[k]
345
+ else:
346
+ if int(k) in indices or k in indices:
347
+ self.cam_params[k] = data[k]
348
+
349
+ if self.opt.all_views:
350
+ for k, v in self.cam_params.items():
351
+ self.cam_params[k]['opt_cam'] = 1
352
+ self.cam_params[k]['flag'] = 1
353
+ else:
354
+ for k, v in self.cam_params.items():
355
+ if int(self.cam_params[k]['flag']):
356
+ self.cam_params[k]['opt_cam'] = 1
357
+ else:
358
+ self.cam_params[k]['opt_cam'] = 0
359
+
360
+ img_paths = [v["filepath"] for k, v in self.cam_params.items() if v["flag"]]
361
+ self.num_views = len(img_paths)
362
+ print(f"[INFO] Number of views: {self.num_views}")
363
+
364
+ for filepath in img_paths:
365
+ print(filepath)
366
+
367
+ images, masks = [], []
368
+
369
+ for i in range(self.num_views):
370
+ img = cv2.imread(img_paths[i], cv2.IMREAD_UNCHANGED)
371
+ if img.shape[-1] == 3:
372
+ if self.bg_remover is None:
373
+ self.bg_remover = rembg.new_session()
374
+ img = rembg.remove(img, session=self.bg_remover)
375
+
376
+ img = img.astype(np.float32) / 255.0
377
+
378
+ # Non-integer cropping creates non-zero mask values
379
+ input_mask = (img[..., 3:] > 0.5).astype(np.float32)
380
+
381
+ # white bg
382
+ input_img = img[..., :3] * input_mask + (1 - input_mask)
383
+ # bgr to rgb
384
+ input_img = input_img[..., ::-1].copy()
385
+
386
+ images.append(input_img), masks.append(input_mask)
387
+
388
+ images = np.stack(images, axis=0)
389
+ masks = np.stack(masks, axis=0)
390
+ self.input_img = images[:self.num_views]
391
+ self.input_mask = masks[:self.num_views]
392
+ self.all_input_images = images
393
+
394
+ self.cams = [CustomCamera(v, index=int(k), opt_pose=self.opt.opt_cam and v['opt_cam']) for k, v in self.cam_params.items() if v["flag"]]
395
+ cam_centers = [mat2latlon(cam.camera_center) for cam in self.cams]
396
+ self.opt.ref_polars = [float(cam[0]) for cam in cam_centers]
397
+ self.opt.ref_azimuths = [float(cam[1]) for cam in cam_centers]
398
+ self.opt.ref_radii = [float(cam[2]) for cam in cam_centers]
399
+ self.fx = np.array([cam.fx for cam in self.cams], dtype=np.float32).mean()
400
+ self.fy = np.array([cam.fy for cam in self.cams], dtype=np.float32).mean()
401
+ self.cx = 128
402
+ self.cy = 128
403
+ if self.opt.opt_cam:
404
+ self.camera_tracks = {}
405
+ for i, cam in enumerate(self.cams):
406
+ self.camera_tracks[i] = []
407
+
408
+ # Azimuth Mapping: [-180, -135): -4, [-135, -90): -3, [-90, -45): -2, [-45, 0): -1,
409
+ # [0, 45): 0, [45, 90): 1, [90, 135): 2, [135, 180): 3.
410
+ # Elevation Mapping: [0, 90): 0, [-90, 0): 1.
411
+
412
+ # Principal Point Pool: Tensor (2, 8, 2), where:
413
+ # - 2: Elevation groups, 8: Azimuth intervals, 2: x, y coordinates (init to 128).
414
+
415
+ # we created a "pool" for principal points
416
+ # we use these principal points to render image to make sure object is at the center
417
+ self.pp_pools = torch.full((2, 8, 2), 128)
418
+ if self.opt.opt_cam:
419
+ self.renderer.gaussians.cam_params = [cam.cam_params for cam in self.cams[:] if cam.opt_pose]
420
+
421
+ @torch.no_grad()
422
+ def save_video(self, post_fix=None):
423
+ xyz = self.renderer.gaussians._xyz
424
+ center = self.renderer.gaussians._xyz.mean(dim=0)
425
+ squared_distances = torch.sum((xyz - center) ** 2, dim=1)
426
+ max_distance_squared = torch.max(squared_distances)
427
+ radius = torch.sqrt(max_distance_squared) + 1.0
428
+ radius = radius.detach().cpu().numpy()
429
+
430
+ render_resolution = 256
431
+ images = []
432
+ frame_rate = 30
433
+ image_size = (render_resolution, render_resolution) # Size of each image
434
+ video_path = self.opt.save_path + f'_rendered_video_{post_fix}.mp4'
435
+
436
+ azimuth = np.arange(0, 360, 3, dtype=np.int32)
437
+
438
+ for azi in tqdm.tqdm(azimuth):
439
+ target = center.detach().cpu().numpy()
440
+ pose = orbit_camera(-30, azi, radius, target=target)
441
+ cur_cam = FoVCamera(
442
+ pose,
443
+ render_resolution,
444
+ render_resolution,
445
+ self.cam.fovy,
446
+ self.cam.fovx,
447
+ self.cam.near,
448
+ self.cam.far,
449
+ )
450
+
451
+ out = self.renderer.render(cur_cam)
452
+ img = out["image"].detach().cpu().numpy() # [3, H, W] in [0, 1]
453
+ img = np.transpose(img, (1, 2, 0))
454
+ image = (img * 255).astype(np.uint8)
455
+ images.append(image)
456
+
457
+ images = np.stack(images, axis=0)
458
+ # ~4 seconds, 120 frames at 30 fps
459
+ import imageio
460
+ imageio.mimwrite(video_path, images, fps=30, quality=8, macro_block_size=1)
461
+
462
+
463
+ @torch.no_grad()
464
+ def save_model(self, mode='geo', texture_size=1024):
465
+ os.makedirs(self.opt.outdir, exist_ok=True)
466
+ if mode == 'geo':
467
+ path = os.path.join(self.opt.outdir, self.opt.save_path + '_mesh.ply')
468
+ mesh = self.renderer.gaussians.extract_mesh(path, self.opt.density_thresh)
469
+ mesh.write_ply(path)
470
+
471
+ elif mode == 'geo+tex':
472
+ path = os.path.join(self.opt.outdir, self.opt.save_path + '_mesh.' + self.opt.mesh_format)
473
+ mesh = self.renderer.gaussians.extract_mesh(path, self.opt.density_thresh)
474
+
475
+ # perform texture extraction
476
+ print(f"[INFO] unwrap uv...")
477
+ h = w = texture_size
478
+ mesh.auto_uv()
479
+ mesh.auto_normal()
480
+
481
+ albedo = torch.zeros((h, w, 3), device=self.device, dtype=torch.float32)
482
+ cnt = torch.zeros((h, w, 1), device=self.device, dtype=torch.float32)
483
+ if self.enable_dino:
484
+ feature = torch.zeros((h, w, self.renderer.gaussians.dino_feat_dim), device=self.device, dtype=torch.float32)
485
+
486
+ # self.prepare_train() # tmp fix for not loading 0123
487
+ # vers = [0]
488
+ # hors = [0]
489
+ vers = [0] * 8 + [-45] * 8 + [45] * 8 + [-89.9, 89.9]
490
+ hors = [0, 45, -45, 90, -90, 135, -135, 180] * 3 + [0, 0]
491
+
492
+ render_resolution = 512
493
+
494
+ import nvdiffrast.torch as dr
495
+
496
+ if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'):
497
+ glctx = dr.RasterizeGLContext()
498
+ else:
499
+ glctx = dr.RasterizeCudaContext()
500
+
501
+ for ver, hor in zip(vers, hors):
502
+ # render image
503
+ pose = orbit_camera(ver, hor, self.cam.radius)
504
+
505
+ cur_cam = FoVCamera(
506
+ pose,
507
+ render_resolution,
508
+ render_resolution,
509
+ self.cam.fovy,
510
+ self.cam.fovx,
511
+ self.cam.near,
512
+ self.cam.far,
513
+ )
514
+
515
+ cur_out = self.renderer.render(cur_cam)
516
+
517
+ rgbs = cur_out["image"].unsqueeze(0) # [1, 3, H, W] in [0, 1]
518
+ if self.enable_dino:
519
+ features = cur_out["feature"].unsqueeze(0) # [1, 384, 512, 512]
520
+
521
+ # enhance texture quality with zero123 [not working well]
522
+ # if self.opt.guidance_model == 'zero123':
523
+ # rgbs = self.guidance.refine(rgbs, [ver], [hor], [0])
524
+ # import kiui
525
+ # kiui.vis.plot_image(rgbs)
526
+
527
+ # get coordinate in texture image
528
+ pose = torch.from_numpy(pose.astype(np.float32)).to(self.device)
529
+ proj = torch.from_numpy(self.cam.perspective.astype(np.float32)).to(self.device)
530
+
531
+ v_cam = torch.matmul(F.pad(mesh.v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
532
+ v_clip = v_cam @ proj.T
533
+ rast, rast_db = dr.rasterize(glctx, v_clip, mesh.f, (render_resolution, render_resolution))
534
+
535
+ depth, _ = dr.interpolate(-v_cam[..., [2]], rast, mesh.f) # [1, H, W, 1]
536
+ depth = depth.squeeze(0) # [H, W, 1]
537
+
538
+ alpha = (rast[0, ..., 3:] > 0).float()
539
+
540
+ uvs, _ = dr.interpolate(mesh.vt.unsqueeze(0), rast, mesh.ft) # [1, 512, 512, 2] in [0, 1]
541
+
542
+ # use normal to produce a back-project mask
543
+ normal, _ = dr.interpolate(mesh.vn.unsqueeze(0).contiguous(), rast, mesh.fn)
544
+ normal = safe_normalize(normal[0])
545
+
546
+ # rotated normal (where [0, 0, 1] always faces camera)
547
+ rot_normal = normal @ pose[:3, :3]
548
+ viewcos = rot_normal[..., [2]]
549
+
550
+ mask = (alpha > 0) & (viewcos > 0.5) # [H, W, 1]
551
+ mask = mask.view(-1)
552
+
553
+ uvs = uvs.view(-1, 2).clamp(0, 1)[mask]
554
+ rgbs = rgbs.view(3, -1).permute(1, 0)[mask].contiguous()
555
+
556
+ # update texture image
557
+ cur_albedo, cur_cnt = mipmap_linear_grid_put_2d(
558
+ h, w,
559
+ uvs[..., [1, 0]] * 2 - 1,
560
+ rgbs,
561
+ min_resolution=256,
562
+ return_count=True,
563
+ )
564
+
565
+ if self.enable_dino:
566
+ features = features.view(features.shape[1], -1).permute(1, 0)[mask].contiguous()
567
+ cur_feature, _ = mipmap_linear_grid_put_2d(
568
+ h, w,
569
+ uvs[..., [1, 0]] * 2 - 1,
570
+ features,
571
+ min_resolution=256,
572
+ return_count=True,
573
+ )
574
+
575
+ # albedo += cur_albedo
576
+ # cnt += cur_cnt
577
+ mask = cnt.squeeze(-1) < 0.1
578
+ albedo[mask] += cur_albedo[mask]
579
+ cnt[mask] += cur_cnt[mask]
580
+
581
+ if self.enable_dino:
582
+ feature[mask] += cur_feature[mask]
583
+
584
+ mask = cnt.squeeze(-1) > 0
585
+ albedo[mask] = albedo[mask] / cnt[mask].repeat(1, 3)
586
+
587
+ if self.enable_dino:
588
+ feature[mask] = feature[mask] / cnt[mask].repeat(1, feature.shape[-1])
589
+
590
+ mask = mask.view(h, w)
591
+
592
+ albedo = albedo.detach().cpu().numpy()
593
+ mask = mask.detach().cpu().numpy()
594
+
595
+ if self.enable_dino:
596
+ feature = feature.detach().cpu().numpy()
597
+
598
+ # dilate texture
599
+ from sklearn.neighbors import NearestNeighbors
600
+ from scipy.ndimage import binary_dilation, binary_erosion
601
+
602
+ inpaint_region = binary_dilation(mask, iterations=32)
603
+ inpaint_region[mask] = 0
604
+
605
+ search_region = mask.copy()
606
+ not_search_region = binary_erosion(search_region, iterations=3)
607
+ search_region[not_search_region] = 0
608
+
609
+ search_coords = np.stack(np.nonzero(search_region), axis=-1)
610
+ inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1)
611
+
612
+ knn = NearestNeighbors(n_neighbors=1, algorithm="kd_tree").fit(
613
+ search_coords
614
+ )
615
+ _, indices = knn.kneighbors(inpaint_coords)
616
+
617
+ albedo[tuple(inpaint_coords.T)] = albedo[tuple(search_coords[indices[:, 0]].T)]
618
+
619
+ mesh.albedo = torch.from_numpy(albedo).to(self.device)
620
+ # mesh.write(path)
621
+
622
+ if self.enable_dino:
623
+ feature[tuple(inpaint_coords.T)] = feature[tuple(search_coords[indices[:, 0]].T)]
624
+ mesh.feature = torch.from_numpy(feature).to(self.device)
625
+
626
+ mesh.write(path, self.enable_dino)
627
+
628
+ else:
629
+ path = os.path.join(self.opt.outdir, self.opt.save_path + '_model.ply')
630
+ self.renderer.gaussians.save_ply(path)
631
+
632
+ print(f"[INFO] save model to {path}.")
633
+
634
+ # no gui mode
635
+ def train(self, iters=500):
636
+ if iters > 0:
637
+ self.prepare_train()
638
+ for i in tqdm.trange(iters):
639
+ self.train_step()
640
+ # do a last prune
641
+ self.renderer.gaussians.prune(min_opacity=0.01, extent=1, max_screen_size=1)
642
+ if self.opt.opt_cam:
643
+ for cam in self.cams:
644
+ try:
645
+ self.cam_params[str(cam.index)]["R"] = cam.rotation.tolist()
646
+ self.cam_params[str(cam.index)]["T"] = cam.translation.tolist()
647
+ except KeyError:
648
+ self.cam_params[f"{cam.index:03}"]["R"] = cam.rotation.tolist()
649
+ self.cam_params[f"{cam.index:03}"]["T"] = cam.translation.tolist()
650
+ with open(self.opt.camera_path.replace(".json", "_updated.json"), "w") as file:
651
+ json.dump(self.cam_params, file, indent=4)
652
+ self.save_model(mode='model')
653
+ self.save_model(mode='geo+tex')
654
+
655
+
656
+ if __name__ == "__main__":
657
+ import argparse
658
+ from omegaconf import OmegaConf
659
+
660
+ parser = argparse.ArgumentParser()
661
+ parser.add_argument("--config", required=True, help="path to the yaml config file")
662
+ args, extras = parser.parse_known_args()
663
+
664
+ # override default config from cli
665
+ opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras))
666
+
667
+ gui = GUI(opt)
668
+
669
+ gui.train(opt.iters)
sparseags/main_stage2.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import json
4
+ import time
5
+ import copy
6
+ import tqdm
7
+ import rembg
8
+ import trimesh
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import numpy as np
12
+ import pandas as pd
13
+
14
+ from kiui.lpips import LPIPS
15
+
16
+ import sys
17
+ sys.path.append('./')
18
+
19
+ from sparseags.cam_utils import orbit_camera, OrbitCamera, mat2latlon, find_mask_center_and_translate
20
+ from sparseags.render_utils.gs_renderer import CustomCamera
21
+ from sparseags.mesh_utils.mesh_renderer import Renderer
22
+
23
+
24
+ class GUI:
25
+ def __init__(self, opt):
26
+ self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
27
+ self.gui = opt.gui # enable gui
28
+ self.W = opt.W
29
+ self.H = opt.H
30
+
31
+ self.mode = "image"
32
+ self.seed = 0
33
+
34
+ self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32)
35
+ self.need_update = True # update buffer_image
36
+
37
+ # models
38
+ self.device = torch.device("cuda")
39
+ self.bg_remover = None
40
+
41
+ self.guidance_sd = None
42
+ self.guidance_zero123 = None
43
+ self.guidance_dino = None
44
+
45
+ self.enable_sd = False
46
+ self.enable_zero123 = False
47
+ self.enable_dino = False
48
+
49
+ # renderer
50
+ self.renderer = Renderer(opt).to(self.device)
51
+
52
+ # input image
53
+ self.input_img = None
54
+ self.input_mask = None
55
+ self.input_img_torch = None
56
+ self.input_mask_torch = None
57
+ self.overlay_input_img = False
58
+ self.overlay_input_img_ratio = 0.5
59
+
60
+ # input text
61
+ self.prompt = ""
62
+ self.negative_prompt = ""
63
+
64
+ # training stuff
65
+ self.training = False
66
+ self.optimizer = None
67
+ self.step = 0
68
+ self.train_steps = 1 # steps per rendering loop
69
+
70
+ # load input data
71
+ self.load_input(self.opt.camera_path, self.opt.order_path)
72
+
73
+ # override prompt from cmdline
74
+ if self.opt.prompt is not None:
75
+ self.prompt = self.opt.prompt
76
+ if self.opt.negative_prompt is not None:
77
+ self.negative_prompt = self.opt.negative_prompt
78
+
79
+ def seed_everything(self):
80
+ try:
81
+ seed = int(self.seed)
82
+ except:
83
+ seed = np.random.randint(0, 1000000)
84
+
85
+ os.environ["PYTHONHASHSEED"] = str(seed)
86
+ np.random.seed(seed)
87
+ torch.manual_seed(seed)
88
+ torch.cuda.manual_seed(seed)
89
+ torch.backends.cudnn.deterministic = True
90
+ torch.backends.cudnn.benchmark = True
91
+
92
+ self.last_seed = seed
93
+
94
+ def prepare_train(self):
95
+
96
+ self.step = 0
97
+
98
+ # setup training
99
+ self.optimizer = torch.optim.Adam(self.renderer.get_params())
100
+
101
+ cameras = [CustomCamera(v, index=int(k)) for k, v in self.cam_params.items() if v["flag"]]
102
+ cam_centers = [mat2latlon(cam.camera_center) for cam in cameras]
103
+ self.opt.ref_polars = [float(cam[0]) for cam in cam_centers]
104
+ self.opt.ref_azimuths = [float(cam[1]) for cam in cam_centers]
105
+ self.opt.ref_radii = [float(cam[2]) for cam in cam_centers]
106
+ self.cams = [(cam.c2w, cam.perspective, cam.focal_length) for cam in cameras]
107
+ self.cam = copy.deepcopy(cameras[0])
108
+
109
+ # Azimuth Mapping: [-180, -135): -4, [-135, -90): -3, [-90, -45): -2, [-45, 0): -1,
110
+ # [0, 45): 0, [45, 90): 1, [90, 135): 2, [135, 180): 3.
111
+ # Elevation Mapping: [0, 90): 0, [-90, 0): 1.
112
+
113
+ # Principal Point Pool: Tensor (2, 8, 2), where:
114
+ # - 2: Elevation groups, 8: Azimuth intervals, 2: x, y coordinates (init to 128).
115
+
116
+ # we created a "pool" for principal points
117
+ # we use these principal points to render image to make sure object is at the center
118
+ self.pp_pools = torch.full((2, 8, 2), 128)
119
+
120
+ # The intrinsics is the average over all cams
121
+ self.cam.fx = np.array([cam.fx for cam in cameras], dtype=np.float32).mean()
122
+ self.cam.fy = np.array([cam.fy for cam in cameras], dtype=np.float32).mean()
123
+ self.cam.cx = np.array([cam.cx for cam in cameras], dtype=np.float32).mean()
124
+ self.cam.cy = np.array([cam.cy for cam in cameras], dtype=np.float32).mean()
125
+
126
+ self.enable_sd = self.opt.lambda_sd > 0 and self.prompt != ""
127
+ self.enable_zero123 = self.opt.lambda_zero123 > 0 and self.input_img is not None
128
+ self.enable_dino = self.opt.lambda_dino > 0
129
+
130
+ # lazy load guidance model
131
+ if self.guidance_sd is None and self.enable_sd:
132
+ if self.opt.mvdream:
133
+ print(f"[INFO] loading MVDream...")
134
+ from guidance.mvdream_utils import MVDream
135
+ self.guidance_sd = MVDream(self.device)
136
+ print(f"[INFO] loaded MVDream!")
137
+ else:
138
+ print(f"[INFO] loading SD...")
139
+ from guidance.sd_utils import StableDiffusion
140
+ self.guidance_sd = StableDiffusion(self.device)
141
+ print(f"[INFO] loaded SD!")
142
+
143
+ if self.guidance_zero123 is None and self.enable_zero123:
144
+ print(f"[INFO] loading zero123...")
145
+ from sparseags.guidance_utils.zero123_6d_utils import Zero123
146
+ self.guidance_zero123 = Zero123(self.device, model_key='ashawkey/zero123-xl-diffusers')
147
+ print(f"[INFO] loaded zero123!")
148
+
149
+ if self.guidance_dino is None and self.enable_dino:
150
+ print(f"[INFO] loading dino...")
151
+ from guidance.dino_utils import Dino
152
+ self.guidance_dino = Dino(self.device, n_components=36, model_key="dinov2_vits14")
153
+ self.guidance_dino.fit_pca(self.all_input_images)
154
+ print(f"[INFO] loaded dino!")
155
+
156
+ # input image
157
+ if self.input_img is not None:
158
+ self.input_img_torch = torch.from_numpy(self.input_img).permute(0, 3, 1, 2).to(self.device)
159
+ self.input_img_torch = F.interpolate(self.input_img_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False)
160
+
161
+ self.input_mask_torch = torch.from_numpy(self.input_mask).permute(0, 3, 1, 2).to(self.device)
162
+ self.input_mask_torch = F.interpolate(self.input_mask_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False)
163
+ self.input_img_torch_channel_last = self.input_img_torch.permute(0, 2, 3, 1).contiguous()
164
+
165
+ # prepare embeddings
166
+ with torch.no_grad():
167
+
168
+ if self.enable_sd:
169
+ self.guidance_sd.get_text_embeds([self.prompt], [self.negative_prompt])
170
+
171
+ if self.enable_zero123:
172
+ self.guidance_zero123.get_img_embeds(self.input_img_torch)
173
+
174
+ if self.enable_dino:
175
+ self.guidance_dino.embeddings = self.guidance_dino.get_dino_embeds(self.input_img_torch, upscale=True, reduced=True, learned_up=True) # [8, 18, 18, 36]
176
+
177
+ def train_step(self):
178
+ starter = torch.cuda.Event(enable_timing=True)
179
+ ender = torch.cuda.Event(enable_timing=True)
180
+ starter.record()
181
+
182
+
183
+ for _ in range(self.train_steps):
184
+
185
+ self.step += 1
186
+ step_ratio = min(1, self.step / self.opt.iters_refine)
187
+
188
+ loss = 0
189
+
190
+ ### known view
191
+ for choice in range(self.num_views):
192
+ ssaa = min(2.0, max(0.125, 2 * np.random.random()))
193
+ out = self.renderer.render(*self.cams[choice][:2], self.opt.ref_size, self.opt.ref_size, ssaa=ssaa)
194
+
195
+ # rgb loss
196
+ image = out["image"] # [H, W, 3] in [0, 1]
197
+ valid_mask = (out["alpha"] > 0).detach()
198
+ loss = loss + F.mse_loss(image * valid_mask, self.input_img_torch_channel_last[choice] * valid_mask)
199
+
200
+ if self.enable_dino:
201
+ feature = out["feature"]
202
+ loss = loss + F.mse_loss(feature * valid_mask, self.guidance_dino.embeddings[choice] * valid_mask)
203
+
204
+ ### novel view (manual batch)
205
+ render_resolution = 512
206
+ images = []
207
+ vers, hors, radii = [], [], []
208
+ # avoid too large elevation (> 80 or < -80), and make sure it always cover [-30, 30]
209
+ # min_ver = max(min(-30, -30 - self.opt.elevation), -80 - self.opt.elevation)
210
+ # max_ver = min(max(30, 30 - self.opt.elevation), 80 - self.opt.elevation)
211
+ # min_ver = max(min(-30, -30 + np.array(self.opt.ref_polars).min()), -80)
212
+ # max_ver = min(max(30, 30 + np.array(self.opt.ref_polars).max()), 80)
213
+ min_ver = max(-30 + np.array(self.opt.ref_polars).min(), -80)
214
+ max_ver = min(30 + np.array(self.opt.ref_polars).max(), 80)
215
+
216
+ for _ in range(self.opt.batch_size):
217
+
218
+ # render random view
219
+ ver = np.random.randint(min_ver, max_ver) - self.opt.ref_polars[0]
220
+ hor = np.random.randint(-180, 180)
221
+ radius = 0
222
+
223
+ vers.append(ver)
224
+ hors.append(hor)
225
+ radii.append(radius)
226
+
227
+ pose = orbit_camera(self.opt.ref_polars[0] + ver, self.opt.ref_azimuths[0] + hor, np.array(self.opt.ref_radii).mean() + radius)
228
+
229
+ # random render resolution
230
+ ssaa = min(2.0, max(0.125, 2 * np.random.random()))
231
+
232
+ # Azimuth
233
+ # [-180, -135): -4, [-135, -90): -3, [-90, -45): -2, [-45, 0): -1
234
+ # [0, 45): 0, [45, 90): 1, [90, 135): 2, [135, 180): 3.
235
+ # Elevation: [0, 90): 0 [-90, 0): 1
236
+ idx_ver, idx_hor = int((self.opt.ref_polars[0]+ver) < 0), hor // 45
237
+
238
+ flag = 0
239
+ cx, cy = self.pp_pools[idx_ver, idx_hor+4].tolist()
240
+ cnt = 0
241
+
242
+ while not flag:
243
+
244
+ self.cam.cx = cx
245
+ self.cam.cy = cy
246
+
247
+ if cnt >= 5:
248
+ print(f"[ERROR] Something must be wrong here!")
249
+ break
250
+
251
+ # We modified the field of view. Otherwise, the rendered object will be too small
252
+ out = self.renderer.render(pose, self.cam.perspective, render_resolution, render_resolution, ssaa=ssaa)
253
+
254
+ image = out["image"]
255
+ image = image.permute(2, 0, 1).contiguous().unsqueeze(0)
256
+ mask = out["alpha"] > 0
257
+ mask = mask.permute(2, 0, 1).contiguous().unsqueeze(0)
258
+ delta_xy = find_mask_center_and_translate(image.detach(), mask.detach()) / render_resolution * 256
259
+
260
+ if delta_xy[0].abs() > 10 or delta_xy[1].abs() > 10:
261
+ cx -= delta_xy[0]
262
+ cy -= delta_xy[1]
263
+ self.pp_pools[idx_ver, idx_hor+4] = torch.tensor([cx, cy]) # Update pp_pools
264
+ cnt += 1
265
+ else:
266
+ flag = 1
267
+
268
+ images.append(image)
269
+
270
+ images = torch.cat(images, dim=0)
271
+
272
+ # guidance loss
273
+ strength = step_ratio * 0.15 + 0.8
274
+ if self.enable_zero123:
275
+ v1 = torch.stack([torch.tensor([radius]) + self.opt.ref_radii[0], torch.deg2rad(torch.tensor([ver]) + self.opt.ref_polars[0]), torch.deg2rad(torch.tensor([hor]) + self.opt.ref_azimuths[0])], dim=-1) # polar,azimuth,radius are all actually delta wrt default
276
+ v2 = torch.stack([torch.tensor(self.opt.ref_radii), torch.deg2rad(torch.tensor(self.opt.ref_polars)), torch.deg2rad(torch.tensor(self.opt.ref_azimuths))], dim=-1)
277
+ angles = torch.rad2deg(self.guidance_zero123.angle_between(v1, v2)).to(self.device)
278
+ choice = torch.argmin(angles.squeeze()).item()
279
+
280
+ cond_RT = {
281
+ "c2w": self.cams[choice][0],
282
+ "focal_length": self.cams[choice][-1],
283
+ }
284
+ target_RT = {
285
+ "c2w": pose,
286
+ "focal_length": np.array(self.cam.fx, self.cam.fy),
287
+ }
288
+ cam_embed = self.guidance_zero123.get_cam_embeddings_6D(target_RT, cond_RT)
289
+
290
+ # Additionally add an idx parameter to choose the correct viewpoints
291
+ refined_images = self.guidance_zero123.refine(images, cam_embed, strength=strength, idx=choice).float()
292
+ refined_images = F.interpolate(refined_images, (render_resolution, render_resolution), mode="bilinear", align_corners=False)
293
+ loss = loss + self.opt.lambda_zero123 * F.mse_loss(images, refined_images)
294
+
295
+ if self.enable_dino:
296
+ loss_dino = self.guidance_dino.train_step(
297
+ images,
298
+ out["feature"].permute(2, 0, 1).contiguous(),
299
+ step_ratio=step_ratio if self.opt.anneal_timestep else None
300
+ )
301
+ loss = loss + self.opt.lambda_dino * loss_dino
302
+
303
+ # optimize step
304
+ loss.backward()
305
+ self.optimizer.step()
306
+ self.optimizer.zero_grad()
307
+
308
+ ender.record()
309
+ torch.cuda.synchronize()
310
+ t = starter.elapsed_time(ender)
311
+
312
+ self.need_update = True
313
+
314
+ def load_input(self, camera_path, order_path=None):
315
+ # load image
316
+ print(f'[INFO] load data from {camera_path}...')
317
+
318
+ if order_path is not None:
319
+ with open(order_path, 'r') as f:
320
+ indices = json.load(f)
321
+ else:
322
+ indices = None
323
+
324
+ with open(camera_path, 'r') as f:
325
+ data = json.load(f)
326
+
327
+ self.cam_params = {}
328
+ for k, v in data.items():
329
+ if indices is None:
330
+ self.cam_params[k] = data[k]
331
+ else:
332
+ if int(k) in indices or k in indices:
333
+ self.cam_params[k] = data[k]
334
+
335
+ if self.opt.all_views:
336
+ v['flag'] = 1
337
+
338
+ img_paths = [v["filepath"] for k, v in self.cam_params.items() if v["flag"]]
339
+ self.num_views = len(img_paths)
340
+ print(f"[INFO] Number of views: {self.num_views}")
341
+
342
+ for filepath in img_paths:
343
+ print(filepath)
344
+
345
+ images, masks = [], []
346
+
347
+ for i in range(len(img_paths)):
348
+ img = cv2.imread(img_paths[i], cv2.IMREAD_UNCHANGED)
349
+ if img.shape[-1] == 3:
350
+ if self.bg_remover is None:
351
+ self.bg_remover = rembg.new_session()
352
+ img = rembg.remove(img, session=self.bg_remover)
353
+
354
+ img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_AREA)
355
+ img = img.astype(np.float32) / 255.0
356
+
357
+ input_mask = img[..., 3:]
358
+ # white bg
359
+ input_img = img[..., :3] * input_mask + (1 - input_mask)
360
+ # bgr to rgb
361
+ input_img = input_img[..., ::-1].copy()
362
+
363
+ images.append(input_img), masks.append(input_mask)
364
+
365
+ images = np.stack(images, axis=0)
366
+ masks = np.stack(masks, axis=0)
367
+ self.input_img = images[:self.num_views]
368
+ self.input_mask = masks[:self.num_views]
369
+ self.all_input_images = images
370
+
371
+ def save_model(self):
372
+ os.makedirs(self.opt.outdir, exist_ok=True)
373
+
374
+ path = os.path.join(self.opt.outdir, self.opt.save_path + '.' + self.opt.mesh_format)
375
+ self.renderer.export_mesh(path)
376
+
377
+ print(f"[INFO] save model to {path}.")
378
+
379
+ # no gui mode
380
+ def train(self, iters=500):
381
+ if iters > 0:
382
+ self.prepare_train()
383
+ for i in tqdm.trange(iters):
384
+ self.train_step()
385
+ # save
386
+ self.save_model()
387
+
388
+
389
+ if __name__ == "__main__":
390
+ import argparse
391
+ from omegaconf import OmegaConf
392
+
393
+ parser = argparse.ArgumentParser()
394
+ parser.add_argument("--config", required=True, help="path to the yaml config file")
395
+ args, extras = parser.parse_known_args()
396
+
397
+ # override default config from cli
398
+ opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras))
399
+
400
+ # auto find mesh from stage 1
401
+ if opt.mesh is None:
402
+ default_path = os.path.join(opt.outdir, opt.save_path + '_mesh.' + opt.mesh_format)
403
+ if os.path.exists(default_path):
404
+ opt.mesh = default_path
405
+ else:
406
+ raise ValueError(f"Cannot find mesh from {default_path}, must specify --mesh explicitly!")
407
+
408
+ gui = GUI(opt)
409
+
410
+ gui.train(opt.iters_refine)
sparseags/mesh_utils/grid_put.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def stride_from_shape(shape):
6
+ stride = [1]
7
+ for x in reversed(shape[1:]):
8
+ stride.append(stride[-1] * x)
9
+ return list(reversed(stride))
10
+
11
+
12
+ def scatter_add_nd(input, indices, values):
13
+ # input: [..., C], D dimension + C channel
14
+ # indices: [N, D], long
15
+ # values: [N, C]
16
+
17
+ D = indices.shape[-1]
18
+ C = input.shape[-1]
19
+ size = input.shape[:-1]
20
+ stride = stride_from_shape(size)
21
+
22
+ assert len(size) == D
23
+
24
+ input = input.view(-1, C) # [HW, C]
25
+ flatten_indices = (indices * torch.tensor(stride, dtype=torch.long, device=indices.device)).sum(-1) # [N]
26
+
27
+ input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values)
28
+
29
+ return input.view(*size, C)
30
+
31
+
32
+ def scatter_add_nd_with_count(input, count, indices, values, weights=None):
33
+ # input: [..., C], D dimension + C channel
34
+ # count: [..., 1], D dimension
35
+ # indices: [N, D], long
36
+ # values: [N, C]
37
+
38
+ D = indices.shape[-1]
39
+ C = input.shape[-1]
40
+ size = input.shape[:-1]
41
+ stride = stride_from_shape(size)
42
+
43
+ assert len(size) == D
44
+
45
+ input = input.view(-1, C) # [HW, C]
46
+ count = count.view(-1, 1)
47
+
48
+ flatten_indices = (indices * torch.tensor(stride, dtype=torch.long, device=indices.device)).sum(-1) # [N]
49
+
50
+ if weights is None:
51
+ weights = torch.ones_like(values[..., :1])
52
+
53
+ input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values)
54
+ count.scatter_add_(0, flatten_indices.unsqueeze(1), weights)
55
+
56
+ return input.view(*size, C), count.view(*size, 1)
57
+
58
+ def nearest_grid_put_2d(H, W, coords, values, return_count=False):
59
+ # coords: [N, 2], float in [-1, 1]
60
+ # values: [N, C]
61
+
62
+ C = values.shape[-1]
63
+
64
+ indices = (coords * 0.5 + 0.5) * torch.tensor(
65
+ [H - 1, W - 1], dtype=torch.float32, device=coords.device
66
+ )
67
+ indices = indices.round().long() # [N, 2]
68
+
69
+ result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C]
70
+ count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1]
71
+ weights = torch.ones_like(values[..., :1]) # [N, 1]
72
+
73
+ result, count = scatter_add_nd_with_count(result, count, indices, values, weights)
74
+
75
+ if return_count:
76
+ return result, count
77
+
78
+ mask = (count.squeeze(-1) > 0)
79
+ result[mask] = result[mask] / count[mask].repeat(1, C)
80
+
81
+ return result
82
+
83
+
84
+ def linear_grid_put_2d(H, W, coords, values, return_count=False):
85
+ # coords: [N, 2], float in [-1, 1]
86
+ # values: [N, C]
87
+
88
+ C = values.shape[-1]
89
+
90
+ indices = (coords * 0.5 + 0.5) * torch.tensor(
91
+ [H - 1, W - 1], dtype=torch.float32, device=coords.device
92
+ )
93
+ indices_00 = indices.floor().long() # [N, 2]
94
+ indices_00[:, 0].clamp_(0, H - 2)
95
+ indices_00[:, 1].clamp_(0, W - 2)
96
+ indices_01 = indices_00 + torch.tensor(
97
+ [0, 1], dtype=torch.long, device=indices.device
98
+ )
99
+ indices_10 = indices_00 + torch.tensor(
100
+ [1, 0], dtype=torch.long, device=indices.device
101
+ )
102
+ indices_11 = indices_00 + torch.tensor(
103
+ [1, 1], dtype=torch.long, device=indices.device
104
+ )
105
+
106
+ h = indices[..., 0] - indices_00[..., 0].float()
107
+ w = indices[..., 1] - indices_00[..., 1].float()
108
+ w_00 = (1 - h) * (1 - w)
109
+ w_01 = (1 - h) * w
110
+ w_10 = h * (1 - w)
111
+ w_11 = h * w
112
+
113
+ result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C]
114
+ count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1]
115
+ weights = torch.ones_like(values[..., :1]) # [N, 1]
116
+
117
+ result, count = scatter_add_nd_with_count(result, count, indices_00, values * w_00.unsqueeze(1), weights* w_00.unsqueeze(1))
118
+ result, count = scatter_add_nd_with_count(result, count, indices_01, values * w_01.unsqueeze(1), weights* w_01.unsqueeze(1))
119
+ result, count = scatter_add_nd_with_count(result, count, indices_10, values * w_10.unsqueeze(1), weights* w_10.unsqueeze(1))
120
+ result, count = scatter_add_nd_with_count(result, count, indices_11, values * w_11.unsqueeze(1), weights* w_11.unsqueeze(1))
121
+
122
+ if return_count:
123
+ return result, count
124
+
125
+ mask = (count.squeeze(-1) > 0)
126
+ result[mask] = result[mask] / count[mask].repeat(1, C)
127
+
128
+ return result
129
+
130
+ def mipmap_linear_grid_put_2d(H, W, coords, values, min_resolution=32, return_count=False):
131
+ # coords: [N, 2], float in [-1, 1]
132
+ # values: [N, C]
133
+
134
+ C = values.shape[-1]
135
+
136
+ result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C]
137
+ count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1]
138
+
139
+ cur_H, cur_W = H, W
140
+
141
+ while min(cur_H, cur_W) > min_resolution:
142
+
143
+ # try to fill the holes
144
+ mask = (count.squeeze(-1) == 0)
145
+ if not mask.any():
146
+ break
147
+
148
+ cur_result, cur_count = linear_grid_put_2d(cur_H, cur_W, coords, values, return_count=True)
149
+ result[mask] = result[mask] + F.interpolate(cur_result.permute(2,0,1).unsqueeze(0).contiguous(), (H, W), mode='bilinear', align_corners=False).squeeze(0).permute(1,2,0).contiguous()[mask]
150
+ count[mask] = count[mask] + F.interpolate(cur_count.view(1, 1, cur_H, cur_W), (H, W), mode='bilinear', align_corners=False).view(H, W, 1)[mask]
151
+ cur_H //= 2
152
+ cur_W //= 2
153
+
154
+ if return_count:
155
+ return result, count
156
+
157
+ mask = (count.squeeze(-1) > 0)
158
+ result[mask] = result[mask] / count[mask].repeat(1, C)
159
+
160
+ return result
161
+
162
+ def nearest_grid_put_3d(H, W, D, coords, values, return_count=False):
163
+ # coords: [N, 3], float in [-1, 1]
164
+ # values: [N, C]
165
+
166
+ C = values.shape[-1]
167
+
168
+ indices = (coords * 0.5 + 0.5) * torch.tensor(
169
+ [H - 1, W - 1, D - 1], dtype=torch.float32, device=coords.device
170
+ )
171
+ indices = indices.round().long() # [N, 2]
172
+
173
+ result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype) # [H, W, C]
174
+ count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype) # [H, W, 1]
175
+ weights = torch.ones_like(values[..., :1]) # [N, 1]
176
+
177
+ result, count = scatter_add_nd_with_count(result, count, indices, values, weights)
178
+
179
+ if return_count:
180
+ return result, count
181
+
182
+ mask = (count.squeeze(-1) > 0)
183
+ result[mask] = result[mask] / count[mask].repeat(1, C)
184
+
185
+ return result
186
+
187
+
188
+ def linear_grid_put_3d(H, W, D, coords, values, return_count=False):
189
+ # coords: [N, 3], float in [-1, 1]
190
+ # values: [N, C]
191
+
192
+ C = values.shape[-1]
193
+
194
+ indices = (coords * 0.5 + 0.5) * torch.tensor(
195
+ [H - 1, W - 1, D - 1], dtype=torch.float32, device=coords.device
196
+ )
197
+ indices_000 = indices.floor().long() # [N, 3]
198
+ indices_000[:, 0].clamp_(0, H - 2)
199
+ indices_000[:, 1].clamp_(0, W - 2)
200
+ indices_000[:, 2].clamp_(0, D - 2)
201
+
202
+ indices_001 = indices_000 + torch.tensor([0, 0, 1], dtype=torch.long, device=indices.device)
203
+ indices_010 = indices_000 + torch.tensor([0, 1, 0], dtype=torch.long, device=indices.device)
204
+ indices_011 = indices_000 + torch.tensor([0, 1, 1], dtype=torch.long, device=indices.device)
205
+ indices_100 = indices_000 + torch.tensor([1, 0, 0], dtype=torch.long, device=indices.device)
206
+ indices_101 = indices_000 + torch.tensor([1, 0, 1], dtype=torch.long, device=indices.device)
207
+ indices_110 = indices_000 + torch.tensor([1, 1, 0], dtype=torch.long, device=indices.device)
208
+ indices_111 = indices_000 + torch.tensor([1, 1, 1], dtype=torch.long, device=indices.device)
209
+
210
+ h = indices[..., 0] - indices_000[..., 0].float()
211
+ w = indices[..., 1] - indices_000[..., 1].float()
212
+ d = indices[..., 2] - indices_000[..., 2].float()
213
+
214
+ w_000 = (1 - h) * (1 - w) * (1 - d)
215
+ w_001 = (1 - h) * w * (1 - d)
216
+ w_010 = h * (1 - w) * (1 - d)
217
+ w_011 = h * w * (1 - d)
218
+ w_100 = (1 - h) * (1 - w) * d
219
+ w_101 = (1 - h) * w * d
220
+ w_110 = h * (1 - w) * d
221
+ w_111 = h * w * d
222
+
223
+ result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype) # [H, W, D, C]
224
+ count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype) # [H, W, D, 1]
225
+ weights = torch.ones_like(values[..., :1]) # [N, 1]
226
+
227
+ result, count = scatter_add_nd_with_count(result, count, indices_000, values * w_000.unsqueeze(1), weights * w_000.unsqueeze(1))
228
+ result, count = scatter_add_nd_with_count(result, count, indices_001, values * w_001.unsqueeze(1), weights * w_001.unsqueeze(1))
229
+ result, count = scatter_add_nd_with_count(result, count, indices_010, values * w_010.unsqueeze(1), weights * w_010.unsqueeze(1))
230
+ result, count = scatter_add_nd_with_count(result, count, indices_011, values * w_011.unsqueeze(1), weights * w_011.unsqueeze(1))
231
+ result, count = scatter_add_nd_with_count(result, count, indices_100, values * w_100.unsqueeze(1), weights * w_100.unsqueeze(1))
232
+ result, count = scatter_add_nd_with_count(result, count, indices_101, values * w_101.unsqueeze(1), weights * w_101.unsqueeze(1))
233
+ result, count = scatter_add_nd_with_count(result, count, indices_110, values * w_110.unsqueeze(1), weights * w_110.unsqueeze(1))
234
+ result, count = scatter_add_nd_with_count(result, count, indices_111, values * w_111.unsqueeze(1), weights * w_111.unsqueeze(1))
235
+
236
+ if return_count:
237
+ return result, count
238
+
239
+ mask = (count.squeeze(-1) > 0)
240
+ result[mask] = result[mask] / count[mask].repeat(1, C)
241
+
242
+ return result
243
+
244
+ def mipmap_linear_grid_put_3d(H, W, D, coords, values, min_resolution=32, return_count=False):
245
+ # coords: [N, 3], float in [-1, 1]
246
+ # values: [N, C]
247
+
248
+ C = values.shape[-1]
249
+
250
+ result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype) # [H, W, D, C]
251
+ count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype) # [H, W, D, 1]
252
+ cur_H, cur_W, cur_D = H, W, D
253
+
254
+ while min(min(cur_H, cur_W), cur_D) > min_resolution:
255
+
256
+ # try to fill the holes
257
+ mask = (count.squeeze(-1) == 0)
258
+ if not mask.any():
259
+ break
260
+
261
+ cur_result, cur_count = linear_grid_put_3d(cur_H, cur_W, cur_D, coords, values, return_count=True)
262
+ result[mask] = result[mask] + F.interpolate(cur_result.permute(3,0,1,2).unsqueeze(0).contiguous(), (H, W, D), mode='trilinear', align_corners=False).squeeze(0).permute(1,2,3,0).contiguous()[mask]
263
+ count[mask] = count[mask] + F.interpolate(cur_count.view(1, 1, cur_H, cur_W, cur_D), (H, W, D), mode='trilinear', align_corners=False).view(H, W, D, 1)[mask]
264
+ cur_H //= 2
265
+ cur_W //= 2
266
+ cur_D //= 2
267
+
268
+ if return_count:
269
+ return result, count
270
+
271
+ mask = (count.squeeze(-1) > 0)
272
+ result[mask] = result[mask] / count[mask].repeat(1, C)
273
+
274
+ return result
275
+
276
+
277
+ def grid_put(shape, coords, values, mode='linear-mipmap', min_resolution=32, return_raw=False):
278
+ # shape: [D], list/tuple
279
+ # coords: [N, D], float in [-1, 1]
280
+ # values: [N, C]
281
+
282
+ D = len(shape)
283
+ assert D in [2, 3], f'only support D == 2 or 3, but got D == {D}'
284
+
285
+ if mode == 'nearest':
286
+ if D == 2:
287
+ return nearest_grid_put_2d(*shape, coords, values, return_raw)
288
+ else:
289
+ return nearest_grid_put_3d(*shape, coords, values, return_raw)
290
+ elif mode == 'linear':
291
+ if D == 2:
292
+ return linear_grid_put_2d(*shape, coords, values, return_raw)
293
+ else:
294
+ return linear_grid_put_3d(*shape, coords, values, return_raw)
295
+ elif mode == 'linear-mipmap':
296
+ if D == 2:
297
+ return mipmap_linear_grid_put_2d(*shape, coords, values, min_resolution, return_raw)
298
+ else:
299
+ return mipmap_linear_grid_put_3d(*shape, coords, values, min_resolution, return_raw)
300
+ else:
301
+ raise NotImplementedError(f"got mode {mode}")
sparseags/mesh_utils/mesh.py ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import trimesh
5
+ import numpy as np
6
+
7
+
8
+ def dot(x, y):
9
+ return torch.sum(x * y, -1, keepdim=True)
10
+
11
+
12
+ def length(x, eps=1e-20):
13
+ return torch.sqrt(torch.clamp(dot(x, x), min=eps))
14
+
15
+
16
+ def safe_normalize(x, eps=1e-20):
17
+ return x / length(x, eps)
18
+
19
+ class Mesh:
20
+ def __init__(
21
+ self,
22
+ v=None,
23
+ f=None,
24
+ vn=None,
25
+ fn=None,
26
+ vt=None,
27
+ ft=None,
28
+ albedo=None,
29
+ vc=None, # vertex color
30
+ device=None,
31
+ ):
32
+ self.device = device
33
+ self.v = v
34
+ self.vn = vn
35
+ self.vt = vt
36
+ self.f = f
37
+ self.fn = fn
38
+ self.ft = ft
39
+ # only support a single albedo
40
+ self.albedo = albedo
41
+ # support vertex color is no albedo
42
+ self.vc = vc
43
+
44
+ self.ori_center = 0
45
+ self.ori_scale = 1
46
+
47
+ @classmethod
48
+ def load(cls, path=None, resize=True, renormal=True, retex=False, front_dir='+z', **kwargs):
49
+ # assume init with kwargs
50
+ if path is None:
51
+ mesh = cls(**kwargs)
52
+ # obj supports face uv
53
+ elif path.endswith(".obj"):
54
+ mesh = cls.load_obj(path, **kwargs)
55
+ # trimesh only supports vertex uv, but can load more formats
56
+ else:
57
+ mesh = cls.load_trimesh(path, **kwargs)
58
+
59
+ print(f"[Mesh loading] v: {mesh.v.shape}, f: {mesh.f.shape}")
60
+ # auto-normalize
61
+ if resize:
62
+ mesh.auto_size()
63
+ # auto-fix normal
64
+ if renormal or mesh.vn is None:
65
+ mesh.auto_normal()
66
+ print(f"[Mesh loading] vn: {mesh.vn.shape}, fn: {mesh.fn.shape}")
67
+ # auto-fix texcoords
68
+ if retex or (mesh.albedo is not None and mesh.vt is None):
69
+ mesh.auto_uv(cache_path=path)
70
+ print(f"[Mesh loading] vt: {mesh.vt.shape}, ft: {mesh.ft.shape}")
71
+
72
+ # rotate front dir to +z
73
+ if front_dir != "+z":
74
+ # axis switch
75
+ if "-z" in front_dir:
76
+ T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, -1]], device=mesh.device, dtype=torch.float32)
77
+ elif "+x" in front_dir:
78
+ T = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32)
79
+ elif "-x" in front_dir:
80
+ T = torch.tensor([[0, 0, -1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32)
81
+ elif "+y" in front_dir:
82
+ T = torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]], device=mesh.device, dtype=torch.float32)
83
+ elif "-y" in front_dir:
84
+ T = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=mesh.device, dtype=torch.float32)
85
+ else:
86
+ T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
87
+ # rotation (how many 90 degrees)
88
+ if '1' in front_dir:
89
+ T @= torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
90
+ elif '2' in front_dir:
91
+ T @= torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
92
+ elif '3' in front_dir:
93
+ T @= torch.tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
94
+ mesh.v @= T
95
+ mesh.vn @= T
96
+
97
+ return mesh
98
+
99
+ # load from obj file
100
+ @classmethod
101
+ def load_obj(cls, path, albedo_path=None, device=None, enable_dino=False):
102
+ assert os.path.splitext(path)[-1] == ".obj"
103
+
104
+ mesh = cls()
105
+
106
+ # device
107
+ if device is None:
108
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
109
+
110
+ mesh.device = device
111
+
112
+ # load obj
113
+ with open(path, "r") as f:
114
+ lines = f.readlines()
115
+
116
+ def parse_f_v(fv):
117
+ # pass in a vertex term of a face, return {v, vt, vn} (-1 if not provided)
118
+ # supported forms:
119
+ # f v1 v2 v3
120
+ # f v1/vt1 v2/vt2 v3/vt3
121
+ # f v1/vt1/vn1 v2/vt2/vn2 v3/vt3/vn3
122
+ # f v1//vn1 v2//vn2 v3//vn3
123
+ xs = [int(x) - 1 if x != "" else -1 for x in fv.split("/")]
124
+ xs.extend([-1] * (3 - len(xs)))
125
+ return xs[0], xs[1], xs[2]
126
+
127
+ # NOTE: we ignore usemtl, and assume the mesh ONLY uses one material (first in mtl)
128
+ vertices, texcoords, normals = [], [], []
129
+ faces, tfaces, nfaces = [], [], []
130
+ mtl_path = None
131
+
132
+ for line in lines:
133
+ split_line = line.split()
134
+ # empty line
135
+ if len(split_line) == 0:
136
+ continue
137
+ prefix = split_line[0].lower()
138
+ # mtllib
139
+ if prefix == "mtllib":
140
+ mtl_path = split_line[1]
141
+ # usemtl
142
+ elif prefix == "usemtl":
143
+ pass # ignored
144
+ # v/vn/vt
145
+ elif prefix == "v":
146
+ vertices.append([float(v) for v in split_line[1:]])
147
+ elif prefix == "vn":
148
+ normals.append([float(v) for v in split_line[1:]])
149
+ elif prefix == "vt":
150
+ val = [float(v) for v in split_line[1:]]
151
+ texcoords.append([val[0], 1.0 - val[1]])
152
+ elif prefix == "f":
153
+ vs = split_line[1:]
154
+ nv = len(vs)
155
+ v0, t0, n0 = parse_f_v(vs[0])
156
+ for i in range(nv - 2): # triangulate (assume vertices are ordered)
157
+ v1, t1, n1 = parse_f_v(vs[i + 1])
158
+ v2, t2, n2 = parse_f_v(vs[i + 2])
159
+ faces.append([v0, v1, v2])
160
+ tfaces.append([t0, t1, t2])
161
+ nfaces.append([n0, n1, n2])
162
+
163
+ mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
164
+ mesh.vt = (
165
+ torch.tensor(texcoords, dtype=torch.float32, device=device)
166
+ if len(texcoords) > 0
167
+ else None
168
+ )
169
+ mesh.vn = (
170
+ torch.tensor(normals, dtype=torch.float32, device=device)
171
+ if len(normals) > 0
172
+ else None
173
+ )
174
+
175
+ mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
176
+ mesh.ft = (
177
+ torch.tensor(tfaces, dtype=torch.int32, device=device)
178
+ if len(texcoords) > 0
179
+ else None
180
+ )
181
+ mesh.fn = (
182
+ torch.tensor(nfaces, dtype=torch.int32, device=device)
183
+ if len(normals) > 0
184
+ else None
185
+ )
186
+
187
+ # see if there is vertex color
188
+ use_vertex_color = False
189
+ if mesh.v.shape[1] == 6:
190
+ use_vertex_color = True
191
+ mesh.vc = mesh.v[:, 3:]
192
+ mesh.v = mesh.v[:, :3]
193
+ print(f"[load_obj] use vertex color: {mesh.vc.shape}")
194
+
195
+ # try to load texture image
196
+ if not use_vertex_color:
197
+ # try to retrieve mtl file
198
+ mtl_path_candidates = []
199
+ if mtl_path is not None:
200
+ mtl_path_candidates.append(mtl_path)
201
+ mtl_path_candidates.append(os.path.join(os.path.dirname(path), mtl_path))
202
+ mtl_path_candidates.append(path.replace(".obj", ".mtl"))
203
+
204
+ mtl_path = None
205
+ for candidate in mtl_path_candidates:
206
+ if os.path.exists(candidate):
207
+ mtl_path = candidate
208
+ break
209
+
210
+ # if albedo_path is not provided, try retrieve it from mtl
211
+ if mtl_path is not None and albedo_path is None:
212
+ with open(mtl_path, "r") as f:
213
+ lines = f.readlines()
214
+ for line in lines:
215
+ split_line = line.split()
216
+ # empty line
217
+ if len(split_line) == 0:
218
+ continue
219
+ prefix = split_line[0]
220
+ # NOTE: simply use the first map_Kd as albedo!
221
+ if "map_Kd" in prefix:
222
+ albedo_path = os.path.join(os.path.dirname(path), split_line[1])
223
+ print(f"[load_obj] use texture from: {albedo_path}")
224
+ # break
225
+ if "map_Ft" in prefix:
226
+ feature_path = os.path.join(os.path.dirname(path), split_line[1])
227
+ print(f"[load_obj] use feature from: {feature_path}")
228
+ break
229
+
230
+ # still not found albedo_path, or the path doesn't exist
231
+ if albedo_path is None or not os.path.exists(albedo_path):
232
+ # init an empty texture
233
+ print(f"[load_obj] init empty albedo!")
234
+ # albedo = np.random.rand(1024, 1024, 3).astype(np.float32)
235
+ albedo = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5]) # default color
236
+ else:
237
+ albedo = cv2.imread(albedo_path, cv2.IMREAD_UNCHANGED)
238
+ albedo = cv2.cvtColor(albedo, cv2.COLOR_BGR2RGB)
239
+ albedo = albedo.astype(np.float32) / 255
240
+ print(f"[load_obj] load texture: {albedo.shape}")
241
+
242
+ # import matplotlib.pyplot as plt
243
+ # plt.imshow(albedo)
244
+ # plt.show()
245
+ if enable_dino and os.path.exists(feature_path):
246
+ feature = torch.load(feature_path).to(device)
247
+ mesh.feature = feature
248
+ print(f"[load_obj] load feature: {feature.shape}")
249
+
250
+ mesh.albedo = torch.tensor(albedo, dtype=torch.float32, device=device)
251
+
252
+ return mesh
253
+
254
+ @classmethod
255
+ def load_trimesh(cls, path, device=None, enable_dino=False):
256
+ mesh = cls()
257
+
258
+ # device
259
+ if device is None:
260
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
261
+
262
+ mesh.device = device
263
+
264
+ # use trimesh to load ply/glb, assume only has one single RootMesh...
265
+ _data = trimesh.load(path)
266
+ if isinstance(_data, trimesh.Scene):
267
+ if len(_data.geometry) == 1:
268
+ _mesh = list(_data.geometry.values())[0]
269
+ else:
270
+ # manual concat, will lose texture
271
+ _concat = []
272
+ for g in _data.geometry.values():
273
+ if isinstance(g, trimesh.Trimesh):
274
+ _concat.append(g)
275
+ _mesh = trimesh.util.concatenate(_concat)
276
+ else:
277
+ _mesh = _data
278
+
279
+ if _mesh.visual.kind == 'vertex':
280
+ vertex_colors = _mesh.visual.vertex_colors
281
+ vertex_colors = np.array(vertex_colors[..., :3]).astype(np.float32) / 255
282
+ mesh.vc = torch.tensor(vertex_colors, dtype=torch.float32, device=device)
283
+ print(f"[load_trimesh] use vertex color: {mesh.vc.shape}")
284
+ elif _mesh.visual.kind == 'texture':
285
+ _material = _mesh.visual.material
286
+ if isinstance(_material, trimesh.visual.material.PBRMaterial):
287
+ texture = np.array(_material.baseColorTexture).astype(np.float32) / 255
288
+ elif isinstance(_material, trimesh.visual.material.SimpleMaterial):
289
+ texture = np.array(_material.to_pbr().baseColorTexture).astype(np.float32) / 255
290
+ else:
291
+ raise NotImplementedError(f"material type {type(_material)} not supported!")
292
+ mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device)
293
+ print(f"[load_trimesh] load texture: {texture.shape}")
294
+ else:
295
+ texture = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5])
296
+ mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device)
297
+ print(f"[load_trimesh] failed to load texture.")
298
+
299
+ vertices = _mesh.vertices
300
+
301
+ try:
302
+ texcoords = _mesh.visual.uv
303
+ texcoords[:, 1] = 1 - texcoords[:, 1]
304
+ except Exception as e:
305
+ texcoords = None
306
+
307
+ try:
308
+ normals = _mesh.vertex_normals
309
+ except Exception as e:
310
+ normals = None
311
+
312
+ # trimesh only support vertex uv...
313
+ faces = tfaces = nfaces = _mesh.faces
314
+
315
+ mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
316
+ mesh.vt = (
317
+ torch.tensor(texcoords, dtype=torch.float32, device=device)
318
+ if texcoords is not None
319
+ else None
320
+ )
321
+ mesh.vn = (
322
+ torch.tensor(normals, dtype=torch.float32, device=device)
323
+ if normals is not None
324
+ else None
325
+ )
326
+
327
+ mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
328
+ mesh.ft = (
329
+ torch.tensor(tfaces, dtype=torch.int32, device=device)
330
+ if texcoords is not None
331
+ else None
332
+ )
333
+ mesh.fn = (
334
+ torch.tensor(nfaces, dtype=torch.int32, device=device)
335
+ if normals is not None
336
+ else None
337
+ )
338
+
339
+ return mesh
340
+
341
+ # aabb
342
+ def aabb(self):
343
+ return torch.min(self.v, dim=0).values, torch.max(self.v, dim=0).values
344
+
345
+ # unit size
346
+ @torch.no_grad()
347
+ def auto_size(self):
348
+ vmin, vmax = self.aabb()
349
+ self.ori_center = (vmax + vmin) / 2
350
+ self.ori_scale = 1.2 / torch.max(vmax - vmin).item()
351
+ self.v = (self.v - self.ori_center) * self.ori_scale
352
+
353
+ def auto_normal(self):
354
+ i0, i1, i2 = self.f[:, 0].long(), self.f[:, 1].long(), self.f[:, 2].long()
355
+ v0, v1, v2 = self.v[i0, :], self.v[i1, :], self.v[i2, :]
356
+
357
+ face_normals = torch.cross(v1 - v0, v2 - v0)
358
+
359
+ # Splat face normals to vertices
360
+ vn = torch.zeros_like(self.v)
361
+ vn.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
362
+ vn.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
363
+ vn.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
364
+
365
+ # Normalize, replace zero (degenerated) normals with some default value
366
+ vn = torch.where(
367
+ dot(vn, vn) > 1e-20,
368
+ vn,
369
+ torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device),
370
+ )
371
+ vn = safe_normalize(vn)
372
+
373
+ self.vn = vn
374
+ self.fn = self.f
375
+
376
+ def auto_uv(self, cache_path=None, vmap=True):
377
+ # try to load cache
378
+ if cache_path is not None:
379
+ cache_path = os.path.splitext(cache_path)[0] + "_uv.npz"
380
+ if cache_path is not None and os.path.exists(cache_path):
381
+ data = np.load(cache_path)
382
+ vt_np, ft_np, vmapping = data["vt"], data["ft"], data["vmapping"]
383
+ else:
384
+ import xatlas
385
+
386
+ v_np = self.v.detach().cpu().numpy()
387
+ f_np = self.f.detach().int().cpu().numpy()
388
+ atlas = xatlas.Atlas()
389
+ atlas.add_mesh(v_np, f_np)
390
+ chart_options = xatlas.ChartOptions()
391
+ # chart_options.max_iterations = 4
392
+ atlas.generate(chart_options=chart_options)
393
+ vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
394
+
395
+ # save to cache
396
+ if cache_path is not None:
397
+ np.savez(cache_path, vt=vt_np, ft=ft_np, vmapping=vmapping)
398
+
399
+ vt = torch.from_numpy(vt_np.astype(np.float32)).to(self.device)
400
+ ft = torch.from_numpy(ft_np.astype(np.int32)).to(self.device)
401
+ self.vt = vt
402
+ self.ft = ft
403
+
404
+ if vmap:
405
+ # remap v/f to vt/ft, so each v correspond to a unique vt. (necessary for gltf)
406
+ vmapping = torch.from_numpy(vmapping.astype(np.int64)).long().to(self.device)
407
+ self.align_v_to_vt(vmapping)
408
+
409
+ def align_v_to_vt(self, vmapping=None):
410
+ # remap v/f and vn/vn to vt/ft.
411
+ if vmapping is None:
412
+ ft = self.ft.view(-1).long()
413
+ f = self.f.view(-1).long()
414
+ vmapping = torch.zeros(self.vt.shape[0], dtype=torch.long, device=self.device)
415
+ vmapping[ft] = f # scatter, randomly choose one if index is not unique
416
+
417
+ self.v = self.v[vmapping]
418
+ self.f = self.ft
419
+ # assume fn == f
420
+ if self.vn is not None:
421
+ self.vn = self.vn[vmapping]
422
+ self.fn = self.ft
423
+
424
+ def to(self, device):
425
+ self.device = device
426
+ for name in ["v", "f", "vn", "fn", "vt", "ft", "albedo"]:
427
+ tensor = getattr(self, name)
428
+ if tensor is not None:
429
+ setattr(self, name, tensor.to(device))
430
+ return self
431
+
432
+ def write(self, path, enable_dino=False):
433
+ if path.endswith(".ply"):
434
+ self.write_ply(path)
435
+ elif path.endswith(".obj"):
436
+ self.write_obj(path, enable_dino)
437
+ elif path.endswith(".glb") or path.endswith(".gltf"):
438
+ self.write_glb(path)
439
+ else:
440
+ raise NotImplementedError(f"format {path} not supported!")
441
+
442
+ # write to ply file (only geom)
443
+ def write_ply(self, path):
444
+
445
+ v_np = self.v.detach().cpu().numpy()
446
+ f_np = self.f.detach().cpu().numpy()
447
+
448
+ _mesh = trimesh.Trimesh(vertices=v_np, faces=f_np)
449
+ _mesh.export(path)
450
+
451
+ # write to gltf/glb file (geom + texture)
452
+ def write_glb(self, path):
453
+
454
+ assert self.vn is not None and self.vt is not None # should be improved to support export without texture...
455
+
456
+ # assert self.v.shape[0] == self.vn.shape[0] and self.v.shape[0] == self.vt.shape[0]
457
+ if self.v.shape[0] != self.vt.shape[0]:
458
+ self.align_v_to_vt()
459
+
460
+ # assume f == fn == ft
461
+
462
+ import pygltflib
463
+
464
+ f_np = self.f.detach().cpu().numpy().astype(np.uint32)
465
+ v_np = self.v.detach().cpu().numpy().astype(np.float32)
466
+ # vn_np = self.vn.detach().cpu().numpy().astype(np.float32)
467
+ vt_np = self.vt.detach().cpu().numpy().astype(np.float32)
468
+
469
+ albedo = self.albedo.detach().cpu().numpy()
470
+ albedo = (albedo * 255).astype(np.uint8)
471
+ albedo = cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR)
472
+
473
+ f_np_blob = f_np.flatten().tobytes()
474
+ v_np_blob = v_np.tobytes()
475
+ # vn_np_blob = vn_np.tobytes()
476
+ vt_np_blob = vt_np.tobytes()
477
+ albedo_blob = cv2.imencode('.png', albedo)[1].tobytes()
478
+
479
+ gltf = pygltflib.GLTF2(
480
+ scene=0,
481
+ scenes=[pygltflib.Scene(nodes=[0])],
482
+ nodes=[pygltflib.Node(mesh=0)],
483
+ meshes=[pygltflib.Mesh(primitives=[
484
+ pygltflib.Primitive(
485
+ # indices to accessors (0 is triangles)
486
+ attributes=pygltflib.Attributes(
487
+ POSITION=1, TEXCOORD_0=2,
488
+ ),
489
+ indices=0, material=0,
490
+ )
491
+ ])],
492
+ materials=[
493
+ pygltflib.Material(
494
+ pbrMetallicRoughness=pygltflib.PbrMetallicRoughness(
495
+ baseColorTexture=pygltflib.TextureInfo(index=0, texCoord=0),
496
+ metallicFactor=0.0,
497
+ roughnessFactor=1.0,
498
+ ),
499
+ alphaCutoff=0,
500
+ doubleSided=True,
501
+ )
502
+ ],
503
+ textures=[
504
+ pygltflib.Texture(sampler=0, source=0),
505
+ ],
506
+ samplers=[
507
+ pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT),
508
+ ],
509
+ images=[
510
+ # use embedded (buffer) image
511
+ pygltflib.Image(bufferView=3, mimeType="image/png"),
512
+ ],
513
+ buffers=[
514
+ pygltflib.Buffer(byteLength=len(f_np_blob) + len(v_np_blob) + len(vt_np_blob) + len(albedo_blob))
515
+ ],
516
+ # buffer view (based on dtype)
517
+ bufferViews=[
518
+ # triangles; as flatten (element) array
519
+ pygltflib.BufferView(
520
+ buffer=0,
521
+ byteLength=len(f_np_blob),
522
+ target=pygltflib.ELEMENT_ARRAY_BUFFER, # GL_ELEMENT_ARRAY_BUFFER (34963)
523
+ ),
524
+ # positions; as vec3 array
525
+ pygltflib.BufferView(
526
+ buffer=0,
527
+ byteOffset=len(f_np_blob),
528
+ byteLength=len(v_np_blob),
529
+ byteStride=12, # vec3
530
+ target=pygltflib.ARRAY_BUFFER, # GL_ARRAY_BUFFER (34962)
531
+ ),
532
+ # texcoords; as vec2 array
533
+ pygltflib.BufferView(
534
+ buffer=0,
535
+ byteOffset=len(f_np_blob) + len(v_np_blob),
536
+ byteLength=len(vt_np_blob),
537
+ byteStride=8, # vec2
538
+ target=pygltflib.ARRAY_BUFFER,
539
+ ),
540
+ # texture; as none target
541
+ pygltflib.BufferView(
542
+ buffer=0,
543
+ byteOffset=len(f_np_blob) + len(v_np_blob) + len(vt_np_blob),
544
+ byteLength=len(albedo_blob),
545
+ ),
546
+ ],
547
+ accessors=[
548
+ # 0 = triangles
549
+ pygltflib.Accessor(
550
+ bufferView=0,
551
+ componentType=pygltflib.UNSIGNED_INT, # GL_UNSIGNED_INT (5125)
552
+ count=f_np.size,
553
+ type=pygltflib.SCALAR,
554
+ max=[int(f_np.max())],
555
+ min=[int(f_np.min())],
556
+ ),
557
+ # 1 = positions
558
+ pygltflib.Accessor(
559
+ bufferView=1,
560
+ componentType=pygltflib.FLOAT, # GL_FLOAT (5126)
561
+ count=len(v_np),
562
+ type=pygltflib.VEC3,
563
+ max=v_np.max(axis=0).tolist(),
564
+ min=v_np.min(axis=0).tolist(),
565
+ ),
566
+ # 2 = texcoords
567
+ pygltflib.Accessor(
568
+ bufferView=2,
569
+ componentType=pygltflib.FLOAT,
570
+ count=len(vt_np),
571
+ type=pygltflib.VEC2,
572
+ max=vt_np.max(axis=0).tolist(),
573
+ min=vt_np.min(axis=0).tolist(),
574
+ ),
575
+ ],
576
+ )
577
+
578
+ # set actual data
579
+ gltf.set_binary_blob(f_np_blob + v_np_blob + vt_np_blob + albedo_blob)
580
+
581
+ # glb = b"".join(gltf.save_to_bytes())
582
+ gltf.save(path)
583
+
584
+ # write to obj file (geom + texture)
585
+ def write_obj(self, path, enable_dino=False):
586
+
587
+ mtl_path = path.replace(".obj", ".mtl")
588
+ albedo_path = path.replace(".obj", "_albedo.png")
589
+ feature_path = path.replace(".obj", "_feature.pt")
590
+
591
+ v_np = self.v.detach().cpu().numpy()
592
+ vt_np = self.vt.detach().cpu().numpy() if self.vt is not None else None
593
+ vn_np = self.vn.detach().cpu().numpy() if self.vn is not None else None
594
+ f_np = self.f.detach().cpu().numpy()
595
+ ft_np = self.ft.detach().cpu().numpy() if self.ft is not None else None
596
+ fn_np = self.fn.detach().cpu().numpy() if self.fn is not None else None
597
+
598
+ with open(path, "w") as fp:
599
+ fp.write(f"mtllib {os.path.basename(mtl_path)} \n")
600
+
601
+ for v in v_np:
602
+ fp.write(f"v {v[0]} {v[1]} {v[2]} \n")
603
+
604
+ if vt_np is not None:
605
+ for v in vt_np:
606
+ fp.write(f"vt {v[0]} {1 - v[1]} \n")
607
+
608
+ if vn_np is not None:
609
+ for v in vn_np:
610
+ fp.write(f"vn {v[0]} {v[1]} {v[2]} \n")
611
+
612
+ fp.write(f"usemtl defaultMat \n")
613
+ for i in range(len(f_np)):
614
+ fp.write(
615
+ f'f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1 if ft_np is not None else ""}/{fn_np[i, 0] + 1 if fn_np is not None else ""} \
616
+ {f_np[i, 1] + 1}/{ft_np[i, 1] + 1 if ft_np is not None else ""}/{fn_np[i, 1] + 1 if fn_np is not None else ""} \
617
+ {f_np[i, 2] + 1}/{ft_np[i, 2] + 1 if ft_np is not None else ""}/{fn_np[i, 2] + 1 if fn_np is not None else ""} \n'
618
+ )
619
+
620
+ with open(mtl_path, "w") as fp:
621
+ fp.write(f"newmtl defaultMat \n")
622
+ fp.write(f"Ka 1 1 1 \n")
623
+ fp.write(f"Kd 1 1 1 \n")
624
+ fp.write(f"Ks 0 0 0 \n")
625
+ fp.write(f"Tr 1 \n")
626
+ fp.write(f"illum 1 \n")
627
+ fp.write(f"Ns 0 \n")
628
+ fp.write(f"map_Kd {os.path.basename(albedo_path)} \n")
629
+ if enable_dino:
630
+ fp.write(f"map_Ft {os.path.basename(feature_path)} \n")
631
+
632
+ albedo = self.albedo.detach().cpu().numpy()
633
+ albedo = (albedo * 255).astype(np.uint8)
634
+ cv2.imwrite(albedo_path, cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR))
635
+
636
+ if enable_dino:
637
+ feature = self.feature.detach().cpu()
638
+ torch.save(feature, feature_path)
sparseags/mesh_utils/mesh_renderer.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import cv2
4
+ import trimesh
5
+ import numpy as np
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ import nvdiffrast.torch as dr
12
+ from sparseags.mesh_utils.mesh import Mesh, safe_normalize
13
+
14
+
15
+ def scale_img_nhwc(x, size, mag='bilinear', min='bilinear'):
16
+ assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[1]), "Trying to magnify image in one dimension and minify in the other"
17
+ y = x.permute(0, 3, 1, 2) # NHWC -> NCHW
18
+ if x.shape[1] > size[0] and x.shape[2] > size[1]: # Minification, previous size was bigger
19
+ y = torch.nn.functional.interpolate(y, size, mode=min)
20
+ else: # Magnification
21
+ if mag == 'bilinear' or mag == 'bicubic':
22
+ y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True)
23
+ else:
24
+ y = torch.nn.functional.interpolate(y, size, mode=mag)
25
+ return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC
26
+
27
+ def scale_img_hwc(x, size, mag='bilinear', min='bilinear'):
28
+ return scale_img_nhwc(x[None, ...], size, mag, min)[0]
29
+
30
+ def scale_img_nhw(x, size, mag='bilinear', min='bilinear'):
31
+ return scale_img_nhwc(x[..., None], size, mag, min)[..., 0]
32
+
33
+ def scale_img_hw(x, size, mag='bilinear', min='bilinear'):
34
+ return scale_img_nhwc(x[None, ..., None], size, mag, min)[0, ..., 0]
35
+
36
+ def trunc_rev_sigmoid(x, eps=1e-6):
37
+ x = x.clamp(eps, 1 - eps)
38
+ return torch.log(x / (1 - x))
39
+
40
+ def make_divisible(x, m=8):
41
+ return int(math.ceil(x / m) * m)
42
+
43
+ class Renderer(nn.Module):
44
+ def __init__(self, opt):
45
+
46
+ super().__init__()
47
+
48
+ self.opt = opt
49
+ self.enable_dino = self.opt.lambda_dino > 0
50
+
51
+ self.mesh = Mesh.load(self.opt.mesh, resize=False, enable_dino=self.enable_dino)
52
+
53
+ if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'):
54
+ self.glctx = dr.RasterizeGLContext()
55
+ else:
56
+ self.glctx = dr.RasterizeCudaContext()
57
+
58
+ self.v_offsets = torch.zeros_like(self.mesh.v)
59
+ self.raw_albedo = trunc_rev_sigmoid(self.mesh.albedo)
60
+
61
+ # extract trainable parameters
62
+ if opt.trainable_texture:
63
+ self.v_offsets = nn.Parameter(self.v_offsets)
64
+ self.raw_albedo = nn.Parameter(self.raw_albedo)
65
+
66
+ if self.enable_dino:
67
+ self.raw_feature = nn.Parameter((self.mesh.feature))
68
+
69
+
70
+ def get_params(self):
71
+
72
+ params = [
73
+ {'params': self.raw_albedo, 'lr': self.opt.texture_lr},
74
+ ]
75
+
76
+ if self.enable_dino:
77
+ params.append({'params': self.raw_feature, 'lr': self.opt.texture_lr})
78
+
79
+ if self.opt.train_geo:
80
+ params.append({'params': self.v_offsets, 'lr': self.opt.geom_lr})
81
+
82
+ return params
83
+
84
+ @torch.no_grad()
85
+ def export_mesh(self, save_path):
86
+ self.mesh.v = (self.mesh.v + self.v_offsets).detach()
87
+ self.mesh.albedo = torch.sigmoid(self.raw_albedo.detach())
88
+ if self.enable_dino:
89
+ self.mesh.feature = self.raw_feature.detach()
90
+ self.mesh.write(save_path, self.enable_dino)
91
+
92
+
93
+ def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1, texture_filter='linear-mipmap-linear'):
94
+
95
+ # do super-sampling
96
+ if ssaa != 1:
97
+ h = make_divisible(h0 * ssaa, 8)
98
+ w = make_divisible(w0 * ssaa, 8)
99
+ else:
100
+ h, w = h0, w0
101
+
102
+ results = {}
103
+
104
+ # get v
105
+ if self.opt.train_geo:
106
+ v = self.mesh.v + self.v_offsets # [N, 3]
107
+ else:
108
+ v = self.mesh.v
109
+
110
+ pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)
111
+ proj = torch.from_numpy(proj.astype(np.float32)).to(v.device)
112
+
113
+ # get v_clip and render rgb
114
+ v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
115
+ v_clip = v_cam @ proj.T
116
+
117
+ rast, rast_db = dr.rasterize(self.glctx, v_clip, self.mesh.f, (h, w))
118
+
119
+ alpha = (rast[0, ..., 3:] > 0).float()
120
+ depth, _ = dr.interpolate(-v_cam[..., [2]], rast, self.mesh.f) # [1, H, W, 1]
121
+ depth = depth.squeeze(0) # [H, W, 1]
122
+
123
+ texc, texc_db = dr.interpolate(self.mesh.vt.unsqueeze(0).contiguous(), rast, self.mesh.ft, rast_db=rast_db, diff_attrs='all')
124
+ albedo = dr.texture(self.raw_albedo.unsqueeze(0), texc, uv_da=texc_db, filter_mode=texture_filter) # [1, H, W, 3]
125
+ albedo = torch.sigmoid(albedo)
126
+ if self.enable_dino:
127
+ # NOTE: backward error when use filter_mode='linear-mipmap-linear'
128
+ feature = dr.texture(self.raw_feature.unsqueeze(0), texc, uv_da=texc_db, filter_mode='linear')
129
+ # feature = torch.sigmoid(feature)
130
+ # get vn and render normal
131
+ if self.opt.train_geo:
132
+ i0, i1, i2 = self.mesh.f[:, 0].long(), self.mesh.f[:, 1].long(), self.mesh.f[:, 2].long()
133
+ v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :]
134
+
135
+ face_normals = torch.cross(v1 - v0, v2 - v0)
136
+ face_normals = safe_normalize(face_normals)
137
+
138
+ vn = torch.zeros_like(v)
139
+ vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)
140
+ vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)
141
+ vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)
142
+
143
+ vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))
144
+ else:
145
+ vn = self.mesh.vn
146
+
147
+ normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, self.mesh.fn)
148
+ normal = safe_normalize(normal[0])
149
+
150
+ # rotated normal (where [0, 0, 1] always faces camera)
151
+ rot_normal = normal @ pose[:3, :3]
152
+ viewcos = rot_normal[..., [2]]
153
+
154
+ # antialias
155
+ albedo = dr.antialias(albedo, rast, v_clip, self.mesh.f).squeeze(0) # [H, W, 3]
156
+ albedo = alpha * albedo + (1 - alpha) * bg_color
157
+
158
+ if self.enable_dino:
159
+ feature = dr.antialias(feature, rast, v_clip, self.mesh.f).squeeze(0) # [H, W, 3]
160
+ feature = alpha * feature + (1 - alpha) * bg_color
161
+
162
+ # ssaa
163
+ if ssaa != 1:
164
+ albedo = scale_img_hwc(albedo, (h0, w0))
165
+ alpha = scale_img_hwc(alpha, (h0, w0))
166
+ depth = scale_img_hwc(depth, (h0, w0))
167
+ normal = scale_img_hwc(normal, (h0, w0))
168
+ viewcos = scale_img_hwc(viewcos, (h0, w0))
169
+ if self.enable_dino:
170
+ feature = scale_img_hwc(feature, (h0, w0))
171
+
172
+ results['image'] = albedo.clamp(0, 1)
173
+ results['alpha'] = alpha
174
+ results['depth'] = depth
175
+ results['normal'] = (normal + 1) / 2
176
+ results['viewcos'] = viewcos
177
+ results['feature'] = feature if self.enable_dino else None # [H, W, 384]
178
+
179
+ return results
180
+
181
+
182
+ def render_batch(self, pose, proj, h0, w0, ssaa=1, bg_color=1, texture_filter='linear-mipmap-linear'):
183
+
184
+ # do super-sampling
185
+ if ssaa != 1:
186
+ h = make_divisible(h0 * ssaa, 8)
187
+ w = make_divisible(w0 * ssaa, 8)
188
+ else:
189
+ h, w = h0, w0
190
+
191
+ results = {}
192
+
193
+ # get v
194
+ if self.opt.train_geo:
195
+ v = self.mesh.v + self.v_offsets # [N, 3]
196
+ else:
197
+ v = self.mesh.v
198
+
199
+ bs = pose.shape[0]
200
+ pose = pose.to(v.device)
201
+ proj = proj.to(v.device).transpose(1, 2)
202
+
203
+ # get v_clip and render rgb
204
+ v_cam = torch.bmm(F.pad(v, pad=(0, 1), mode='constant', value=1.0).expand(bs, -1, -1), torch.linalg.inv(pose).transpose(1, 2)).float()
205
+ v_clip = torch.bmm(v_cam, proj)
206
+
207
+ rast, rast_db = dr.rasterize(self.glctx, v_clip, self.mesh.f, (h, w))
208
+
209
+ alpha = (rast[..., 3:] > 0).float()
210
+ depth, _ = dr.interpolate(-v_cam[..., [2]], rast, self.mesh.f) # [1, H, W, 1]
211
+
212
+ texc, texc_db = dr.interpolate(self.mesh.vt.expand(bs, -1, -1).contiguous(), rast, self.mesh.ft, rast_db=rast_db, diff_attrs='all')
213
+ albedo = dr.texture(self.raw_albedo.detach().unsqueeze(0).contiguous(), texc, uv_da=texc_db, filter_mode=texture_filter) # [1, H, W, 3]
214
+ albedo = torch.sigmoid(albedo)
215
+ if self.enable_dino:
216
+ # NOTE: backward error when use filter_mode='linear-mipmap-linear'
217
+ feature = dr.texture(self.raw_feature.unsqueeze(0), texc, uv_da=texc_db, filter_mode='linear')
218
+ # feature = torch.sigmoid(feature)
219
+ # get vn and render normal
220
+ if self.opt.train_geo:
221
+ i0, i1, i2 = self.mesh.f[:, 0].long(), self.mesh.f[:, 1].long(), self.mesh.f[:, 2].long()
222
+ v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :]
223
+
224
+ face_normals = torch.cross(v1 - v0, v2 - v0)
225
+ face_normals = safe_normalize(face_normals)
226
+
227
+ vn = torch.zeros_like(v)
228
+ vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)
229
+ vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)
230
+ vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)
231
+
232
+ vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))
233
+ else:
234
+ vn = self.mesh.vn
235
+
236
+ normal, _ = dr.interpolate(vn.expand(bs, -1, -1).contiguous(), rast, self.mesh.fn)
237
+ normal = safe_normalize(normal).reshape(bs, -1, 3)
238
+
239
+ # rotated normal (where [0, 0, 1] always faces camera)
240
+ rot_normal = torch.bmm(normal, pose[:, :3, :3]).reshape(bs, h, w, 3)
241
+ viewcos = rot_normal[..., [2]]
242
+
243
+ # antialias
244
+ albedo = dr.antialias(albedo, rast, v_clip, self.mesh.f) # [H, W, 3]
245
+ albedo = alpha * albedo + (1 - alpha) * bg_color
246
+
247
+ if self.enable_dino:
248
+ feature = dr.antialias(feature, rast, v_clip, self.mesh.f).squeeze(0) # [H, W, 3]
249
+ feature = alpha * feature + (1 - alpha) * bg_color
250
+
251
+ # ssaa
252
+ if ssaa != 1:
253
+ albedo = scale_img_hwc(albedo, (h0, w0))
254
+ alpha = scale_img_hwc(alpha, (h0, w0))
255
+ depth = scale_img_hwc(depth, (h0, w0))
256
+ normal = scale_img_hwc(normal, (h0, w0))
257
+ viewcos = scale_img_hwc(viewcos, (h0, w0))
258
+ if self.enable_dino:
259
+ feature = scale_img_hwc(feature, (h0, w0))
260
+
261
+ results['image'] = albedo.clamp(0, 1)
262
+ results['alpha'] = alpha
263
+ results['depth'] = depth
264
+ results['normal'] = (normal + 1) / 2
265
+ results['viewcos'] = viewcos
266
+ results['feature'] = feature if self.enable_dino else None # [H, W, 384]
267
+
268
+ return results
sparseags/mesh_utils/mesh_utils.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pymeshlab as pml
3
+
4
+
5
+ def poisson_mesh_reconstruction(points, normals=None):
6
+ # points/normals: [N, 3] np.ndarray
7
+
8
+ import open3d as o3d
9
+
10
+ pcd = o3d.geometry.PointCloud()
11
+ pcd.points = o3d.utility.Vector3dVector(points)
12
+
13
+ # outlier removal
14
+ pcd, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=10)
15
+
16
+ # normals
17
+ if normals is None:
18
+ pcd.estimate_normals()
19
+ else:
20
+ pcd.normals = o3d.utility.Vector3dVector(normals[ind])
21
+
22
+ # visualize
23
+ o3d.visualization.draw_geometries([pcd], point_show_normal=False)
24
+
25
+ mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
26
+ pcd, depth=9
27
+ )
28
+ vertices_to_remove = densities < np.quantile(densities, 0.1)
29
+ mesh.remove_vertices_by_mask(vertices_to_remove)
30
+
31
+ # visualize
32
+ o3d.visualization.draw_geometries([mesh])
33
+
34
+ vertices = np.asarray(mesh.vertices)
35
+ triangles = np.asarray(mesh.triangles)
36
+
37
+ print(
38
+ f"[INFO] poisson mesh reconstruction: {points.shape} --> {vertices.shape} / {triangles.shape}"
39
+ )
40
+
41
+ return vertices, triangles
42
+
43
+
44
+ def decimate_mesh(
45
+ verts, faces, target, backend="pymeshlab", remesh=False, optimalplacement=True
46
+ ):
47
+ # optimalplacement: default is True, but for flat mesh must turn False to prevent spike artifect.
48
+
49
+ _ori_vert_shape = verts.shape
50
+ _ori_face_shape = faces.shape
51
+
52
+ if backend == "pyfqmr":
53
+ import pyfqmr
54
+
55
+ solver = pyfqmr.Simplify()
56
+ solver.setMesh(verts, faces)
57
+ solver.simplify_mesh(target_count=target, preserve_border=False, verbose=False)
58
+ verts, faces, normals = solver.getMesh()
59
+ else:
60
+ m = pml.Mesh(verts, faces)
61
+ ms = pml.MeshSet()
62
+ ms.add_mesh(m, "mesh") # will copy!
63
+
64
+ # filters
65
+ # ms.meshing_decimation_clustering(threshold=pml.Percentage(1))
66
+ ms.meshing_decimation_quadric_edge_collapse(
67
+ targetfacenum=int(target), optimalplacement=optimalplacement
68
+ )
69
+
70
+ if remesh:
71
+ # ms.apply_coord_taubin_smoothing()
72
+ ms.meshing_isotropic_explicit_remeshing(
73
+ iterations=3, targetlen=pml.Percentage(1)
74
+ )
75
+
76
+ # extract mesh
77
+ m = ms.current_mesh()
78
+ verts = m.vertex_matrix()
79
+ faces = m.face_matrix()
80
+
81
+ print(
82
+ f"[INFO] mesh decimation: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}"
83
+ )
84
+
85
+ return verts, faces
86
+
87
+
88
+ def clean_mesh(
89
+ verts,
90
+ faces,
91
+ v_pct=1,
92
+ min_f=64,
93
+ min_d=20,
94
+ repair=True,
95
+ remesh=True,
96
+ remesh_size=0.01,
97
+ ):
98
+ # verts: [N, 3]
99
+ # faces: [N, 3]
100
+
101
+ _ori_vert_shape = verts.shape
102
+ _ori_face_shape = faces.shape
103
+
104
+ m = pml.Mesh(verts, faces)
105
+ ms = pml.MeshSet()
106
+ ms.add_mesh(m, "mesh") # will copy!
107
+
108
+ # filters
109
+ ms.meshing_remove_unreferenced_vertices() # verts not refed by any faces
110
+
111
+ if v_pct > 0:
112
+ ms.meshing_merge_close_vertices(
113
+ threshold=pml.Percentage(v_pct)
114
+ ) # 1/10000 of bounding box diagonal
115
+
116
+ ms.meshing_remove_duplicate_faces() # faces defined by the same verts
117
+ ms.meshing_remove_null_faces() # faces with area == 0
118
+
119
+ if min_d > 0:
120
+ ms.meshing_remove_connected_component_by_diameter(
121
+ mincomponentdiag=pml.Percentage(min_d)
122
+ )
123
+
124
+ if min_f > 0:
125
+ ms.meshing_remove_connected_component_by_face_number(mincomponentsize=min_f)
126
+
127
+ if repair:
128
+ # ms.meshing_remove_t_vertices(method=0, threshold=40, repeat=True)
129
+ ms.meshing_repair_non_manifold_edges(method=0)
130
+ ms.meshing_repair_non_manifold_vertices(vertdispratio=0)
131
+
132
+ if remesh:
133
+ # ms.apply_coord_taubin_smoothing()
134
+ ms.meshing_isotropic_explicit_remeshing(
135
+ iterations=3, targetlen=pml.AbsoluteValue(remesh_size)
136
+ )
137
+
138
+ # extract mesh
139
+ m = ms.current_mesh()
140
+ verts = m.vertex_matrix()
141
+ faces = m.face_matrix()
142
+
143
+ print(
144
+ f"[INFO] mesh cleaning: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}"
145
+ )
146
+
147
+ return verts, faces
sparseags/render_utils/gs_renderer.py ADDED
@@ -0,0 +1,1102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import numpy as np
4
+ from typing import NamedTuple
5
+ from plyfile import PlyData, PlyElement
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+ from liegroups.torch import SE3
11
+ from simple_knn._C import distCUDA2
12
+
13
+ from sparseags.sh_utils import eval_sh, SH2RGB, RGB2SH
14
+ from sparseags.mesh_utils.mesh import Mesh
15
+ from sparseags.mesh_utils.mesh_utils import decimate_mesh, clean_mesh
16
+ from sparseags.cam_utils import sample_points_from_voxel
17
+
18
+ import kiui
19
+
20
+
21
+ def inverse_sigmoid(x):
22
+ return torch.log(x/(1-x))
23
+
24
+
25
+ def get_expon_lr_func(
26
+ lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
27
+ ):
28
+
29
+ def helper(step):
30
+ if lr_init == lr_final:
31
+ # constant lr, ignore other params
32
+ return lr_init
33
+ if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
34
+ # Disable this parameter
35
+ return 0.0
36
+ if lr_delay_steps > 0:
37
+ # A kind of reverse cosine decay.
38
+ delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
39
+ 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
40
+ )
41
+ else:
42
+ delay_rate = 1.0
43
+ t = np.clip(step / max_steps, 0, 1)
44
+ log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
45
+ return delay_rate * log_lerp
46
+
47
+ return helper
48
+
49
+
50
+ def strip_lowerdiag(L):
51
+ uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
52
+
53
+ uncertainty[:, 0] = L[:, 0, 0]
54
+ uncertainty[:, 1] = L[:, 0, 1]
55
+ uncertainty[:, 2] = L[:, 0, 2]
56
+ uncertainty[:, 3] = L[:, 1, 1]
57
+ uncertainty[:, 4] = L[:, 1, 2]
58
+ uncertainty[:, 5] = L[:, 2, 2]
59
+ return uncertainty
60
+
61
+ def strip_symmetric(sym):
62
+ return strip_lowerdiag(sym)
63
+
64
+
65
+ def gaussian_3d_coeff(xyzs, covs):
66
+ # xyzs: [N, 3]
67
+ # covs: [N, 6]
68
+ x, y, z = xyzs[:, 0], xyzs[:, 1], xyzs[:, 2]
69
+ a, b, c, d, e, f = covs[:, 0], covs[:, 1], covs[:, 2], covs[:, 3], covs[:, 4], covs[:, 5]
70
+
71
+ # eps must be small enough !!!
72
+ inv_det = 1 / (a * d * f + 2 * e * c * b - e**2 * a - c**2 * d - b**2 * f + 1e-24)
73
+ inv_a = (d * f - e**2) * inv_det
74
+ inv_b = (e * c - b * f) * inv_det
75
+ inv_c = (e * b - c * d) * inv_det
76
+ inv_d = (a * f - c**2) * inv_det
77
+ inv_e = (b * c - e * a) * inv_det
78
+ inv_f = (a * d - b**2) * inv_det
79
+
80
+ power = -0.5 * (x**2 * inv_a + y**2 * inv_d + z**2 * inv_f) - x * y * inv_b - x * z * inv_c - y * z * inv_e
81
+
82
+ power[power > 0] = -1e10 # abnormal values... make weights 0
83
+
84
+ return torch.exp(power)
85
+
86
+
87
+ def build_rotation(r):
88
+ norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
89
+
90
+ q = r / norm[:, None]
91
+
92
+ R = torch.zeros((q.size(0), 3, 3), device='cuda')
93
+
94
+ r = q[:, 0]
95
+ x = q[:, 1]
96
+ y = q[:, 2]
97
+ z = q[:, 3]
98
+
99
+ R[:, 0, 0] = 1 - 2 * (y*y + z*z)
100
+ R[:, 0, 1] = 2 * (x*y - r*z)
101
+ R[:, 0, 2] = 2 * (x*z + r*y)
102
+ R[:, 1, 0] = 2 * (x*y + r*z)
103
+ R[:, 1, 1] = 1 - 2 * (x*x + z*z)
104
+ R[:, 1, 2] = 2 * (y*z - r*x)
105
+ R[:, 2, 0] = 2 * (x*z - r*y)
106
+ R[:, 2, 1] = 2 * (y*z + r*x)
107
+ R[:, 2, 2] = 1 - 2 * (x*x + y*y)
108
+ return R
109
+
110
+
111
+ def build_scaling_rotation(s, r):
112
+ L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
113
+ R = build_rotation(r)
114
+
115
+ L[:,0,0] = s[:,0]
116
+ L[:,1,1] = s[:,1]
117
+ L[:,2,2] = s[:,2]
118
+
119
+ L = R @ L
120
+ return L
121
+
122
+
123
+ class BasicPointCloud(NamedTuple):
124
+ points: np.array
125
+ colors: np.array
126
+ normals: np.array
127
+
128
+
129
+ class GaussianModel:
130
+
131
+ def setup_functions(self):
132
+ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
133
+ L = build_scaling_rotation(scaling_modifier * scaling, rotation)
134
+ actual_covariance = L @ L.transpose(1, 2)
135
+ symm = strip_symmetric(actual_covariance)
136
+ return symm
137
+
138
+ self.scaling_activation = torch.exp
139
+ self.scaling_inverse_activation = torch.log
140
+
141
+ self.covariance_activation = build_covariance_from_scaling_rotation
142
+
143
+ self.opacity_activation = torch.sigmoid
144
+ self.inverse_opacity_activation = inverse_sigmoid
145
+
146
+ self.rotation_activation = torch.nn.functional.normalize
147
+
148
+
149
+ def __init__(self, sh_degree : int):
150
+ self.active_sh_degree = 0
151
+ self.max_sh_degree = sh_degree
152
+ self._xyz = torch.empty(0)
153
+ self._features_dc = torch.empty(0)
154
+ self._features_rest = torch.empty(0)
155
+ self._scaling = torch.empty(0)
156
+ self._rotation = torch.empty(0)
157
+ self._opacity = torch.empty(0)
158
+ self.max_radii2D = torch.empty(0)
159
+ self.xyz_gradient_accum = torch.empty(0)
160
+ self.denom = torch.empty(0)
161
+ self.optimizer = None
162
+ self.percent_dense = 0
163
+ self.spatial_lr_scale = 0
164
+ self.setup_functions()
165
+
166
+ def capture(self):
167
+ return (
168
+ self.active_sh_degree,
169
+ self._xyz,
170
+ self._features_dc,
171
+ self._features_rest,
172
+ self._scaling,
173
+ self._rotation,
174
+ self._opacity,
175
+ self.max_radii2D,
176
+ self.xyz_gradient_accum,
177
+ self.denom,
178
+ self.optimizer.state_dict(),
179
+ self.spatial_lr_scale,
180
+ )
181
+
182
+ def restore(self, model_args, training_args):
183
+ (self.active_sh_degree,
184
+ self._xyz,
185
+ self._features_dc,
186
+ self._features_rest,
187
+ self._scaling,
188
+ self._rotation,
189
+ self._opacity,
190
+ self.max_radii2D,
191
+ xyz_gradient_accum,
192
+ denom,
193
+ opt_dict,
194
+ self.spatial_lr_scale) = model_args
195
+ self.training_setup(training_args)
196
+ self.xyz_gradient_accum = xyz_gradient_accum
197
+ self.denom = denom
198
+ self.optimizer.load_state_dict(opt_dict)
199
+
200
+ @property
201
+ def get_scaling(self):
202
+ return self.scaling_activation(self._scaling)
203
+
204
+ @property
205
+ def get_rotation(self):
206
+ return self.rotation_activation(self._rotation)
207
+
208
+ @property
209
+ def get_xyz(self):
210
+ return self._xyz
211
+
212
+ @property
213
+ def get_features(self):
214
+ features_dc = self._features_dc
215
+ features_rest = self._features_rest
216
+ if self.enable_dino:
217
+ return torch.cat((features_dc, features_rest[..., :3]), dim=1), features_rest[..., 3:].reshape(features_rest.shape[0], 1, -1)[..., :self.dino_feat_dim]
218
+ else:
219
+ return torch.cat((features_dc, features_rest), dim=1)
220
+
221
+ @property
222
+ def get_opacity(self):
223
+ return self.opacity_activation(self._opacity)
224
+
225
+ @torch.no_grad()
226
+ def extract_fields(self, resolution=128, num_blocks=16, relax_ratio=1.5):
227
+ # resolution: resolution of field
228
+
229
+ block_size = 2 / num_blocks
230
+
231
+ assert resolution % block_size == 0
232
+ split_size = resolution // num_blocks
233
+
234
+ opacities = self.get_opacity
235
+
236
+ # pre-filter low opacity gaussians to save computation
237
+ mask = (opacities > 0.005).squeeze(1)
238
+
239
+ opacities = opacities[mask]
240
+ xyzs = self.get_xyz[mask]
241
+ stds = self.get_scaling[mask]
242
+
243
+ # normalize to ~ [-1, 1]
244
+ mn, mx = xyzs.amin(0), xyzs.amax(0)
245
+ self.center = (mn + mx) / 2
246
+ self.scale = 1.8 / (mx - mn).amax().item()
247
+
248
+ xyzs = (xyzs - self.center) * self.scale
249
+ stds = stds * self.scale
250
+
251
+ covs = self.covariance_activation(stds, 1, self._rotation[mask])
252
+
253
+ # tile
254
+ device = opacities.device
255
+ occ = torch.zeros([resolution] * 3, dtype=torch.float32, device=device)
256
+
257
+ X = torch.linspace(-1, 1, resolution).split(split_size)
258
+ Y = torch.linspace(-1, 1, resolution).split(split_size)
259
+ Z = torch.linspace(-1, 1, resolution).split(split_size)
260
+
261
+ # loop blocks (assume max size of gaussian is small than relax_ratio * block_size !!!)
262
+ for xi, xs in enumerate(X):
263
+ for yi, ys in enumerate(Y):
264
+ for zi, zs in enumerate(Z):
265
+ xx, yy, zz = torch.meshgrid(xs, ys, zs)
266
+ # sample points [M, 3]
267
+ pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1).to(device)
268
+ # in-tile gaussians mask
269
+ vmin, vmax = pts.amin(0), pts.amax(0)
270
+ vmin -= block_size * relax_ratio
271
+ vmax += block_size * relax_ratio
272
+ mask = (xyzs < vmax).all(-1) & (xyzs > vmin).all(-1)
273
+ # if hit no gaussian, continue to next block
274
+ if not mask.any():
275
+ continue
276
+ mask_xyzs = xyzs[mask] # [L, 3]
277
+ mask_covs = covs[mask] # [L, 6]
278
+ mask_opas = opacities[mask].view(1, -1) # [L, 1] --> [1, L]
279
+
280
+ # query per point-gaussian pair.
281
+ g_pts = pts.unsqueeze(1).repeat(1, mask_covs.shape[0], 1) - mask_xyzs.unsqueeze(0) # [M, L, 3]
282
+ g_covs = mask_covs.unsqueeze(0).repeat(pts.shape[0], 1, 1) # [M, L, 6]
283
+
284
+ # batch on gaussian to avoid OOM
285
+ batch_g = 1024
286
+ val = 0
287
+ for start in range(0, g_covs.shape[1], batch_g):
288
+ end = min(start + batch_g, g_covs.shape[1])
289
+ w = gaussian_3d_coeff(g_pts[:, start:end].reshape(-1, 3), g_covs[:, start:end].reshape(-1, 6)).reshape(pts.shape[0], -1) # [M, l]
290
+ val += (mask_opas[:, start:end] * w).sum(-1)
291
+
292
+ # kiui.lo(val, mask_opas, w)
293
+
294
+ occ[xi * split_size: xi * split_size + len(xs),
295
+ yi * split_size: yi * split_size + len(ys),
296
+ zi * split_size: zi * split_size + len(zs)] = val.reshape(len(xs), len(ys), len(zs))
297
+
298
+ kiui.lo(occ, verbose=1)
299
+
300
+ return occ
301
+
302
+ def extract_mesh(self, path, density_thresh=1, resolution=128, decimate_target=1e5):
303
+
304
+ os.makedirs(os.path.dirname(path), exist_ok=True)
305
+
306
+ occ = self.extract_fields(resolution).detach().cpu().numpy()
307
+
308
+ import mcubes
309
+ vertices, triangles = mcubes.marching_cubes(occ, density_thresh)
310
+ vertices = vertices / (resolution - 1.0) * 2 - 1
311
+
312
+ # transform back to the original space
313
+ vertices = vertices / self.scale + self.center.detach().cpu().numpy()
314
+
315
+ vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=0.015)
316
+ if decimate_target > 0 and triangles.shape[0] > decimate_target:
317
+ vertices, triangles = decimate_mesh(vertices, triangles, decimate_target)
318
+
319
+ v = torch.from_numpy(vertices.astype(np.float32)).contiguous().cuda()
320
+ f = torch.from_numpy(triangles.astype(np.int32)).contiguous().cuda()
321
+
322
+ print(
323
+ f"[INFO] marching cubes result: {v.shape} ({v.min().item()}-{v.max().item()}), {f.shape}"
324
+ )
325
+
326
+ mesh = Mesh(v=v, f=f, device='cuda')
327
+
328
+ return mesh
329
+
330
+ def get_covariance(self, scaling_modifier = 1):
331
+ return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
332
+
333
+ def oneupSHdegree(self):
334
+ if self.active_sh_degree < self.max_sh_degree:
335
+ self.active_sh_degree += 1
336
+
337
+ def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float = 1):
338
+ self.spatial_lr_scale = spatial_lr_scale
339
+ fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
340
+ fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
341
+ features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
342
+ features[:, :3, 0 ] = fused_color
343
+ features[:, 3:, 1:] = 0.0
344
+
345
+ print("Number of points at initialisation : ", fused_point_cloud.shape[0])
346
+
347
+ dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
348
+ scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
349
+ rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
350
+ rots[:, 0] = 1
351
+
352
+ opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
353
+
354
+ self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
355
+ self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
356
+ if self.enable_dino:
357
+ # Overide the original features
358
+ _features_rest = features[:,:,1:].transpose(1, 2).contiguous().cuda()
359
+ dim_rest = _features_rest.shape[1]
360
+ _semantic_features = torch.randn(self._xyz.shape[0], dim_rest, self.dino_feat_dim//dim_rest + 1).cuda()
361
+ self._features_rest = nn.Parameter(torch.cat([_features_rest, _semantic_features], dim=-1).requires_grad_(True))
362
+ else:
363
+ self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
364
+ self._scaling = nn.Parameter(scales.requires_grad_(True))
365
+ self._rotation = nn.Parameter(rots.requires_grad_(True))
366
+ self._opacity = nn.Parameter(opacities.requires_grad_(True))
367
+ self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
368
+
369
+ def training_setup(self, training_args):
370
+ self.percent_dense = training_args.percent_dense
371
+ self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
372
+ self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
373
+
374
+ l = [
375
+ {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
376
+ {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
377
+ {'params': [self._features_rest], 'lr': training_args.feature_lr / 20, "name": "f_rest"}, # /20
378
+ {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
379
+ {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"},
380
+ {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"},
381
+ ]
382
+
383
+ if training_args.opt_cam:
384
+ l.append({'params': self.cam_params, 'lr': training_args.camera_lr, "name": "cam_params"})
385
+
386
+ self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
387
+ self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
388
+ lr_final=training_args.position_lr_final*self.spatial_lr_scale,
389
+ lr_delay_mult=training_args.position_lr_delay_mult,
390
+ max_steps=training_args.position_lr_max_steps)
391
+
392
+ def update_learning_rate(self, iteration):
393
+ ''' Learning rate scheduling per step '''
394
+ for param_group in self.optimizer.param_groups:
395
+ if param_group["name"] == "xyz":
396
+ if iteration > 500:
397
+ iteration = iteration % 500
398
+ lr = self.xyz_scheduler_args(iteration)
399
+ param_group['lr'] = lr
400
+
401
+ def construct_list_of_attributes(self):
402
+ l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
403
+ # All channels except the 3 DC
404
+ for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
405
+ l.append('f_dc_{}'.format(i))
406
+ for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
407
+ l.append('f_rest_{}'.format(i))
408
+ l.append('opacity')
409
+ for i in range(self._scaling.shape[1]):
410
+ l.append('scale_{}'.format(i))
411
+ for i in range(self._rotation.shape[1]):
412
+ l.append('rot_{}'.format(i))
413
+ return l
414
+
415
+ def save_ply(self, path):
416
+ os.makedirs(os.path.dirname(path), exist_ok=True)
417
+
418
+ xyz = self._xyz.detach().cpu().numpy()
419
+ normals = np.zeros_like(xyz)
420
+ f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
421
+ f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
422
+ opacities = self._opacity.detach().cpu().numpy()
423
+ scale = self._scaling.detach().cpu().numpy()
424
+ rotation = self._rotation.detach().cpu().numpy()
425
+
426
+ dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
427
+
428
+ elements = np.empty(xyz.shape[0], dtype=dtype_full)
429
+ attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
430
+ elements[:] = list(map(tuple, attributes))
431
+ el = PlyElement.describe(elements, 'vertex')
432
+ PlyData([el]).write(path)
433
+
434
+ def reset_opacity(self):
435
+ opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
436
+ optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
437
+ self._opacity = optimizable_tensors["opacity"]
438
+
439
+ def load_ply(self, path):
440
+ plydata = PlyData.read(path)
441
+
442
+ xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
443
+ np.asarray(plydata.elements[0]["y"]),
444
+ np.asarray(plydata.elements[0]["z"])), axis=1)
445
+ opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
446
+
447
+ print("Number of points at loading : ", xyz.shape[0])
448
+
449
+ features_dc = np.zeros((xyz.shape[0], 3, 1))
450
+ features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
451
+ features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
452
+ features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
453
+
454
+ extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
455
+ assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
456
+ features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
457
+ for idx, attr_name in enumerate(extra_f_names):
458
+ features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
459
+ # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
460
+ features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
461
+
462
+ scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
463
+ scales = np.zeros((xyz.shape[0], len(scale_names)))
464
+ for idx, attr_name in enumerate(scale_names):
465
+ scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
466
+
467
+ rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
468
+ rots = np.zeros((xyz.shape[0], len(rot_names)))
469
+ for idx, attr_name in enumerate(rot_names):
470
+ rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
471
+
472
+ self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
473
+ self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
474
+ self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
475
+ self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
476
+ self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
477
+ self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
478
+
479
+ self.active_sh_degree = self.max_sh_degree
480
+
481
+ def replace_tensor_to_optimizer(self, tensor, name):
482
+ optimizable_tensors = {}
483
+ for group in self.optimizer.param_groups:
484
+ if len(group["params"]) != 1:
485
+ continue
486
+ if group["name"] == name:
487
+ stored_state = self.optimizer.state.get(group['params'][0], None)
488
+ stored_state["exp_avg"] = torch.zeros_like(tensor)
489
+ stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
490
+
491
+ del self.optimizer.state[group['params'][0]]
492
+ group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
493
+ self.optimizer.state[group['params'][0]] = stored_state
494
+
495
+ optimizable_tensors[group["name"]] = group["params"][0]
496
+ return optimizable_tensors
497
+
498
+ def _prune_optimizer(self, mask):
499
+ optimizable_tensors = {}
500
+ for group in self.optimizer.param_groups:
501
+ if len(group["params"]) != 1:
502
+ continue
503
+
504
+ stored_state = self.optimizer.state.get(group['params'][0], None)
505
+ if stored_state is not None:
506
+ stored_state["exp_avg"] = stored_state["exp_avg"][mask]
507
+ stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
508
+
509
+ del self.optimizer.state[group['params'][0]]
510
+ group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
511
+ self.optimizer.state[group['params'][0]] = stored_state
512
+
513
+ optimizable_tensors[group["name"]] = group["params"][0]
514
+ else:
515
+ group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
516
+ optimizable_tensors[group["name"]] = group["params"][0]
517
+ return optimizable_tensors
518
+
519
+ def prune_points(self, mask):
520
+ valid_points_mask = ~mask
521
+ optimizable_tensors = self._prune_optimizer(valid_points_mask)
522
+
523
+ self._xyz = optimizable_tensors["xyz"]
524
+ self._features_dc = optimizable_tensors["f_dc"]
525
+ self._features_rest = optimizable_tensors["f_rest"]
526
+ self._opacity = optimizable_tensors["opacity"]
527
+ self._scaling = optimizable_tensors["scaling"]
528
+ self._rotation = optimizable_tensors["rotation"]
529
+
530
+ self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
531
+
532
+ self.denom = self.denom[valid_points_mask]
533
+ self.max_radii2D = self.max_radii2D[valid_points_mask]
534
+
535
+ def cat_tensors_to_optimizer(self, tensors_dict):
536
+ optimizable_tensors = {}
537
+ for group in self.optimizer.param_groups:
538
+ if len(group["params"]) != 1:
539
+ continue
540
+ assert len(group["params"]) == 1
541
+ extension_tensor = tensors_dict[group["name"]]
542
+ stored_state = self.optimizer.state.get(group['params'][0], None)
543
+ if stored_state is not None:
544
+
545
+ stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
546
+ stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
547
+
548
+ del self.optimizer.state[group['params'][0]]
549
+ group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
550
+ self.optimizer.state[group['params'][0]] = stored_state
551
+
552
+ optimizable_tensors[group["name"]] = group["params"][0]
553
+ else:
554
+ group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
555
+ optimizable_tensors[group["name"]] = group["params"][0]
556
+
557
+ return optimizable_tensors
558
+
559
+ def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation):
560
+ d = {"xyz": new_xyz,
561
+ "f_dc": new_features_dc,
562
+ "f_rest": new_features_rest,
563
+ "opacity": new_opacities,
564
+ "scaling" : new_scaling,
565
+ "rotation" : new_rotation}
566
+
567
+ optimizable_tensors = self.cat_tensors_to_optimizer(d)
568
+ self._xyz = optimizable_tensors["xyz"]
569
+ self._features_dc = optimizable_tensors["f_dc"]
570
+ self._features_rest = optimizable_tensors["f_rest"]
571
+ self._opacity = optimizable_tensors["opacity"]
572
+ self._scaling = optimizable_tensors["scaling"]
573
+ self._rotation = optimizable_tensors["rotation"]
574
+
575
+ self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
576
+ self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
577
+ self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
578
+
579
+ def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
580
+ n_init_points = self.get_xyz.shape[0]
581
+ # Extract points that satisfy the gradient condition
582
+ padded_grad = torch.zeros((n_init_points), device="cuda")
583
+ padded_grad[:grads.shape[0]] = grads.squeeze()
584
+ selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
585
+ selected_pts_mask = torch.logical_and(selected_pts_mask,
586
+ torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent
587
+ )
588
+
589
+ stds = self.get_scaling[selected_pts_mask].repeat(N,1)
590
+ means =torch.zeros((stds.size(0), 3),device="cuda")
591
+ samples = torch.normal(mean=means, std=stds)
592
+ rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
593
+ new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
594
+ new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))
595
+ new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
596
+ new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
597
+ new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
598
+ new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
599
+
600
+ self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation)
601
+
602
+ prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
603
+ self.prune_points(prune_filter)
604
+
605
+ def densify_and_clone(self, grads, grad_threshold, scene_extent):
606
+ # Extract points that satisfy the gradient condition
607
+ selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
608
+ selected_pts_mask = torch.logical_and(selected_pts_mask,
609
+ torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent
610
+ )
611
+
612
+ new_xyz = self._xyz[selected_pts_mask]
613
+ new_features_dc = self._features_dc[selected_pts_mask]
614
+ new_features_rest = self._features_rest[selected_pts_mask]
615
+ new_opacities = self._opacity[selected_pts_mask]
616
+ new_scaling = self._scaling[selected_pts_mask]
617
+ new_rotation = self._rotation[selected_pts_mask]
618
+
619
+ self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation)
620
+
621
+ def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
622
+ grads = self.xyz_gradient_accum / self.denom
623
+ grads[grads.isnan()] = 0.0
624
+
625
+ self.densify_and_clone(grads, max_grad, extent)
626
+ self.densify_and_split(grads, max_grad, extent)
627
+
628
+ prune_mask = (self.get_opacity < min_opacity).squeeze()
629
+ if max_screen_size:
630
+ big_points_vs = self.max_radii2D > max_screen_size
631
+ big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
632
+ prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
633
+ self.prune_points(prune_mask)
634
+
635
+ torch.cuda.empty_cache()
636
+
637
+ def prune(self, min_opacity, extent, max_screen_size):
638
+
639
+ prune_mask = (self.get_opacity < min_opacity).squeeze()
640
+ if max_screen_size:
641
+ big_points_vs = self.max_radii2D > max_screen_size
642
+ big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
643
+ prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
644
+ self.prune_points(prune_mask)
645
+
646
+ torch.cuda.empty_cache()
647
+
648
+
649
+ def add_densification_stats(self, viewspace_point_tensor, update_filter):
650
+ self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True)
651
+ self.denom[update_filter] += 1
652
+
653
+
654
+ def getProjectionMatrix(znear, zfar, fx, fy, cx, cy):
655
+ # TODO: remove hard-coded image size
656
+ P = torch.zeros(4, 4)
657
+
658
+ z_sign = 1.0
659
+
660
+ P[0, 0] = 2 * fx / 256
661
+ P[1, 1] = 2 * fy / 256
662
+ P[0, 2] = 2 * (cx / 256) - 1
663
+ P[1, 2] = 2 * (cy / 256) - 1
664
+ P[2, 2] = z_sign * zfar / (zfar - znear)
665
+ P[3, 2] = z_sign
666
+ P[2, 3] = -(zfar * znear) / (zfar - znear)
667
+ return P
668
+
669
+
670
+ def getProjectionMatrixFoV(znear, zfar, fovX, fovY):
671
+ tanHalfFovY = math.tan((fovY / 2))
672
+ tanHalfFovX = math.tan((fovX / 2))
673
+
674
+ P = torch.zeros(4, 4)
675
+
676
+ z_sign = 1.0
677
+
678
+ P[0, 0] = 1 / tanHalfFovX
679
+ P[1, 1] = 1 / tanHalfFovY
680
+ P[3, 2] = z_sign
681
+ P[2, 2] = z_sign * zfar / (zfar - znear)
682
+ P[2, 3] = -(zfar * znear) / (zfar - znear)
683
+ return P
684
+
685
+
686
+ class Camera:
687
+ def __init__(self, c2w, width, height, fx, fy, cx, cy, znear=0.01, zfar=100, opt_pose=False):
688
+ # c2w (pose) should be in NeRF convention.
689
+
690
+ self.image_width = width
691
+ self.image_height = height
692
+ self.fx, self.fy = fx, fy
693
+ self.cx, self.cy = cx, cy
694
+ self.FoVy = 2 * np.arctan(256 / 2 / self.fy)
695
+ self.FoVx = 2 * np.arctan(256 / 2 / self.fx)
696
+ self.znear = znear
697
+ self.zfar = zfar
698
+ self.opt_pose = opt_pose
699
+
700
+ self.projection_matrix = (
701
+ getProjectionMatrix(
702
+ znear=self.znear,
703
+ zfar=self.zfar,
704
+ fx=self.fx,
705
+ fy=self.fy,
706
+ cx=self.cx,
707
+ cy=self.cy,
708
+ )
709
+ .transpose(0, 1)
710
+ .cuda()
711
+ )
712
+
713
+ w2c = np.linalg.inv(c2w)
714
+
715
+ # OpenGL to OpenCV
716
+ w2c[1:3] *= -1
717
+
718
+ self.world_view_transform = torch.tensor(w2c).transpose(0, 1).cuda()
719
+ self.full_proj_transform = self.world_view_transform @ self.projection_matrix
720
+ self.camera_center = torch.tensor(c2w[:3, 3]).cuda()
721
+
722
+
723
+ class FoVCamera:
724
+ def __init__(self, c2w, width, height, fovy, fovx, znear, zfar, cam_params=None, opt_pose=False):
725
+ # c2w (pose) should be in NeRF convention.
726
+
727
+ self.image_width = width
728
+ self.image_height = height
729
+ self.FoVy = fovy
730
+ self.FoVx = fovx
731
+ self.znear = znear
732
+ self.zfar = zfar
733
+ self.opt_pose = opt_pose
734
+
735
+ self.projection_matrix = (
736
+ getProjectionMatrixFoV(
737
+ znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy
738
+ )
739
+ .transpose(0, 1)
740
+ .cuda()
741
+ )
742
+
743
+ w2c = np.linalg.inv(c2w)
744
+
745
+ # OpenGL to OpenCV
746
+ w2c[1:3] *= -1
747
+
748
+ self.world_view_transform = torch.tensor(w2c).transpose(0, 1).cuda()
749
+ self.full_proj_transform = self.world_view_transform @ self.projection_matrix
750
+ self.camera_center = torch.tensor(c2w[:3, 3]).cuda()
751
+
752
+
753
+ class CustomCamera:
754
+ def __init__(self, cam_params=None, index=None, c2w=None, opt_pose=False):
755
+ # TODO: remove hard-coded image size
756
+ # c2w (pose) should be in NeRF convention.
757
+ # This this the camera class that supports pose optimization.
758
+
759
+ self.image_width, self.image_height = 256, 256
760
+ self.fx, self.fy = cam_params["focal_length"]
761
+ self.cx, self.cy = cam_params["principal_point"]
762
+ self.FoVy = 2 * np.arctan(self.image_height / 2 / self.fy)
763
+ self.FoVx = 2 * np.arctan(self.image_width / 2 / self.fx)
764
+ self.R = torch.tensor(cam_params["R"])
765
+ self.T = torch.tensor(cam_params["T"])
766
+ self.znear = 0.01
767
+ self.zfar = 100
768
+ self.opt_pose = opt_pose
769
+ self.index = index
770
+
771
+ self.projection_matrix = (
772
+ getProjectionMatrix(
773
+ znear=self.znear,
774
+ zfar=self.zfar,
775
+ fx=self.fx,
776
+ fy=self.fy,
777
+ cx=self.cx,
778
+ cy=self.cy,
779
+ )
780
+ .transpose(0, 1)
781
+ .cuda()
782
+ )
783
+
784
+ if not opt_pose:
785
+ if c2w:
786
+ w2c = torch.from_numpy(c2w)
787
+ w2c[1:3] *= -1 # OpenGL to OpenCV
788
+ else:
789
+ R = self.R.T # note the transpose here
790
+ T = self.T
791
+ upper = torch.cat([R, T[:, None]], dim=1) # Upper 3x4 part of the matrix
792
+ lower = torch.tensor([[0, 0, 0, 1]], device=R.device, dtype=R.dtype) # Last row
793
+ w2c = torch.cat([upper, lower], dim=0)
794
+
795
+ w2c[:2] *= -1 # PyTorch3D to OpenCV
796
+
797
+ self.w2c = w2c
798
+ self.cam_params = torch.zeros(6)
799
+ self.world_view_transform = w2c.transpose(0, 1).cuda()
800
+ self.full_proj_transform = self.world_view_transform @ self.projection_matrix
801
+ self.camera_center = self.world_view_transform.inverse()[3, :3]
802
+ else:
803
+ R = self.R.T # note the transpose here
804
+ T = self.T
805
+ upper = torch.cat([R, T[:, None]], dim=1) # Upper 3x4 part of the matrix
806
+ lower = torch.tensor([[0, 0, 0, 1]], device=R.device, dtype=R.dtype) # Last row
807
+ w2c = torch.cat([upper, lower], dim=0)
808
+
809
+ w2c[:2] *= -1 # PyTorch3D to OpenCV
810
+
811
+ self.w2c = w2c
812
+ self.cam_params = torch.randn(6) * 1e-6
813
+ self.cam_params.requires_grad_()
814
+
815
+ self.world_view_transform = w2c.transpose(0, 1).cuda()
816
+ self.full_proj_transform = self.world_view_transform @ self.projection_matrix
817
+ self.camera_center = self.world_view_transform.inverse()[3, :3]
818
+
819
+ @property
820
+ def perspective(self):
821
+ P = torch.zeros(4, 4)
822
+
823
+ z_sign = -1.0
824
+
825
+ P[0, 0] = 2 * self.fx / 256
826
+ P[1, 1] = -2 * self.fy / 256
827
+ P[0, 2] = -(2 * (self.cx / 256) - 1)
828
+ P[1, 2] = -(2 * (self.cy / 256) - 1)
829
+ P[2, 2] = z_sign * self.zfar / (self.zfar - self.znear)
830
+ P[3, 2] = z_sign
831
+ P[2, 3] = -(self.zfar * self.znear) / (self.zfar - self.znear)
832
+ return P.numpy()
833
+
834
+ @property
835
+ def c2w(self):
836
+ if self.opt_pose:
837
+ w2c = self.w2c @ SE3.exp(self.cam_params.detach()).as_matrix()
838
+ w2c[1:3] *= -1 # OpenCV to OpenGL
839
+ else:
840
+ R = self.R.T # note the transpose here
841
+ T = self.T
842
+ upper = torch.cat([R, T[:, None]], dim=1) # Upper 3x4 part of the matrix
843
+ lower = torch.tensor([[0, 0, 0, 1]], device=R.device, dtype=R.dtype) # Last row
844
+ w2c = torch.cat([upper, lower], dim=0)
845
+ w2c[:2, :] *= -1 # PyTorch3D to OpenCV
846
+ w2c[1:3, :] *= -1 # OpenCV to OpenGL
847
+
848
+ return torch.inverse(w2c).numpy()
849
+
850
+ @property
851
+ def focal_length(self):
852
+ return np.array([self.fx, self.fy])
853
+
854
+ @property
855
+ def rotation(self):
856
+ w2c = self.w2c @ SE3.exp(self.cam_params.detach()).as_matrix()
857
+ w2c[:2] *= -1
858
+ return w2c[:3, :3].T
859
+
860
+ @property
861
+ def translation(self):
862
+ w2c = self.w2c @ SE3.exp(self.cam_params.detach()).as_matrix()
863
+ w2c[:2] *= -1
864
+ return w2c[:3, 3]
865
+
866
+
867
+ class Renderer:
868
+ def __init__(self, sh_degree=3, white_background=True, radius=1):
869
+
870
+ self.sh_degree = sh_degree
871
+ self.white_background = white_background
872
+ self.radius = radius
873
+ self.enable_dino = None
874
+
875
+ self.gaussians = GaussianModel(sh_degree)
876
+
877
+ self.bg_color = torch.tensor(
878
+ [1, 1, 1] if white_background else [0, 0, 0],
879
+ dtype=torch.float32,
880
+ device="cuda",
881
+ )
882
+
883
+ def initialize(self, input=None, num_pts=5000, radius=0.5, cameras=None, imgs=None, masks=None, point_maps=None, mode='sphere'):
884
+ # load checkpoint
885
+ if input is None and mode in ['sphere', "carve", "inverse_carve"]:
886
+ # init from random point cloud
887
+
888
+ if mode == 'sphere':
889
+ phis = np.random.random((num_pts,)) * 2 * np.pi
890
+ costheta = np.random.random((num_pts,)) * 2 - 1
891
+ thetas = np.arccos(costheta)
892
+ mu = np.random.random((num_pts,))
893
+ radius = radius * np.cbrt(mu)
894
+ x = radius * np.sin(thetas) * np.cos(phis)
895
+ y = radius * np.sin(thetas) * np.sin(phis)
896
+ z = radius * np.cos(thetas)
897
+ xyz = np.stack((x, y, z), axis=1)
898
+
899
+ elif mode == "carve":
900
+ try:
901
+ xyz = sample_points_from_voxel(cameras, masks, radius, N=num_pts).cpu().numpy()
902
+ except RuntimeError:
903
+ radius = 0.3
904
+ phis = np.random.random((num_pts,)) * 2 * np.pi
905
+ costheta = np.random.random((num_pts,)) * 2 - 1
906
+ thetas = np.arccos(costheta)
907
+ mu = np.random.random((num_pts,))
908
+ radius = radius * np.cbrt(mu)
909
+ x = radius * np.sin(thetas) * np.cos(phis)
910
+ y = radius * np.sin(thetas) * np.sin(phis)
911
+ z = radius * np.cos(thetas)
912
+ xyz = np.stack((x, y, z), axis=1)
913
+
914
+ elif mode == "inverse_carve":
915
+ xyz = sample_points_from_voxel(cameras, masks, radius, N=num_pts, inverse=True).cpu().numpy()
916
+
917
+ shs = np.random.random((num_pts, 3)) / 255.0
918
+ pcd = BasicPointCloud(
919
+ points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))
920
+ )
921
+ self.gaussians.create_from_pcd(pcd, 10)
922
+
923
+ elif input is None and mode == "dust3r":
924
+ num_points = sum([np.count_nonzero(masks[i]) for i in range(8)])
925
+ xyz = np.zeros((num_points, 3))
926
+ colors = np.zeros((num_points, 3))
927
+
928
+ # Iterate through data and add points to xyz and colors arrays
929
+ index = 0
930
+ for i in range(len(point_maps)):
931
+ rgb = imgs[i].reshape(-1, 3)
932
+ point_map = point_maps[i].reshape(-1, 3).detach().cpu().numpy()
933
+ for j, include_point in enumerate((masks[i] > 0.5).flatten()):
934
+ if include_point == 1:
935
+ xyz[index] = point_map[j]
936
+ colors[index] = rgb[j]
937
+ index += 1
938
+
939
+ # Check if index matches expected number of points
940
+ assert index == num_points, "Number of points does not match expected count"
941
+
942
+ pcd = BasicPointCloud(
943
+ points=xyz, colors=colors, normals=np.zeros((len(point_maps)*224*224, 3))
944
+ )
945
+ self.gaussians.create_from_pcd(pcd, 10)
946
+
947
+ elif isinstance(input, BasicPointCloud):
948
+ # load from a provided pcd
949
+ self.gaussians.create_from_pcd(input, 1)
950
+ else:
951
+ # load from saved ply
952
+ self.gaussians.load_ply(input)
953
+
954
+ def render(
955
+ self,
956
+ viewpoint_camera,
957
+ scaling_modifier=1.0,
958
+ bg_color=None,
959
+ override_color=None,
960
+ compute_cov3D_python=False,
961
+ convert_SHs_python=False,
962
+ ):
963
+ if self.enable_dino:
964
+ from diff_gaussian_rasterization_feature import (
965
+ GaussianRasterizationSettings,
966
+ GaussianRasterizer,
967
+ )
968
+ else:
969
+ from diff_gaussian_rasterization import (
970
+ GaussianRasterizationSettings,
971
+ GaussianRasterizer,
972
+ )
973
+
974
+ if viewpoint_camera.opt_pose:
975
+ w2c = viewpoint_camera.w2c @ SE3.exp(viewpoint_camera.cam_params).as_matrix()
976
+ w2c = w2c.to("cuda")
977
+
978
+ viewpoint_camera.world_view_transform = w2c.transpose(0, 1)
979
+ viewpoint_camera.full_proj_transform = viewpoint_camera.world_view_transform @ viewpoint_camera.projection_matrix
980
+ viewpoint_camera.camera_center = viewpoint_camera.world_view_transform.inverse()[3, :3]
981
+
982
+ # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
983
+ screenspace_points = (
984
+ torch.zeros_like(
985
+ self.gaussians.get_xyz,
986
+ dtype=self.gaussians.get_xyz.dtype,
987
+ requires_grad=True,
988
+ device="cuda",
989
+ )
990
+ + 0
991
+ )
992
+ try:
993
+ screenspace_points.retain_grad()
994
+ except:
995
+ pass
996
+
997
+ # Set up rasterization configuration
998
+ tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
999
+ tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
1000
+
1001
+ raster_settings = GaussianRasterizationSettings(
1002
+ image_height=int(viewpoint_camera.image_height),
1003
+ image_width=int(viewpoint_camera.image_width),
1004
+ tanfovx=tanfovx,
1005
+ tanfovy=tanfovy,
1006
+ bg=self.bg_color if bg_color is None else bg_color,
1007
+ scale_modifier=scaling_modifier,
1008
+ viewmatrix=viewpoint_camera.world_view_transform,
1009
+ perspectivematrix=viewpoint_camera.projection_matrix,
1010
+ projmatrix=viewpoint_camera.full_proj_transform,
1011
+ sh_degree=self.gaussians.active_sh_degree,
1012
+ campos=viewpoint_camera.camera_center,
1013
+ prefiltered=False,
1014
+ debug=False,
1015
+ )
1016
+
1017
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
1018
+
1019
+ means3D = self.gaussians.get_xyz
1020
+ means2D = screenspace_points
1021
+ opacity = self.gaussians.get_opacity
1022
+
1023
+ # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
1024
+ # scaling / rotation by the rasterizer.
1025
+ scales = None
1026
+ rotations = None
1027
+ cov3D_precomp = None
1028
+ if compute_cov3D_python:
1029
+ cov3D_precomp = self.gaussians.get_covariance(scaling_modifier)
1030
+ else:
1031
+ scales = self.gaussians.get_scaling
1032
+ rotations = self.gaussians.get_rotation
1033
+
1034
+ # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
1035
+ # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
1036
+ shs = None
1037
+ colors_precomp = None
1038
+ if colors_precomp is None:
1039
+ if convert_SHs_python:
1040
+ shs_view = self.gaussians.get_features.transpose(1, 2).view(
1041
+ -1, 3, (self.gaussians.max_sh_degree + 1) ** 2
1042
+ )
1043
+ dir_pp = self.gaussians.get_xyz - viewpoint_camera.camera_center.repeat(
1044
+ self.gaussians.get_features.shape[0], 1
1045
+ )
1046
+ dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
1047
+ sh2rgb = eval_sh(
1048
+ self.gaussians.active_sh_degree, shs_view, dir_pp_normalized
1049
+ )
1050
+ colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
1051
+ else:
1052
+ shs = self.gaussians.get_features
1053
+ else:
1054
+ colors_precomp = override_color
1055
+
1056
+ if self.enable_dino:
1057
+ shs, semantic_feature = shs
1058
+
1059
+ rendered_image, rendered_feature, radii, rendered_depth, rendered_alpha = rasterizer(
1060
+ means3D=means3D,
1061
+ means2D=means2D,
1062
+ shs=shs,
1063
+ semantic_feature=semantic_feature,
1064
+ colors_precomp=colors_precomp,
1065
+ opacities=opacity,
1066
+ scales=scales,
1067
+ rotations=rotations,
1068
+ cov3D_precomp=cov3D_precomp,
1069
+ viewmat=viewpoint_camera.world_view_transform,
1070
+ )
1071
+
1072
+ else:
1073
+ # Rasterize visible Gaussians to image, obtain their radii (on screen).
1074
+ rendered_image, radii, rendered_depth, rendered_alpha = rasterizer(
1075
+ means3D=means3D,
1076
+ means2D=means2D,
1077
+ shs=shs,
1078
+ colors_precomp=colors_precomp,
1079
+ opacities=opacity,
1080
+ scales=scales,
1081
+ rotations=rotations,
1082
+ cov3D_precomp=cov3D_precomp,
1083
+ viewmat=viewpoint_camera.world_view_transform,
1084
+ )
1085
+
1086
+ rendered_image = rendered_image.clamp(0, 1)
1087
+
1088
+ # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
1089
+ # They will be excluded from value updates used in the splitting criteria.
1090
+ ret = {
1091
+ "image": rendered_image,
1092
+ "depth": rendered_depth,
1093
+ "alpha": rendered_alpha,
1094
+ "viewspace_points": screenspace_points,
1095
+ "visibility_filter": radii > 0,
1096
+ "radii": radii,
1097
+ }
1098
+
1099
+ if self.enable_dino:
1100
+ ret["feature"] = rendered_feature
1101
+
1102
+ return ret
sparseags/render_utils/util.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import gc
4
+ import copy
5
+ import tqdm
6
+ import torchvision
7
+ import shutil
8
+ import argparse
9
+ import numpy as np
10
+ from PIL import Image
11
+ from torchvision.utils import save_image
12
+ from omegaconf import OmegaConf
13
+ import matplotlib.pyplot as plt
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+
18
+ from kiui.lpips import LPIPS
19
+ from liegroups.torch import SE3
20
+
21
+ import sys
22
+ sys.path.append('./')
23
+
24
+ from sparseags.render_utils.gs_renderer import CustomCamera
25
+ from sparseags.mesh_utils.mesh_renderer import Renderer
26
+ from sparseags.cam_utils import OrbitCamera, mat2latlon
27
+
28
+
29
+ def safe_normalize(x):
30
+ return x / x.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-8)
31
+
32
+
33
+ def look_at(campos, target, opengl=True):
34
+ if not opengl:
35
+ forward_vector = safe_normalize(target - campos)
36
+ up_vector = torch.tensor([0, 1, 0], dtype=campos.dtype, device=campos.device).expand_as(forward_vector)
37
+ right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
38
+ up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1))
39
+ else:
40
+ forward_vector = safe_normalize(campos - target)
41
+ up_vector = torch.tensor([0, 1, 0], dtype=campos.dtype, device=campos.device).expand_as(forward_vector)
42
+ right_vector = safe_normalize(torch.cross(up_vector, forward_vector, dim=-1))
43
+ up_vector = safe_normalize(torch.cross(forward_vector, right_vector, dim=-1))
44
+ R = torch.stack([right_vector, up_vector, forward_vector], dim=-1)
45
+ return R
46
+
47
+
48
+ def orbit_camera(elevation, azimuth, radius=1.0, is_degree=True, target=None, opengl=True):
49
+ """Converts elevation & azimuth to a batch of camera pose matrices."""
50
+ if is_degree:
51
+ elevation = torch.deg2rad(elevation)
52
+ azimuth = torch.deg2rad(azimuth)
53
+ x = radius * torch.cos(elevation) * torch.sin(azimuth)
54
+ y = -radius * torch.sin(elevation)
55
+ z = radius * torch.cos(elevation) * torch.cos(azimuth)
56
+ if target is None:
57
+ target = torch.zeros(3, dtype=torch.float32, device=elevation.device)
58
+ campos = torch.stack([x, y, z], dim=-1) + target
59
+ R = look_at(campos, target.unsqueeze(0).expand_as(campos), opengl)
60
+ T = torch.eye(4, dtype=torch.float32, device=elevation.device).unsqueeze(0).expand(campos.shape[0], -1, -1).clone()
61
+ T[:, :3, :3] = R
62
+ T[:, :3, 3] = campos
63
+ return T
64
+
65
+
66
+ def render_and_compare(camera_data, mesh_path, save_path, num_views=8):
67
+ parser = argparse.ArgumentParser()
68
+ parser.add_argument('--object', type=str, help="path to mesh (obj, ply, glb, ...)")
69
+ parser.add_argument('--path', type=str, help="path to mesh (obj, ply, glb, ...)")
70
+ parser.add_argument('--front_dir', type=str, default='+z', help="mesh front-facing dir")
71
+ parser.add_argument('--mode', default='albedo', type=str, choices=['lambertian', 'albedo', 'normal', 'depth'], help="rendering mode")
72
+ parser.add_argument('--W', type=int, default=256, help="GUI width")
73
+ parser.add_argument('--H', type=int, default=256, help="GUI height")
74
+ parser.add_argument("--wogui", type=bool, default=True, help="disable all dpg GUI")
75
+ parser.add_argument("--force_cuda_rast", action='store_true', help="force to use RasterizeCudaContext.")
76
+ parser.add_argument("--config", default='configs/navi.yaml', help="path to the yaml config file")
77
+ parser.add_argument('--radius', type=float, default=3, help="default GUI camera radius from center")
78
+ parser.add_argument('--fovy', type=float, default=49.1, help="default GUI camera fovy")
79
+ args, extras = parser.parse_known_args()
80
+
81
+ # override default config from cli
82
+ opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras))
83
+ data = camera_data
84
+
85
+ opt.mesh = mesh_path
86
+ opt.trainable_texture = False
87
+ renderer = Renderer(opt).to(torch.device("cuda"))
88
+ target = renderer.mesh.v.mean(dim=0)
89
+
90
+ cameras = [CustomCamera(cam_params) for cam_params in data.values()]
91
+ # cams = [(cam.c2w, cam.perspective, cam.focal_length) for cam in cameras]
92
+ img_paths = [v["filepath"] for k, v in data.items()]
93
+ flags = [int(v["flag"]) for k, v in data.items()]
94
+
95
+ cam_centers = [mat2latlon(cam.camera_center - target) for idx, cam in enumerate(cameras) if flags[idx]]
96
+ ref_polars = [float(cam[0]) for cam in cam_centers]
97
+ ref_azimuths = [float(cam[1]) for cam in cam_centers]
98
+ ref_radii = [float(cam[2]) for cam in cam_centers]
99
+
100
+ base_cam = copy.copy(cameras[0])
101
+ base_cam.fx = np.array([cam.fx for idx, cam in enumerate(cameras) if flags[idx]], dtype=np.float32).mean()
102
+ base_cam.fy = np.array([cam.fy for idx, cam in enumerate(cameras) if flags[idx]], dtype=np.float32).mean()
103
+ base_cam.cx = 128
104
+ base_cam.cy = 128
105
+
106
+ lpips_loss = LPIPS(net='vgg').cuda()
107
+ elevation_range = (max([min(ref_polars) - 20, -89.9]), min([max(ref_polars) + 20, 89.9]))
108
+ azimuth_range = (-180, 180)
109
+ radius_range = (min(ref_radii) - 0.2, max(ref_radii) + 0.2)
110
+
111
+ elevation_steps = torch.arange(elevation_range[0], elevation_range[1], 15, dtype=torch.float32)
112
+ azimuth_steps = torch.arange(azimuth_range[0], azimuth_range[1], 15, dtype=torch.float32)
113
+ radius_steps = torch.arange(radius_range[0], radius_range[1], 0.2, dtype=torch.float32)
114
+ elevation_grid, azimuth_grid, radius_grid = torch.meshgrid(elevation_steps, azimuth_steps, radius_steps, indexing='ij')
115
+ pose_grid = torch.stack((elevation_grid.flatten(), azimuth_grid.flatten(), radius_grid.flatten()), dim=1)
116
+
117
+ poses = orbit_camera(pose_grid[:, 0], pose_grid[:, 1], pose_grid[:, 2], target=target.cpu())
118
+ print("Number of hypotheses:", poses.shape[0])
119
+ s1_steps = 128
120
+ s2_steps = 256
121
+ beta = 0.25
122
+ chunk_size = 512
123
+
124
+ for i in tqdm.tqdm(range(num_views)):
125
+ if flags[i]:
126
+ continue
127
+
128
+ pose_grid = torch.stack((elevation_grid.flatten(), azimuth_grid.flatten(), radius_grid.flatten()), dim=1)
129
+
130
+ poses = orbit_camera(pose_grid[:, 0], pose_grid[:, 1], pose_grid[:, 2], target=target.cpu())
131
+
132
+ img_path = img_paths[i]
133
+ base_cam.fx = cameras[i].fx
134
+ base_cam.fy = cameras[i].fy
135
+ perspectives = torch.from_numpy(base_cam.perspective).expand(pose_grid.shape[0], -1, -1)
136
+
137
+ learnable_cam_params = torch.randn(pose_grid.shape[0], 6) * 1e-6
138
+ learnable_cam_params.requires_grad_()
139
+
140
+ loss_MSE_grid = np.zeros(pose_grid.shape[0])
141
+ loss_LPIPS_grid = np.zeros(pose_grid.shape[0])
142
+ loss = 0
143
+
144
+ gt_img = Image.open(img_path)
145
+ if gt_img.mode == 'RGBA':
146
+ gt_img = np.asarray(gt_img, dtype=np.uint8).copy()
147
+ gt_mask = (gt_img[..., 3:] > 128).astype(np.float32)
148
+ gt_img[gt_img[:, :, -1] <= 255*0.9] = [255., 255., 255., 255.] # thresholding background
149
+ gt_img = gt_img[:, :, :3]
150
+
151
+ gt_tensor = torch.from_numpy(gt_img).float().unsqueeze(0).cuda() / 255.
152
+ gt_mask_tensor = torch.from_numpy(gt_mask).float().unsqueeze(0).cuda()
153
+
154
+ num_batches = pose_grid.shape[0] // chunk_size + int(pose_grid.shape[0]%chunk_size > 0)
155
+
156
+ # Render images for visualization
157
+ vis_img = torch.zeros(pose_grid.shape[0], 256, 256, 3)
158
+ for j in tqdm.tqdm(range(num_batches)):
159
+ batch_poses = poses[j*chunk_size:(j+1)*chunk_size]
160
+ batch_perspectives = perspectives[j*chunk_size:(j+1)*chunk_size]
161
+ with torch.no_grad():
162
+ out = renderer.render_batch(batch_poses, batch_perspectives, 256, 256, ssaa=1) # (500, 256, 256, 3)
163
+ # batch_image = (out["image"].detach().cpu().numpy() * 255).astype(np.uint8)
164
+ batch_image = out["image"].detach().cpu()
165
+ vis_img[j*chunk_size:(j+1)*chunk_size] = batch_image
166
+
167
+ l = [{'params': learnable_cam_params, 'lr': 5e-3, "name": "cam_params"}]
168
+ optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
169
+ scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
170
+
171
+ init_lr = optimizer.param_groups[0]['lr']
172
+ for j in tqdm.tqdm(range(num_batches)):
173
+ batch_poses = poses[j*chunk_size:(j+1)*chunk_size]
174
+ batch_perspectives = perspectives[j*chunk_size:(j+1)*chunk_size]
175
+ optimizer.param_groups[0]['lr'] = init_lr
176
+ for k in tqdm.tqdm(range(s1_steps)):
177
+ batch_residuals = SE3.exp(learnable_cam_params[j*chunk_size:(j+1)*chunk_size]).as_matrix() # [5760, 4, 4]
178
+ batch_poses_opt = torch.bmm(batch_poses, batch_residuals)
179
+ out = renderer.render_batch(batch_poses_opt, batch_perspectives, 256, 256, ssaa=1) # (500, 256, 256, 3)
180
+ pred_tensor = out["image"]
181
+ valid_mask = (out["alpha"] > 0) & (out["viewcos"] > 0.5) # (500, 256, 256, 1)
182
+
183
+ if k == s1_steps - 1:
184
+ loss = F.mse_loss(pred_tensor, gt_tensor.expand(pred_tensor.shape[0], -1, -1, -1), reduction='none').mean(dim=(1, 2, 3))
185
+ loss_MSE_grid[j*chunk_size:(j+1)*chunk_size] = loss.detach().cpu().numpy()
186
+ loss = loss.mean()
187
+
188
+ else:
189
+ loss = F.mse_loss(pred_tensor, gt_tensor.expand(pred_tensor.shape[0], -1, -1, -1), reduction='mean')
190
+
191
+ loss.backward()
192
+ optimizer.step()
193
+ optimizer.zero_grad()
194
+ scheduler.step()
195
+
196
+ # Render optimized images for visualization
197
+ # vis_img_optimized = torch.zeros(pose_grid.shape[0], 256, 256, 3)
198
+ # for j in tqdm.tqdm(range(num_batches)):
199
+ # batch_poses = poses[j*chunk_size:(j+1)*chunk_size]
200
+ # batch_perspectives = perspectives[j*chunk_size:(j+1)*chunk_size]
201
+ # batch_residuals = SE3.exp(learnable_cam_params[j*chunk_size:(j+1)*chunk_size]).as_matrix() # [5760, 4, 4]
202
+ # batch_poses_opt = torch.bmm(batch_poses, batch_residuals)
203
+ # with torch.no_grad():
204
+ # out = renderer.render_batch(batch_poses_opt, batch_perspectives, 256, 256, ssaa=1) # (500, 256, 256, 3)
205
+ # # batch_image = (out["image"].detach().cpu().numpy() * 255).astype(np.uint8)
206
+ # batch_image = out["image"].detach().cpu()
207
+ # vis_img_optimized[j*chunk_size:(j+1)*chunk_size] = batch_image
208
+
209
+ # indices = np.argsort(loss_MSE_grid)
210
+ # padding = (pose_grid.shape[0] // 10 + int(pose_grid.shape[0]%10 > 0)) * 10 - pose_grid.shape[0]
211
+ # grid = vis_img[indices].permute(0, 3, 1, 2).contiguous()
212
+ # padded_gird = torch.cat([grid, torch.ones(padding, 3, 256, 256)], dim=0)
213
+ # padded_gird = padded_gird.view((padding + pose_grid.shape[0]) // 10, 10, 3, 256, 256).permute(2, 0, 3, 1, 4)
214
+ # padded_gird = padded_gird.reshape(3, -1, 2560)
215
+ # output_path = os.path.join(save_path, f'vis1_candidates_{i}.png')
216
+ # save_image(padded_gird, output_path)
217
+
218
+ # grid = vis_img_optimized[indices].permute(0, 3, 1, 2).contiguous()
219
+ # padded_gird = torch.cat([grid, torch.ones(padding, 3, 256, 256)], dim=0)
220
+ # padded_gird = padded_gird.view((padding + pose_grid.shape[0]) // 10, 10, 3, 256, 256).permute(2, 0, 3, 1, 4)
221
+ # padded_gird = padded_gird.reshape(3, -1, 2560)
222
+ # output_path = os.path.join(save_path, f'vis1_optimized_candidates_{i}.png')
223
+ # save_image(padded_gird, output_path)
224
+
225
+ beta = 0.1
226
+ indices = np.argsort(loss_MSE_grid)[:max(int(loss_MSE_grid.shape[0] * beta), 64)]
227
+ batch_poses = poses[indices]
228
+ batch_residuals = SE3.exp(learnable_cam_params[indices].detach()).as_matrix() # [5760, 4, 4]
229
+ poses = torch.bmm(batch_poses, batch_residuals) # [216, 4, 4]
230
+ poses = poses.repeat(4, 1, 1)
231
+
232
+ learnable_cam_params = torch.randn(poses.shape[0], 6) * 1e-1
233
+ learnable_cam_params.requires_grad_()
234
+
235
+ optimizer.param_groups = []
236
+ optimizer.add_param_group({'params': learnable_cam_params})
237
+
238
+ perspectives = torch.from_numpy(cameras[i].perspective).expand(poses.shape[0], -1, -1)
239
+ loss_MSE_grid = np.zeros(poses.shape[0])
240
+
241
+ num_batches = poses.shape[0] // chunk_size + int(poses.shape[0]%chunk_size > 0)
242
+ for j in tqdm.tqdm(range(num_batches)):
243
+ batch_poses = poses[j*chunk_size:(j+1)*chunk_size]
244
+ batch_perspectives = perspectives[j*chunk_size:(j+1)*chunk_size]
245
+ optimizer.param_groups[0]['lr'] = 1e-3
246
+ for k in tqdm.tqdm(range(s2_steps)):
247
+ batch_residuals = SE3.exp(learnable_cam_params[j*chunk_size:(j+1)*chunk_size]).as_matrix() # [5760, 4, 4]
248
+ batch_poses_opt = torch.bmm(batch_poses, batch_residuals)
249
+ out = renderer.render_batch(batch_poses_opt, batch_perspectives, 256, 256, ssaa=1) # (500, 256, 256, 3)
250
+ pred_tensor = out["image"]
251
+ valid_mask = (out["alpha"] > 0) & (out["viewcos"] > 0.5) # (500, 256, 256, 1)
252
+ # batch_image = (out["image"].detach().cpu().numpy() * 255).astype(np.uint8)
253
+ # del batch_pose, batch_perspective
254
+
255
+ if k == s2_steps - 1:
256
+ loss = F.mse_loss(pred_tensor, gt_tensor.expand(pred_tensor.shape[0], -1, -1, -1), reduction='none').mean(dim=(1, 2, 3))
257
+ # loss += F.mse_loss(valid_mask, gt_mask_tensor.expand(pred_tensor.shape[0], -1, -1, -1), reduction='none').mean(dim=(1, 2, 3))
258
+ loss_MSE_grid[j*chunk_size:(j+1)*chunk_size] = loss.detach().cpu().numpy()
259
+ loss = loss.mean()
260
+
261
+ else:
262
+ loss = F.mse_loss(pred_tensor, gt_tensor.expand(pred_tensor.shape[0], -1, -1, -1), reduction='mean')
263
+
264
+ loss.backward()
265
+ optimizer.step()
266
+ optimizer.zero_grad()
267
+ scheduler.step()
268
+
269
+ beta = 0.1
270
+ indices = np.argsort(loss_MSE_grid)[:max(int(loss_MSE_grid.shape[0] * beta), 64)]
271
+ batch_poses = poses[indices]
272
+ batch_residuals = SE3.exp(learnable_cam_params[indices].detach()).as_matrix() # [5760, 4, 4]
273
+ poses = torch.bmm(batch_poses, batch_residuals) # [216, 4, 4]
274
+ poses = poses.repeat(4, 1, 1)
275
+
276
+ learnable_cam_params = torch.randn(poses.shape[0], 6) * 1e-2
277
+ learnable_cam_params.requires_grad_()
278
+
279
+ optimizer.param_groups = []
280
+ optimizer.add_param_group({'params': learnable_cam_params})
281
+
282
+ perspectives = torch.from_numpy(cameras[i].perspective).expand(poses.shape[0], -1, -1)
283
+ loss_MSE_grid = np.zeros(poses.shape[0])
284
+
285
+ num_batches = poses.shape[0] // chunk_size + int(poses.shape[0]%chunk_size > 0)
286
+ for j in tqdm.tqdm(range(num_batches)):
287
+ batch_poses = poses[j*chunk_size:(j+1)*chunk_size]
288
+ batch_perspectives = perspectives[j*chunk_size:(j+1)*chunk_size]
289
+ optimizer.param_groups[0]['lr'] = 1e-3
290
+ for k in tqdm.tqdm(range(s2_steps)):
291
+ batch_residuals = SE3.exp(learnable_cam_params[j*chunk_size:(j+1)*chunk_size]).as_matrix() # [5760, 4, 4]
292
+ batch_poses_opt = torch.bmm(batch_poses, batch_residuals)
293
+ out = renderer.render_batch(batch_poses_opt, batch_perspectives, 256, 256, ssaa=1) # (500, 256, 256, 3)
294
+ pred_tensor = out["image"]
295
+ valid_mask = (out["alpha"] > 0) & (out["viewcos"] > 0.5) # (500, 256, 256, 1)
296
+
297
+ if k == s2_steps - 1:
298
+ loss = F.mse_loss(pred_tensor, gt_tensor.expand(pred_tensor.shape[0], -1, -1, -1), reduction='none').mean(dim=(1, 2, 3))
299
+ # loss += F.mse_loss(valid_mask, gt_mask_tensor.expand(pred_tensor.shape[0], -1, -1, -1), reduction='none').mean(dim=(1, 2, 3))
300
+ loss_MSE_grid[j*chunk_size:(j+1)*chunk_size] = loss.detach().cpu().numpy()
301
+ loss = loss.mean()
302
+
303
+ else:
304
+ loss = F.mse_loss(pred_tensor, gt_tensor.expand(pred_tensor.shape[0], -1, -1, -1), reduction='mean')
305
+
306
+ loss.backward()
307
+ optimizer.step()
308
+ optimizer.zero_grad()
309
+ scheduler.step()
310
+
311
+ pose_grid = poses
312
+ loss_LPIPS_grid = np.zeros(poses.shape[0])
313
+
314
+ chunk_size = 64
315
+ gt_tensor = gt_tensor.permute(0, 3, 1, 2).contiguous()
316
+ vis_img_opt = np.zeros((pose_grid.shape[0], 256, 256, 3), dtype=np.uint8)
317
+ num_batches = pose_grid.shape[0] // chunk_size + int(pose_grid.shape[0]%chunk_size > 0)
318
+ for j in tqdm.tqdm(range(num_batches)):
319
+ batch_poses = poses[j*chunk_size:(j+1)*chunk_size]
320
+ batch_residuals = SE3.exp(learnable_cam_params[j*chunk_size:(j+1)*chunk_size]).as_matrix() # [5760, 4, 4]
321
+ batch_poses_opt = torch.bmm(batch_poses, batch_residuals)
322
+ batch_perspectives = perspectives[j*chunk_size:(j+1)*chunk_size]
323
+ with torch.no_grad():
324
+ out = renderer.render_batch(batch_poses_opt, batch_perspectives, 256, 256, ssaa=1) # (500, 256, 256, 3)
325
+ batch_image = (out["image"].detach().cpu().numpy() * 255).astype(np.uint8)
326
+ vis_img_opt[j*chunk_size:(j+1)*chunk_size] = batch_image
327
+
328
+ pred_tensor = out["image"].permute(0, 3, 1, 2).contiguous()
329
+ with torch.no_grad():
330
+ loss_LPIPS_grid[j*chunk_size:(j+1)*chunk_size] = lpips_loss(pred_tensor, gt_tensor.expand(pred_tensor.shape[0], -1, -1, -1)).squeeze().cpu().numpy()
331
+
332
+ # indices_of_smallest = np.argsort(loss_MSE_grid)[:15]
333
+ indices1 = np.argsort(loss_MSE_grid)
334
+ indices2 = np.argsort(loss_LPIPS_grid)
335
+
336
+ ranks1 = np.zeros_like(loss_MSE_grid)
337
+ ranks2 = np.zeros_like(loss_LPIPS_grid)
338
+
339
+ ranks1[indices1] = np.arange(1, loss_MSE_grid.size + 1)
340
+ ranks2[indices2] = np.arange(1, loss_LPIPS_grid.size + 1)
341
+
342
+ total_ranks = ranks1 + ranks2
343
+ indices_of_smallest = np.argsort(total_ranks)[:15]
344
+
345
+ index = indices_of_smallest[0]
346
+ residual = SE3.exp(learnable_cam_params[index].detach()).as_matrix() # [5760, 4, 4]
347
+ c2w = poses[index] @ residual
348
+ w2c = torch.inverse(c2w)
349
+
350
+ w2c[1:3, :] *= -1 # OpenCV to OpenGL
351
+ w2c[:2, :] *= -1 # PyTorch3D to OpenCV
352
+
353
+ data[list(data.keys())[i]]["R"] = w2c[:3, :3].T.tolist()
354
+ data[list(data.keys())[i]]["T"] = w2c[:3, 3].tolist()
355
+
356
+ num_frames = 16
357
+ cmap = plt.get_cmap("hot")
358
+ num_rows = 2
359
+ num_cols = 8
360
+ # plt.subplots_adjust(top=0.2)
361
+ figsize = (num_cols * 2, num_rows * 2.4)
362
+ fig, axs = plt.subplots(num_rows, num_cols, figsize=figsize)
363
+ fig.suptitle(f"Input Image v.s. Top 15 Similar Renderings", fontsize=16, y=0.93)
364
+ plt.subplots_adjust(top=0.9)
365
+ axs = axs.flatten()
366
+ for idx in range(num_rows * num_cols):
367
+ if idx < num_frames:
368
+ if idx == 0:
369
+ axs[idx].imshow(gt_img.reshape(256, 256, 3))
370
+ axs[idx].set_xlabel(f'Input Image', fontsize=10)
371
+ else:
372
+ axs[idx].imshow(vis_img_opt[indices_of_smallest[idx-1]].reshape(256, 256, 3))
373
+ loss_text = f"MSE: {loss_MSE_grid[indices_of_smallest[idx-1]]:.4f}_{int(ranks1[indices_of_smallest[idx-1]]):d}\nLPIPS: {loss_LPIPS_grid[indices_of_smallest[idx-1]]:.4f}_{int(ranks2[indices_of_smallest[idx-1]]):d}"
374
+ axs[idx].text(0.05, 0.95, loss_text, color='black', fontsize=8,
375
+ ha='left', va='top', transform=axs[idx].transAxes)
376
+ for s in ["bottom", "top", "left", "right"]:
377
+ if idx == 0:
378
+ axs[idx].spines[s].set_color("green")
379
+ else:
380
+ axs[idx].spines[s].set_color(cmap(0.8 * idx / (num_frames)))
381
+ axs[idx].spines[s].set_linewidth(5)
382
+ axs[idx].set_xticks([])
383
+ axs[idx].set_yticks([])
384
+
385
+ # if i >= args.all_views:
386
+ # axs[i].set_xlabel(f'MSE: {mse_losses[i%args.all_views]:.4f}\nLPIPS: {lpips_losses[i%args.all_views]:.4f}', fontsize=10)
387
+ else:
388
+ axs[i].axis("off")
389
+ plt.tight_layout()
390
+
391
+ output_path = os.path.join(save_path, f'vis_{i}_render_and_compare.png')
392
+ plt.savefig(output_path) # Save the figure to a file
393
+ plt.close(fig)
394
+ print(f"Visualization file written to {output_path}")
395
+
396
+ del lpips_loss, renderer, learnable_cam_params
397
+ gc.collect()
398
+ torch.cuda.empty_cache()
399
+
400
+ return data
401
+
402
+
403
+ def align_to_mesh(camera_data, mesh_path, save_path, num_views=8):
404
+ parser = argparse.ArgumentParser()
405
+ parser.add_argument('--object', type=str, help="path to mesh (obj, ply, glb, ...)")
406
+ parser.add_argument('--path', type=str, help="path to mesh (obj, ply, glb, ...)")
407
+ parser.add_argument('--front_dir', type=str, default='+z', help="mesh front-facing dir")
408
+ parser.add_argument('--mode', default='albedo', type=str, choices=['lambertian', 'albedo', 'normal', 'depth'], help="rendering mode")
409
+ parser.add_argument('--W', type=int, default=256, help="GUI width")
410
+ parser.add_argument('--H', type=int, default=256, help="GUI height")
411
+ parser.add_argument("--wogui", type=bool, default=True, help="disable all dpg GUI")
412
+ parser.add_argument("--force_cuda_rast", action='store_true', help="force to use RasterizeCudaContext.")
413
+ parser.add_argument("--config", default='configs/navi.yaml', help="path to the yaml config file")
414
+ parser.add_argument('--radius', type=float, default=3, help="default GUI camera radius from center")
415
+ parser.add_argument('--fovy', type=float, default=49.1, help="default GUI camera fovy")
416
+ args, extras = parser.parse_known_args()
417
+
418
+ # override default config from cli
419
+ opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras))
420
+ data = camera_data
421
+
422
+ opt.mesh = mesh_path
423
+ opt.trainable_texture = False
424
+ renderer = Renderer(opt).to(torch.device("cuda"))
425
+
426
+ cameras = [CustomCamera(cam_params) for cam_params in data.values()]
427
+ # cams = [(cam.c2w, cam.perspective, cam.focal_length) for cam in cameras]
428
+ img_paths = [v["filepath"] for k, v in data.items()]
429
+ flags = [int(v["flag"]) for k, v in data.items()]
430
+
431
+ s1_steps = 128
432
+ num_hypotheses = 64
433
+ chunk_size = 512
434
+ print("Number of hypotheses:", num_hypotheses)
435
+
436
+ for i in tqdm.tqdm(range(num_views)):
437
+ if flags[i]:
438
+ continue
439
+
440
+ loss_MSE_grid = np.zeros(num_hypotheses)
441
+ vis_img_opt = torch.zeros(num_hypotheses, 256, 256, 3)
442
+ poses = torch.from_numpy(cameras[i].c2w).expand(num_hypotheses, -1, -1)
443
+ perspectives = torch.from_numpy(cameras[i].perspective).expand(num_hypotheses, -1, -1)
444
+
445
+ learnable_cam_params = torch.randn(num_hypotheses, 6) * 1e-3
446
+ learnable_cam_params.requires_grad_()
447
+
448
+ img_path = img_paths[i]
449
+ gt_img = Image.open(img_path)
450
+ if gt_img.mode == 'RGBA':
451
+ gt_img = np.asarray(gt_img, dtype=np.uint8).copy()
452
+ gt_mask = (gt_img[..., 3:] > 128).astype(np.float32)
453
+ gt_img[gt_img[:, :, -1] <= 255*0.9] = [255., 255., 255., 255.] # thresholding background
454
+ gt_img = gt_img[:, :, :3]
455
+
456
+ gt_tensor = torch.from_numpy(gt_img).float().unsqueeze(0).cuda() / 255.
457
+ gt_mask_tensor = torch.from_numpy(gt_mask).float().unsqueeze(0).cuda()
458
+
459
+ num_batches = num_hypotheses // chunk_size + int(num_hypotheses%chunk_size > 0)
460
+
461
+ l = [{'params': learnable_cam_params, 'lr': 5e-3, "name": "cam_params"}]
462
+ optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
463
+ scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
464
+
465
+ init_lr = optimizer.param_groups[0]['lr']
466
+ for j in tqdm.tqdm(range(num_batches)):
467
+ batch_poses = poses[j*chunk_size:(j+1)*chunk_size]
468
+ batch_perspectives = perspectives[j*chunk_size:(j+1)*chunk_size]
469
+ optimizer.param_groups[0]['lr'] = init_lr
470
+ for k in tqdm.tqdm(range(s1_steps)):
471
+ batch_residuals = SE3.exp(learnable_cam_params[j*chunk_size:(j+1)*chunk_size]).as_matrix() # [5760, 4, 4]
472
+ batch_poses_opt = torch.bmm(batch_poses, batch_residuals)
473
+ out = renderer.render_batch(batch_poses_opt, batch_perspectives, 256, 256, ssaa=1) # (500, 256, 256, 3)
474
+ pred_tensor = out["image"]
475
+
476
+ if k == s1_steps - 1:
477
+ loss = F.mse_loss(pred_tensor, gt_tensor.expand(pred_tensor.shape[0], -1, -1, -1), reduction='none').mean(dim=(1, 2, 3))
478
+ # loss += F.mse_loss(valid_mask, gt_mask_tensor.expand(pred_tensor.shape[0], -1, -1, -1), reduction='none').mean(dim=(1, 2, 3))
479
+ loss_MSE_grid[j*chunk_size:(j+1)*chunk_size] = loss.detach().cpu().numpy()
480
+ batch_image = pred_tensor.detach().cpu()
481
+ vis_img_opt[j*chunk_size:(j+1)*chunk_size] = batch_image
482
+ loss = loss.mean()
483
+
484
+ else:
485
+ loss = F.mse_loss(pred_tensor, gt_tensor.expand(pred_tensor.shape[0], -1, -1, -1), reduction='mean')
486
+
487
+ loss.backward()
488
+ optimizer.step()
489
+ optimizer.zero_grad()
490
+ scheduler.step()
491
+
492
+ indices = np.argsort(loss_MSE_grid)
493
+ residual = SE3.exp(learnable_cam_params[indices[0]].detach()).as_matrix() # [5760, 4, 4]
494
+ c2w = torch.from_numpy(cameras[i].c2w) @ residual
495
+ w2c = torch.inverse(c2w)
496
+
497
+ w2c[1:3, :] *= -1 # OpenCV to OpenGL
498
+ w2c[:2, :] *= -1 # PyTorch3D to OpenCV
499
+
500
+ data[list(data.keys())[i]]["R"] = w2c[:3, :3].T.tolist()
501
+ data[list(data.keys())[i]]["T"] = w2c[:3, 3].tolist()
502
+
503
+ grid = vis_img_opt[indices].permute(0, 3, 1, 2).contiguous()
504
+ grid = grid.view(8, 8, 3, 256, 256).permute(2, 0, 3, 1, 4)
505
+ grid = grid.reshape(3, -1, int(256*8))
506
+ output_path = os.path.join(save_path, f'vis_aligned_candidates_{i}.png')
507
+ save_image(grid, output_path)
508
+
509
+ return data
510
+
sparseags/sh_utils.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The PlenOctree Authors.
2
+ # Redistribution and use in source and binary forms, with or without
3
+ # modification, are permitted provided that the following conditions are met:
4
+ #
5
+ # 1. Redistributions of source code must retain the above copyright notice,
6
+ # this list of conditions and the following disclaimer.
7
+ #
8
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
9
+ # this list of conditions and the following disclaimer in the documentation
10
+ # and/or other materials provided with the distribution.
11
+ #
12
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15
+ # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16
+ # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17
+ # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18
+ # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19
+ # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20
+ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21
+ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22
+ # POSSIBILITY OF SUCH DAMAGE.
23
+
24
+ import torch
25
+
26
+ C0 = 0.28209479177387814
27
+ C1 = 0.4886025119029199
28
+ C2 = [
29
+ 1.0925484305920792,
30
+ -1.0925484305920792,
31
+ 0.31539156525252005,
32
+ -1.0925484305920792,
33
+ 0.5462742152960396
34
+ ]
35
+ C3 = [
36
+ -0.5900435899266435,
37
+ 2.890611442640554,
38
+ -0.4570457994644658,
39
+ 0.3731763325901154,
40
+ -0.4570457994644658,
41
+ 1.445305721320277,
42
+ -0.5900435899266435
43
+ ]
44
+ C4 = [
45
+ 2.5033429417967046,
46
+ -1.7701307697799304,
47
+ 0.9461746957575601,
48
+ -0.6690465435572892,
49
+ 0.10578554691520431,
50
+ -0.6690465435572892,
51
+ 0.47308734787878004,
52
+ -1.7701307697799304,
53
+ 0.6258357354491761,
54
+ ]
55
+
56
+
57
+ def eval_sh(deg, sh, dirs):
58
+ """
59
+ Evaluate spherical harmonics at unit directions
60
+ using hardcoded SH polynomials.
61
+ Works with torch/np/jnp.
62
+ ... Can be 0 or more batch dimensions.
63
+ Args:
64
+ deg: int SH deg. Currently, 0-3 supported
65
+ sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
66
+ dirs: jnp.ndarray unit directions [..., 3]
67
+ Returns:
68
+ [..., C]
69
+ """
70
+ assert deg <= 4 and deg >= 0
71
+ coeff = (deg + 1) ** 2
72
+ assert sh.shape[-1] >= coeff
73
+
74
+ result = C0 * sh[..., 0]
75
+ if deg > 0:
76
+ x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
77
+ result = (result -
78
+ C1 * y * sh[..., 1] +
79
+ C1 * z * sh[..., 2] -
80
+ C1 * x * sh[..., 3])
81
+
82
+ if deg > 1:
83
+ xx, yy, zz = x * x, y * y, z * z
84
+ xy, yz, xz = x * y, y * z, x * z
85
+ result = (result +
86
+ C2[0] * xy * sh[..., 4] +
87
+ C2[1] * yz * sh[..., 5] +
88
+ C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
89
+ C2[3] * xz * sh[..., 7] +
90
+ C2[4] * (xx - yy) * sh[..., 8])
91
+
92
+ if deg > 2:
93
+ result = (result +
94
+ C3[0] * y * (3 * xx - yy) * sh[..., 9] +
95
+ C3[1] * xy * z * sh[..., 10] +
96
+ C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
97
+ C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
98
+ C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
99
+ C3[5] * z * (xx - yy) * sh[..., 14] +
100
+ C3[6] * x * (xx - 3 * yy) * sh[..., 15])
101
+
102
+ if deg > 3:
103
+ result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
104
+ C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
105
+ C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
106
+ C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
107
+ C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
108
+ C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
109
+ C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
110
+ C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
111
+ C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
112
+ return result
113
+
114
+ def RGB2SH(rgb):
115
+ return (rgb - 0.5) / C0
116
+
117
+ def SH2RGB(sh):
118
+ return sh * C0 + 0.5
sparseags/visual_utils.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import cv2
4
+ import csv
5
+ import json
6
+ import math
7
+ import tqdm
8
+ import shutil
9
+ import argparse
10
+ import numpy as np
11
+ from PIL import Image
12
+ from omegaconf import OmegaConf
13
+ import matplotlib.pyplot as plt
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import nvdiffrast.torch as dr
18
+
19
+ from kiui.mesh import Mesh
20
+ from kiui.cam import OrbitCamera
21
+ from kiui.op import safe_normalize
22
+ from kiui.lpips import LPIPS
23
+
24
+ import sys
25
+ from sparseags.mesh_utils.mesh_renderer import Renderer
26
+ from sparseags.cam_utils import orbit_camera, OrbitCamera
27
+ from sparseags.render_utils.gs_renderer import CustomCamera
28
+
29
+
30
+ class GUI:
31
+ def __init__(self, opt):
32
+ self.opt = opt
33
+ self.W = opt.W
34
+ self.H = opt.H
35
+ self.wogui = opt.wogui # disable gui and run in cmd
36
+ self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)
37
+ self.bg_color = torch.ones(3, dtype=torch.float32).cuda() # default white bg
38
+ # self.bg_color = torch.zeros(3, dtype=torch.float32).cuda() # black bg
39
+
40
+ self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
41
+ self.need_update = True # camera moved, should reset accumulation
42
+ self.light_dir = np.array([0, 0])
43
+ self.ambient_ratio = 0.5
44
+
45
+ # auto-rotate
46
+ self.auto_rotate_cam = False
47
+ self.auto_rotate_light = False
48
+
49
+ self.mode = opt.mode
50
+ self.render_modes = ['albedo', 'depth', 'normal', 'lambertian']
51
+
52
+ # load mesh
53
+ self.mesh = Mesh.load(opt.mesh, front_dir=opt.front_dir)
54
+
55
+ if not opt.force_cuda_rast and (self.wogui or os.name == 'nt'):
56
+ self.glctx = dr.RasterizeGLContext()
57
+ else:
58
+ self.glctx = dr.RasterizeCudaContext()
59
+
60
+ def step(self):
61
+
62
+ if not self.need_update:
63
+ return
64
+
65
+ starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
66
+ starter.record()
67
+
68
+ # do MVP for vertices
69
+ pose = torch.from_numpy(self.cam.pose.astype(np.float32)).cuda()
70
+ proj = torch.from_numpy(self.cam.perspective.astype(np.float32)).cuda()
71
+
72
+ v_cam = torch.matmul(F.pad(self.mesh.v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
73
+ v_clip = v_cam @ proj.T
74
+
75
+ rast, rast_db = dr.rasterize(self.glctx, v_clip, self.mesh.f, (self.H, self.W))
76
+
77
+ alpha = (rast[..., 3:] > 0).float()
78
+ alpha = dr.antialias(alpha, rast, v_clip, self.mesh.f).squeeze(0).clamp(0, 1) # [H, W, 3]
79
+
80
+ if self.mode == 'depth':
81
+ depth, _ = dr.interpolate(-v_cam[..., [2]], rast, self.mesh.f) # [1, H, W, 1]
82
+ depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-20)
83
+ buffer = depth.squeeze(0).detach().cpu().numpy().repeat(3, -1) # [H, W, 3]
84
+ else:
85
+ # use vertex color if exists
86
+ if self.mesh.vc is not None:
87
+ albedo, _ = dr.interpolate(self.mesh.vc.unsqueeze(0).contiguous(), rast, self.mesh.f)
88
+ # use texture image
89
+ else:
90
+ texc, _ = dr.interpolate(self.mesh.vt.unsqueeze(0).contiguous(), rast, self.mesh.ft)
91
+ albedo = dr.texture(self.mesh.albedo.unsqueeze(0), texc, filter_mode='linear') # [1, H, W, 3]
92
+
93
+ albedo = torch.where(rast[..., 3:] > 0, albedo, torch.tensor(0).to(albedo.device)) # remove background
94
+ albedo = dr.antialias(albedo, rast, v_clip, self.mesh.f).clamp(0, 1) # [1, H, W, 3]
95
+ if self.mode == 'albedo':
96
+ albedo = albedo * alpha + self.bg_color * (1 - alpha)
97
+ buffer = albedo[0].detach().cpu().numpy()
98
+ else:
99
+ normal, _ = dr.interpolate(self.mesh.vn.unsqueeze(0).contiguous(), rast, self.mesh.fn)
100
+ normal = safe_normalize(normal)
101
+ if self.mode == 'normal':
102
+ normal_image = (normal[0] + 1) / 2
103
+ normal_image = torch.where(rast[..., 3:] > 0, normal_image, torch.tensor(1).to(normal_image.device)) # remove background
104
+ buffer = normal_image.detach().cpu().numpy()
105
+ elif self.mode == 'lambertian':
106
+ light_d = np.deg2rad(self.light_dir)
107
+ light_d = np.array([
108
+ np.cos(light_d[0]) * np.sin(light_d[1]),
109
+ -np.sin(light_d[0]),
110
+ np.cos(light_d[0]) * np.cos(light_d[1]),
111
+ ], dtype=np.float32)
112
+ light_d = torch.from_numpy(light_d).to(albedo.device)
113
+ lambertian = self.ambient_ratio + (1 - self.ambient_ratio) * (normal @ light_d).float().clamp(min=0)
114
+ albedo = (albedo * lambertian.unsqueeze(-1)) * alpha + self.bg_color * (1 - alpha)
115
+ buffer = albedo[0].detach().cpu().numpy()
116
+
117
+ ender.record()
118
+ torch.cuda.synchronize()
119
+ t = starter.elapsed_time(ender)
120
+
121
+ self.render_buffer = buffer
122
+ self.need_update = False
123
+
124
+ if self.auto_rotate_cam:
125
+ self.cam.orbit(5, 0)
126
+ self.need_update = True
127
+
128
+ if self.auto_rotate_light:
129
+ self.light_dir[1] += 3
130
+ self.need_update = True
131
+
132
+
133
+ def vis_output(camera_data, mesh_path=None, save_path=None, num_views=8):
134
+ parser = argparse.ArgumentParser()
135
+ parser.add_argument('--front_dir', type=str, default='+z', help="mesh front-facing dir")
136
+ parser.add_argument('--mode', default='albedo', type=str, choices=['lambertian', 'albedo', 'normal', 'depth'], help="rendering mode")
137
+ parser.add_argument('--W', type=int, default=256, help="GUI width")
138
+ parser.add_argument('--H', type=int, default=256, help="GUI height")
139
+ parser.add_argument('--radius', type=float, default=3, help="default GUI camera radius from center")
140
+ parser.add_argument('--fovy', type=float, default=49.1, help="default GUI camera fovy")
141
+ parser.add_argument("--wogui", type=bool, default=True, help="disable all dpg GUI")
142
+ parser.add_argument("--force_cuda_rast", action='store_true', help="force to use RasterizeCudaContext.")
143
+ parser.add_argument('--elevation', type=int, default=0, help="rendering elevation")
144
+ parser.add_argument('--save_video', type=str, default=None, help="path to save rendered video")
145
+ parser.add_argument('--idx', type=int, default=0, help="GUI height")
146
+ parser.add_argument('--config', default='configs/navi.yaml', type=str, help='Path to config directory, which contains image.yaml')
147
+ args, extras = parser.parse_known_args()
148
+
149
+ # override default config from cli
150
+ opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras))
151
+ data = camera_data
152
+
153
+ cameras = [CustomCamera(cam_params) for cam_params in data.values()]
154
+ cams = [(cam.c2w, cam.perspective, cam.focal_length) for cam in cameras]
155
+ img_paths = [v["filepath"] for k, v in data.items()]
156
+
157
+ opt.mesh = mesh_path
158
+ opt.trainable_texture = False
159
+ renderer = Renderer(opt).to(torch.device("cuda"))
160
+
161
+ lpips_loss = LPIPS(net='vgg').cuda()
162
+ mse_losses = []
163
+ lpips_losses = []
164
+ flags = [int(v["flag"]) for k, v in data.items()]
165
+ images = np.zeros((2, num_views, 256, 256, 3), dtype=np.uint8)
166
+
167
+ for i in tqdm.tqdm(range(len(cams))):
168
+
169
+ img_path = img_paths[i]
170
+
171
+ img = Image.open(img_path)
172
+ if img.mode == 'RGBA':
173
+ img = np.asarray(img, dtype=np.uint8).copy()
174
+ img[img[:, :, -1] <= 255*0.9] = [255., 255., 255., 255.] # thresholding background
175
+ img = img[:, :, :3]
176
+
177
+ gt_tensor = torch.from_numpy(img).permute(2, 0, 1).float().unsqueeze(0).cuda() / 255.0
178
+
179
+ images[0, i] = img
180
+
181
+ with torch.no_grad():
182
+ out = renderer.render(*cams[i][:2], 256, 256, ssaa=1)
183
+
184
+ # rgb loss
185
+ image = (out["image"].detach().cpu().numpy() * 255).astype(np.uint8)
186
+ pred_tensor = out["image"].permute(2, 0, 1).float().unsqueeze(0).cuda()
187
+ # obj_scale = ((out["alpha"] > 0) & (out["viewcos"] > 0.5)).detach().sum().float()
188
+ obj_scale = (out["alpha"] > 0).detach().sum().float()
189
+ obj_scale /= 256 ** 2
190
+
191
+ images[1, i] = image
192
+ with torch.no_grad():
193
+ mse_losses.append(F.mse_loss(pred_tensor, gt_tensor).squeeze().cpu().numpy() / obj_scale.item())
194
+ lpips_losses.append(lpips_loss(pred_tensor, gt_tensor).squeeze().cpu().numpy() / obj_scale.item())
195
+
196
+ mean_mse = np.mean(np.array(mse_losses)[:num_views])
197
+ mean_lpips = np.mean(np.array(lpips_losses)[:num_views])
198
+
199
+ num_frames = 2 * num_views
200
+ cmap = plt.get_cmap("hsv")
201
+ num_rows = 2
202
+ num_cols = num_views
203
+ plt.subplots_adjust(top=0.2)
204
+ figsize = (num_cols * 2, num_rows * 2.2)
205
+ fig, axs = plt.subplots(num_rows, num_cols, figsize=figsize)
206
+ fig.suptitle(f"Avg MSE: {mean_mse:.4f}, Avg LPIPS: {mean_lpips:.4f}", fontsize=16, y=0.97)
207
+ axs = axs.flatten()
208
+ for i in range(num_rows * num_cols):
209
+ if i < num_frames:
210
+ axs[i].imshow(images.reshape(-1, 256, 256, 3)[i])
211
+ for s in ["bottom", "top", "left", "right"]:
212
+ if i % num_views <= num_views - 1:
213
+ if not flags[i%num_views]:
214
+ axs[i].spines[s].set_color("red")
215
+ else:
216
+ axs[i].spines[s].set_color("green")
217
+ else:
218
+ axs[i].spines[s].set_color(cmap(i / (num_frames)))
219
+ axs[i].spines[s].set_linewidth(5)
220
+ axs[i].set_xticks([])
221
+ axs[i].set_yticks([])
222
+
223
+ if i >= num_views:
224
+ axs[i].set_xlabel(f'MSE: {mse_losses[i%num_views]:.4f}\nLPIPS: {lpips_losses[i%num_views]:.4f}', fontsize=10)
225
+ else:
226
+ axs[i].axis("off")
227
+ plt.tight_layout()
228
+ plt.savefig(save_path)
229
+ plt.close(fig)
230
+ print(f"Visualization file written to {save_path}")
231
+
232
+ out_dir = save_path.replace('vis.png', 'reprojections')
233
+ os.makedirs(out_dir, exist_ok=True)
234
+
235
+ for i in range(num_views):
236
+ gt = Image.fromarray(images[0, i])
237
+ pred = Image.fromarray(images[1, i])
238
+ gt.save(os.path.join(out_dir, f"gt_{i}.png"))
239
+ pred.save(os.path.join(out_dir, f"pred_{i}.png"))
240
+
241
+ return np.array(lpips_losses), np.array(mse_losses)
242
+
243
+