init commit
Browse files- sparseags/cam_utils.py +472 -0
- sparseags/dust3r_utils.py +66 -0
- sparseags/guidance_utils/zero123.py +666 -0
- sparseags/guidance_utils/zero123_6d_utils.py +389 -0
- sparseags/main_stage1.py +669 -0
- sparseags/main_stage2.py +410 -0
- sparseags/mesh_utils/grid_put.py +301 -0
- sparseags/mesh_utils/mesh.py +638 -0
- sparseags/mesh_utils/mesh_renderer.py +268 -0
- sparseags/mesh_utils/mesh_utils.py +147 -0
- sparseags/render_utils/gs_renderer.py +1102 -0
- sparseags/render_utils/util.py +510 -0
- sparseags/sh_utils.py +118 -0
- sparseags/visual_utils.py +243 -0
sparseags/cam_utils.py
ADDED
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from scipy.spatial.transform import Rotation as R
|
3 |
+
|
4 |
+
# import ipdb
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from pytorch3d.transforms import Rotate, Translate
|
9 |
+
|
10 |
+
|
11 |
+
def intersect_skew_line_groups(p, r, mask):
|
12 |
+
# p, r both of shape (B, N, n_intersected_lines, 3)
|
13 |
+
# mask of shape (B, N, n_intersected_lines)
|
14 |
+
p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask)
|
15 |
+
if p_intersect is None:
|
16 |
+
return None, None, None, None
|
17 |
+
_, p_line_intersect = point_line_distance(
|
18 |
+
p, r, p_intersect[..., None, :].expand_as(p)
|
19 |
+
)
|
20 |
+
intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum(
|
21 |
+
dim=-1
|
22 |
+
)
|
23 |
+
return p_intersect, p_line_intersect, intersect_dist_squared, r
|
24 |
+
|
25 |
+
|
26 |
+
def intersect_skew_lines_high_dim(p, r, mask=None):
|
27 |
+
# Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions
|
28 |
+
dim = p.shape[-1]
|
29 |
+
# make sure the heading vectors are l2-normed
|
30 |
+
if mask is None:
|
31 |
+
mask = torch.ones_like(p[..., 0])
|
32 |
+
r = torch.nn.functional.normalize(r, dim=-1)
|
33 |
+
|
34 |
+
eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None]
|
35 |
+
I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None]
|
36 |
+
sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3)
|
37 |
+
|
38 |
+
# I_eps = torch.zeros_like(I_min_cov.sum(dim=-3)) + 1e-10
|
39 |
+
# p_intersect = torch.pinverse(I_min_cov.sum(dim=-3) + I_eps).matmul(sum_proj)[..., 0]
|
40 |
+
p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
|
41 |
+
|
42 |
+
# I_min_cov.sum(dim=-3): torch.Size([1, 1, 3, 3])
|
43 |
+
# sum_proj: torch.Size([1, 1, 3, 1])
|
44 |
+
|
45 |
+
# p_intersect = np.linalg.lstsq(I_min_cov.sum(dim=-3).numpy(), sum_proj.numpy(), rcond=None)[0]
|
46 |
+
|
47 |
+
if torch.any(torch.isnan(p_intersect)):
|
48 |
+
print(p_intersect)
|
49 |
+
return None, None
|
50 |
+
ipdb.set_trace()
|
51 |
+
assert False
|
52 |
+
return p_intersect, r
|
53 |
+
|
54 |
+
|
55 |
+
def point_line_distance(p1, r1, p2):
|
56 |
+
df = p2 - p1
|
57 |
+
proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1)
|
58 |
+
line_pt_nearest = p2 - proj_vector
|
59 |
+
d = (proj_vector).norm(dim=-1)
|
60 |
+
return d, line_pt_nearest
|
61 |
+
|
62 |
+
|
63 |
+
def compute_optical_axis_intersection(cameras, in_ndc=True):
|
64 |
+
centers = cameras.get_camera_center()
|
65 |
+
principal_points = cameras.principal_point
|
66 |
+
|
67 |
+
one_vec = torch.ones((len(cameras), 1), device=centers.device)
|
68 |
+
optical_axis = torch.cat((principal_points, one_vec), -1)
|
69 |
+
|
70 |
+
# optical_axis = torch.cat(
|
71 |
+
# (principal_points, cameras.focal_length[:, 0].unsqueeze(1)), -1
|
72 |
+
# )
|
73 |
+
|
74 |
+
pp = cameras.unproject_points(optical_axis, from_ndc=in_ndc, world_coordinates=True)
|
75 |
+
pp2 = torch.diagonal(pp, dim1=0, dim2=1).T
|
76 |
+
|
77 |
+
directions = pp2 - centers
|
78 |
+
centers = centers.unsqueeze(0).unsqueeze(0)
|
79 |
+
directions = directions.unsqueeze(0).unsqueeze(0)
|
80 |
+
|
81 |
+
p_intersect, p_line_intersect, _, r = intersect_skew_line_groups(
|
82 |
+
p=centers, r=directions, mask=None
|
83 |
+
)
|
84 |
+
|
85 |
+
if p_intersect is None:
|
86 |
+
dist = None
|
87 |
+
else:
|
88 |
+
p_intersect = p_intersect.squeeze().unsqueeze(0)
|
89 |
+
dist = (p_intersect - centers).norm(dim=-1)
|
90 |
+
|
91 |
+
return p_intersect, dist, p_line_intersect, pp2, r
|
92 |
+
|
93 |
+
|
94 |
+
def normalize_cameras_with_up_axis(cameras, sequence_name, scale=1.0, in_ndc=True):
|
95 |
+
"""
|
96 |
+
Normalizes cameras such that the optical axes point to the origin and the average
|
97 |
+
distance to the origin is 1.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
cameras (List[camera]).
|
101 |
+
"""
|
102 |
+
|
103 |
+
# Let distance from first camera to origin be unit
|
104 |
+
new_cameras = cameras.clone()
|
105 |
+
new_transform = new_cameras.get_world_to_view_transform()
|
106 |
+
|
107 |
+
p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection(
|
108 |
+
cameras,
|
109 |
+
in_ndc=in_ndc
|
110 |
+
)
|
111 |
+
t = Translate(p_intersect)
|
112 |
+
|
113 |
+
# scale = dist.squeeze()[0]
|
114 |
+
scale = dist.squeeze().mean()
|
115 |
+
|
116 |
+
# Degenerate case
|
117 |
+
if scale == 0:
|
118 |
+
print(cameras.T)
|
119 |
+
print(new_transform.get_matrix()[:, 3, :3])
|
120 |
+
return -1
|
121 |
+
assert scale != 0
|
122 |
+
|
123 |
+
new_transform = t.compose(new_transform)
|
124 |
+
new_cameras.R = new_transform.get_matrix()[:, :3, :3]
|
125 |
+
new_cameras.T = new_transform.get_matrix()[:, 3, :3] / scale * 1.85
|
126 |
+
|
127 |
+
needs_checking = False
|
128 |
+
|
129 |
+
# ===== Rotation normalization
|
130 |
+
# Estimate the world 'up' direction assuming that yaw is small
|
131 |
+
# and running SVD on the x-vectors of the cameras
|
132 |
+
x_vectors = new_cameras.R.transpose(1, 2)[:, 0, :].clone()
|
133 |
+
x_vectors -= x_vectors.mean(dim=0, keepdim=True)
|
134 |
+
U, S, Vh = torch.linalg.svd(x_vectors)
|
135 |
+
V = Vh.mH
|
136 |
+
# vector with the smallest variation is to the normal to
|
137 |
+
# the plane of x-vectors (assume this to be the up direction)
|
138 |
+
if S[0] / S[1] > S[1] / S[2]:
|
139 |
+
print('Warning: unexpected singular values in sequence {}: {}'.format(sequence_name, S))
|
140 |
+
needs_checking = True
|
141 |
+
# return None, None, None, None, None
|
142 |
+
estimated_world_up = V[:, 2:]
|
143 |
+
# check all cameras have the same y-direction
|
144 |
+
for camera_idx in range(len(new_cameras.T)):
|
145 |
+
if torch.sign(torch.dot(estimated_world_up[:, 0],
|
146 |
+
new_cameras.R[0].transpose(0,1)[1, :])) != torch.sign(torch.dot(estimated_world_up[:, 0],
|
147 |
+
new_cameras.R[camera_idx].transpose(0,1)[1, :])):
|
148 |
+
print("Some cameras appear to be flipped in sequence {}".format(sequence_name) )
|
149 |
+
needs_checking = True
|
150 |
+
# return None, None, None, None, None
|
151 |
+
flip = torch.sign(torch.dot(estimated_world_up[:, 0], new_cameras.R[0].transpose(0,1)[1, :])) < 0
|
152 |
+
if flip:
|
153 |
+
estimated_world_up = V[:, 2:] * -1
|
154 |
+
# build the target coordinate basis using the estimated world up
|
155 |
+
target_coordinate_basis = torch.cat([V[:, :1],
|
156 |
+
estimated_world_up,
|
157 |
+
torch.linalg.cross(V[:, :1], estimated_world_up, dim=0)],
|
158 |
+
dim=1)
|
159 |
+
new_cameras.R = torch.matmul(target_coordinate_basis.T, new_cameras.R)
|
160 |
+
return new_cameras, p_intersect, p_line_intersect, pp, r, needs_checking
|
161 |
+
|
162 |
+
|
163 |
+
def dot(x, y):
|
164 |
+
if isinstance(x, np.ndarray):
|
165 |
+
return np.sum(x * y, -1, keepdims=True)
|
166 |
+
else:
|
167 |
+
return torch.sum(x * y, -1, keepdim=True)
|
168 |
+
|
169 |
+
|
170 |
+
def length(x, eps=1e-20):
|
171 |
+
if isinstance(x, np.ndarray):
|
172 |
+
return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps))
|
173 |
+
else:
|
174 |
+
return torch.sqrt(torch.clamp(dot(x, x), min=eps))
|
175 |
+
|
176 |
+
|
177 |
+
def safe_normalize(x, eps=1e-20):
|
178 |
+
return x / length(x, eps)
|
179 |
+
|
180 |
+
|
181 |
+
def look_at(campos, target, opengl=True):
|
182 |
+
# campos: [N, 3], camera/eye position
|
183 |
+
# target: [N, 3], object to look at
|
184 |
+
# return: [N, 3, 3], rotation matrix
|
185 |
+
if not opengl:
|
186 |
+
# camera forward aligns with -z
|
187 |
+
forward_vector = safe_normalize(target - campos)
|
188 |
+
up_vector = np.array([0, 1, 0], dtype=np.float32)
|
189 |
+
right_vector = safe_normalize(np.cross(forward_vector, up_vector))
|
190 |
+
up_vector = safe_normalize(np.cross(right_vector, forward_vector))
|
191 |
+
else:
|
192 |
+
# camera forward aligns with +z
|
193 |
+
forward_vector = safe_normalize(campos - target)
|
194 |
+
up_vector = np.array([0, 1, 0], dtype=np.float32)
|
195 |
+
right_vector = safe_normalize(np.cross(up_vector, forward_vector))
|
196 |
+
up_vector = safe_normalize(np.cross(forward_vector, right_vector))
|
197 |
+
R = np.stack([right_vector, up_vector, forward_vector], axis=1)
|
198 |
+
return R
|
199 |
+
|
200 |
+
|
201 |
+
# elevation & azimuth to pose (cam2world) matrix
|
202 |
+
def orbit_camera(elevation, azimuth, radius=1, is_degree=True, target=None, opengl=True):
|
203 |
+
# radius: scalar
|
204 |
+
# elevation: scalar, in (-90, 90), from +y to -y is (-90, 90)
|
205 |
+
# azimuth: scalar, in (-180, 180), from +z to +x is (0, 90)
|
206 |
+
# return: [4, 4], camera pose matrix
|
207 |
+
if is_degree:
|
208 |
+
elevation = np.deg2rad(elevation)
|
209 |
+
azimuth = np.deg2rad(azimuth)
|
210 |
+
x = radius * np.cos(elevation) * np.sin(azimuth)
|
211 |
+
y = - radius * np.sin(elevation)
|
212 |
+
z = radius * np.cos(elevation) * np.cos(azimuth)
|
213 |
+
if target is None:
|
214 |
+
target = np.zeros([3], dtype=np.float32)
|
215 |
+
campos = np.array([x, y, z]) + target # [3]
|
216 |
+
T = np.eye(4, dtype=np.float32)
|
217 |
+
T[:3, :3] = look_at(campos, target, opengl)
|
218 |
+
T[:3, 3] = campos
|
219 |
+
|
220 |
+
return T
|
221 |
+
|
222 |
+
|
223 |
+
def mat2latlon(T):
|
224 |
+
if not isinstance(T, np.ndarray):
|
225 |
+
xyz = T.cpu().detach().numpy()
|
226 |
+
else:
|
227 |
+
xyz = T.copy()
|
228 |
+
r = np.linalg.norm(xyz)
|
229 |
+
xyz = xyz / r
|
230 |
+
theta = -np.arcsin(xyz[1])
|
231 |
+
azi = np.arctan2(xyz[0], xyz[2])
|
232 |
+
return np.rad2deg(theta), np.rad2deg(azi), r
|
233 |
+
|
234 |
+
|
235 |
+
def extract_camera_properties(camera_to_world_matrix):
|
236 |
+
# Camera position is the translation part of the matrix
|
237 |
+
camera_position = camera_to_world_matrix[:3, 3]
|
238 |
+
|
239 |
+
# Extracting the forward direction vector (third column of rotation matrix)
|
240 |
+
forward = camera_to_world_matrix[:3, 2]
|
241 |
+
|
242 |
+
return camera_position, forward
|
243 |
+
|
244 |
+
|
245 |
+
def compute_angular_error_batch(rotation1, rotation2):
|
246 |
+
R_rel = np.einsum("Bij,Bjk ->Bik", rotation1.transpose(0, 2, 1), rotation2)
|
247 |
+
t = (np.trace(R_rel, axis1=1, axis2=2) - 1) / 2
|
248 |
+
theta = np.arccos(np.clip(t, -1, 1))
|
249 |
+
return theta * 180 / np.pi
|
250 |
+
|
251 |
+
|
252 |
+
def find_mask_center_and_translate(image, mask):
|
253 |
+
"""
|
254 |
+
Calculate the center of the mask and translate the image such that
|
255 |
+
the mask center is at the image center.
|
256 |
+
|
257 |
+
Args:
|
258 |
+
- image (torch.Tensor): Input image tensor of shape (N, C, H, W)
|
259 |
+
- mask (torch.Tensor): Mask tensor of shape (N, 1, H, W)
|
260 |
+
|
261 |
+
Returns:
|
262 |
+
- Translated image of shape (N, C, H, W)
|
263 |
+
"""
|
264 |
+
_, _, h, w = image.shape
|
265 |
+
|
266 |
+
# Calculate the center of mass of the mask
|
267 |
+
# Note: mask should be a binary mask of the same spatial dimensions as the image
|
268 |
+
y_coords, x_coords = torch.meshgrid(torch.arange(0, h), torch.arange(0, w), indexing='ij')
|
269 |
+
total_mass = mask.sum(dim=[2, 3], keepdim=True)
|
270 |
+
x_center = (mask * x_coords.to(image.device)).sum(dim=[2, 3], keepdim=True) / total_mass
|
271 |
+
y_center = (mask * y_coords.to(image.device)).sum(dim=[2, 3], keepdim=True) / total_mass
|
272 |
+
|
273 |
+
# Calculate the translation needed to move the mask center to the image center
|
274 |
+
image_center_x, image_center_y = w // 2, h // 2
|
275 |
+
delta_x = x_center.squeeze() - image_center_x
|
276 |
+
delta_y = y_center.squeeze() - image_center_y
|
277 |
+
|
278 |
+
return torch.tensor([delta_x, delta_y])
|
279 |
+
|
280 |
+
|
281 |
+
def create_voxel_grid(length, resolution=64):
|
282 |
+
"""
|
283 |
+
Creates a voxel grid.
|
284 |
+
xyz_range: ((min_x, max_x), (min_y, max_y), (min_z, max_z))
|
285 |
+
resolution: The number of divisions along each axis.
|
286 |
+
Returns a 4D tensor representing the voxel grid, with each voxel initialized to 1 (solid).
|
287 |
+
"""
|
288 |
+
x = torch.linspace(-length, length, resolution)
|
289 |
+
y = torch.linspace(-length, length, resolution)
|
290 |
+
z = torch.linspace(-length, length, resolution)
|
291 |
+
|
292 |
+
xx, yy, zz = torch.meshgrid(x, y, z, indexing='ij')
|
293 |
+
voxels = torch.stack([xx, yy, zz, torch.ones_like(xx)], dim=-1) # Homogeneous coordinates
|
294 |
+
return voxels
|
295 |
+
|
296 |
+
|
297 |
+
def project_voxels_to_image(voxels, camera):
|
298 |
+
"""
|
299 |
+
Projects voxel centers into the camera's image plane.
|
300 |
+
voxels: 4D tensor of voxel grid in homogeneous coordinates.
|
301 |
+
K: Camera intrinsic matrix.
|
302 |
+
R: Camera rotation matrix.
|
303 |
+
t: Camera translation vector.
|
304 |
+
Returns a tensor of projected 2D points in image coordinates.
|
305 |
+
"""
|
306 |
+
device = voxels.device
|
307 |
+
# K, R, t = torch.tensor(K, device=device), torch.tensor(R, device=device), torch.tensor(t, device=device)
|
308 |
+
|
309 |
+
# Flatten voxels to shape (N, 4) for matrix multiplication
|
310 |
+
N = voxels.nelement() // 4 # Total number of voxels
|
311 |
+
voxels_flat = voxels.reshape(-1, 4).t() # Shape (4, N)
|
312 |
+
|
313 |
+
# # Apply extrinsic parameters (rotation and translation)
|
314 |
+
# transformed_voxels = R @ voxels_flat[:3, :] + t[:, None]
|
315 |
+
|
316 |
+
# # Apply intrinsic parameters
|
317 |
+
# projected_voxels = K @ transformed_voxels
|
318 |
+
|
319 |
+
projected_voxels = camera.projection_matrix.transpose(0, 1) @ camera.world_view_transform.transpose(0, 1) @ voxels_flat
|
320 |
+
|
321 |
+
# Convert from homogeneous coordinates to 2D
|
322 |
+
projected_voxels_2d = (projected_voxels[:2, :] / projected_voxels[3, :]).t() # Reshape to grid dimensions with 2D points
|
323 |
+
projected_voxels_2d = (projected_voxels_2d.reshape(*voxels.shape[:-1], 2) + 1.) * 255 * 0.5
|
324 |
+
|
325 |
+
return projected_voxels_2d
|
326 |
+
|
327 |
+
|
328 |
+
def carve_voxels(voxel_grid, projected_points, mask):
|
329 |
+
"""
|
330 |
+
Updates the voxel grid based on the comparison with the mask.
|
331 |
+
voxel_grid: 3D tensor representing the voxel grid.
|
332 |
+
projected_points: Projected 2D points in image coordinates.
|
333 |
+
mask: Binary mask image.
|
334 |
+
"""
|
335 |
+
# Convert projected points to indices in the mask
|
336 |
+
indices_x = torch.clamp(projected_points[..., 0], 0, mask.shape[1] - 1).long()
|
337 |
+
indices_y = torch.clamp(projected_points[..., 1], 0, mask.shape[0] - 1).long()
|
338 |
+
|
339 |
+
# Check if projected points are within the object in the mask
|
340 |
+
in_object = mask[indices_y, indices_x]
|
341 |
+
|
342 |
+
# Carve out voxels where the projection does not fall within the object
|
343 |
+
voxel_grid[in_object == 0] = 0
|
344 |
+
|
345 |
+
|
346 |
+
def sample_points_from_voxel(cameras, masks, length=1, resolution=64, N=5000, inverse=False, device="cuda"):
|
347 |
+
"""
|
348 |
+
Randomly sample N points from solid regions in a voxel grid.
|
349 |
+
|
350 |
+
Args:
|
351 |
+
- voxel_grid (torch.Tensor): A 3D tensor representing the voxel grid after carving.
|
352 |
+
Solid regions are marked with 1s.
|
353 |
+
- N (int): The number of points to sample.
|
354 |
+
|
355 |
+
Returns:
|
356 |
+
- sampled_points (torch.Tensor): A tensor of shape (N, 3) representing the sampled 3D coordinates.
|
357 |
+
"""
|
358 |
+
voxel_grid = create_voxel_grid(length, resolution).to(device)
|
359 |
+
voxel_grid_indicator = torch.ones(resolution, resolution, resolution)
|
360 |
+
|
361 |
+
masks = torch.from_numpy(masks).to(device).squeeze()
|
362 |
+
|
363 |
+
for idx, cam in enumerate(cameras):
|
364 |
+
projected_points = project_voxels_to_image(voxel_grid, cam)
|
365 |
+
carve_voxels(voxel_grid_indicator, projected_points, masks[idx])
|
366 |
+
|
367 |
+
voxel_grid_indicator = voxel_grid_indicator.reshape(resolution, resolution, resolution)
|
368 |
+
|
369 |
+
# Identify the indices of solid voxels
|
370 |
+
if inverse:
|
371 |
+
solid_indices = torch.nonzero(voxel_grid_indicator == 0)
|
372 |
+
else:
|
373 |
+
solid_indices = torch.nonzero(voxel_grid_indicator == 1)
|
374 |
+
|
375 |
+
# Randomly select N indices from the solid indices
|
376 |
+
if N <= solid_indices.size(0):
|
377 |
+
# Randomly select N indices from the solid indices if there are enough solid voxels
|
378 |
+
sampled_indices = solid_indices[torch.randperm(solid_indices.size(0))[:N]]
|
379 |
+
else:
|
380 |
+
# If there are not enough solid voxels, sample with replacement
|
381 |
+
sampled_indices = solid_indices[torch.randint(0, solid_indices.size(0), (N,))]
|
382 |
+
|
383 |
+
# Convert indices to coordinates
|
384 |
+
# Note: This step assumes the voxel grid spans from 0 to 1 in each dimension.
|
385 |
+
# Adjust accordingly if your grid spans a different range.
|
386 |
+
sampled_points = sampled_indices.float() / (voxel_grid.size(0) - 1) * 2 * length - length
|
387 |
+
|
388 |
+
return sampled_points
|
389 |
+
|
390 |
+
|
391 |
+
class OrbitCamera:
|
392 |
+
def __init__(self, W, H, r=2, fovy=60, near=0.01, far=100):
|
393 |
+
self.W = W
|
394 |
+
self.H = H
|
395 |
+
self.radius = r # camera distance from center
|
396 |
+
self.fovy = np.deg2rad(fovy) # deg 2 rad
|
397 |
+
self.near = near
|
398 |
+
self.far = far
|
399 |
+
self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point
|
400 |
+
self.rot = R.from_matrix(np.eye(3))
|
401 |
+
self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized!
|
402 |
+
|
403 |
+
@property
|
404 |
+
def fovx(self):
|
405 |
+
return 2 * np.arctan(np.tan(self.fovy / 2) * self.W / self.H)
|
406 |
+
|
407 |
+
@property
|
408 |
+
def campos(self):
|
409 |
+
return self.pose[:3, 3]
|
410 |
+
|
411 |
+
# pose (c2w)
|
412 |
+
@property
|
413 |
+
def pose(self):
|
414 |
+
# first move camera to radius
|
415 |
+
res = np.eye(4, dtype=np.float32)
|
416 |
+
res[2, 3] = self.radius # opengl convention...
|
417 |
+
# rotate
|
418 |
+
rot = np.eye(4, dtype=np.float32)
|
419 |
+
rot[:3, :3] = self.rot.as_matrix()
|
420 |
+
res = rot @ res
|
421 |
+
# translate
|
422 |
+
res[:3, 3] -= self.center
|
423 |
+
return res
|
424 |
+
|
425 |
+
# view (w2c)
|
426 |
+
@property
|
427 |
+
def view(self):
|
428 |
+
return np.linalg.inv(self.pose)
|
429 |
+
|
430 |
+
# projection (perspective)
|
431 |
+
@property
|
432 |
+
def perspective(self):
|
433 |
+
y = np.tan(self.fovy / 2)
|
434 |
+
aspect = self.W / self.H
|
435 |
+
return np.array(
|
436 |
+
[
|
437 |
+
[1 / (y * aspect), 0, 0, 0],
|
438 |
+
[0, -1 / y, 0, 0],
|
439 |
+
[
|
440 |
+
0,
|
441 |
+
0,
|
442 |
+
-(self.far + self.near) / (self.far - self.near),
|
443 |
+
-(2 * self.far * self.near) / (self.far - self.near),
|
444 |
+
],
|
445 |
+
[0, 0, -1, 0],
|
446 |
+
],
|
447 |
+
dtype=np.float32,
|
448 |
+
)
|
449 |
+
|
450 |
+
# intrinsics
|
451 |
+
@property
|
452 |
+
def intrinsics(self):
|
453 |
+
focal = self.H / (2 * np.tan(self.fovy / 2))
|
454 |
+
return np.array([focal, focal, self.W // 2, self.H // 2], dtype=np.float32)
|
455 |
+
|
456 |
+
@property
|
457 |
+
def mvp(self):
|
458 |
+
return self.perspective @ np.linalg.inv(self.pose) # [4, 4]
|
459 |
+
|
460 |
+
def orbit(self, dx, dy):
|
461 |
+
# rotate along camera up/side axis!
|
462 |
+
side = self.rot.as_matrix()[:3, 0]
|
463 |
+
rotvec_x = self.up * np.radians(-0.05 * dx)
|
464 |
+
rotvec_y = side * np.radians(-0.05 * dy)
|
465 |
+
self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot
|
466 |
+
|
467 |
+
def scale(self, delta):
|
468 |
+
self.radius *= 1.1 ** (-delta)
|
469 |
+
|
470 |
+
def pan(self, dx, dy, dz=0):
|
471 |
+
# pan in camera coordinate system (careful on the sensitivity!)
|
472 |
+
self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([-dx, -dy, dz])
|
sparseags/dust3r_utils.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from pytorch3d.renderer import PerspectiveCameras
|
3 |
+
|
4 |
+
import sys
|
5 |
+
sys.path.append('./')
|
6 |
+
from sparseags.cam_utils import normalize_cameras_with_up_axis
|
7 |
+
|
8 |
+
sys.path[0] = sys.path[0] + '/dust3r'
|
9 |
+
from dust3r.inference import inference
|
10 |
+
from dust3r.utils.image import load_images
|
11 |
+
from dust3r.image_pairs import make_pairs
|
12 |
+
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
|
13 |
+
|
14 |
+
|
15 |
+
def infer_dust3r(dust3r_model, file_names, device='cuda'):
|
16 |
+
batch_size = 1
|
17 |
+
schedule = 'cosine'
|
18 |
+
lr = 0.01
|
19 |
+
niter = 300
|
20 |
+
|
21 |
+
images = load_images(file_names, size=224)
|
22 |
+
pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
|
23 |
+
output = inference(pairs, dust3r_model, device, batch_size=batch_size)
|
24 |
+
|
25 |
+
scene = global_aligner(output, optimize_pp=True, device=device, mode=GlobalAlignerMode.PointCloudOptimizer)
|
26 |
+
loss = scene.compute_global_alignment(init="mst", niter=niter, schedule=schedule, lr=lr)
|
27 |
+
|
28 |
+
# retrieve useful values from scene:
|
29 |
+
imgs = scene.imgs
|
30 |
+
cams2world = scene.get_im_poses()
|
31 |
+
w2c = torch.linalg.inv(cams2world)
|
32 |
+
pps = scene.get_principal_points() * 256 / 224
|
33 |
+
focals = scene.get_focals() * 256 / 224
|
34 |
+
|
35 |
+
w2c[:, :2] *= -1 # OpenCV to PyTorch3D
|
36 |
+
Rs = w2c[:, :3, :3].transpose(1, 2)
|
37 |
+
Ts = w2c[:, :3, 3]
|
38 |
+
|
39 |
+
cameras = PerspectiveCameras(
|
40 |
+
focal_length=focals,
|
41 |
+
principal_point=pps,
|
42 |
+
in_ndc=False,
|
43 |
+
R=Rs,
|
44 |
+
T=Ts,
|
45 |
+
)
|
46 |
+
normalized_cameras, _, _, _, _, needs_checking = normalize_cameras_with_up_axis(cameras, None, in_ndc=False)
|
47 |
+
|
48 |
+
if normalized_cameras is None:
|
49 |
+
print("It seems something wrong...")
|
50 |
+
return 0
|
51 |
+
|
52 |
+
data = {}
|
53 |
+
base_names = [file_name.split('/')[-1].split('.')[0] for file_name in file_names]
|
54 |
+
file_names = [file_name.replace('source', 'processed').replace('.png', '_rgba.png') for file_name in file_names]
|
55 |
+
|
56 |
+
for idx, base_name in enumerate(base_names):
|
57 |
+
data[base_name] = {}
|
58 |
+
data[base_name]["R"] = normalized_cameras.R[idx].cpu().tolist()
|
59 |
+
data[base_name]["T"] = normalized_cameras.T[idx].cpu().tolist()
|
60 |
+
data[base_name]["needs_checking"] = needs_checking
|
61 |
+
data[base_name]["principal_point"] = normalized_cameras.principal_point[idx].cpu().tolist()
|
62 |
+
data[base_name]["focal_length"] = normalized_cameras.focal_length[idx].cpu().tolist()
|
63 |
+
data[base_name]["flag"] = 1
|
64 |
+
data[base_name]["filepath"] = file_names[idx]
|
65 |
+
|
66 |
+
return data
|
sparseags/guidance_utils/zero123.py
ADDED
@@ -0,0 +1,666 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
import math
|
17 |
+
import warnings
|
18 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
19 |
+
|
20 |
+
import PIL
|
21 |
+
import torch
|
22 |
+
import torchvision.transforms.functional as TF
|
23 |
+
from diffusers.configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
24 |
+
from diffusers.image_processor import VaeImageProcessor
|
25 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
26 |
+
from diffusers.models.modeling_utils import ModelMixin
|
27 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
28 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
29 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
30 |
+
StableDiffusionSafetyChecker,
|
31 |
+
)
|
32 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
33 |
+
from diffusers.utils import deprecate, is_accelerate_available, logging
|
34 |
+
from diffusers.utils.torch_utils import randn_tensor
|
35 |
+
from packaging import version
|
36 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
37 |
+
|
38 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
39 |
+
|
40 |
+
|
41 |
+
class CLIPCameraProjection(ModelMixin, ConfigMixin):
|
42 |
+
"""
|
43 |
+
A Projection layer for CLIP embedding and camera embedding.
|
44 |
+
|
45 |
+
Parameters:
|
46 |
+
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `clip_embed`
|
47 |
+
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
|
48 |
+
projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
|
49 |
+
additional_embeddings`.
|
50 |
+
"""
|
51 |
+
|
52 |
+
@register_to_config
|
53 |
+
def __init__(self, embedding_dim: int = 768, additional_embeddings: int = 4):
|
54 |
+
super().__init__()
|
55 |
+
self.embedding_dim = embedding_dim
|
56 |
+
self.additional_embeddings = additional_embeddings
|
57 |
+
|
58 |
+
self.input_dim = self.embedding_dim + self.additional_embeddings
|
59 |
+
self.output_dim = self.embedding_dim
|
60 |
+
|
61 |
+
self.proj = torch.nn.Linear(self.input_dim, self.output_dim)
|
62 |
+
|
63 |
+
def forward(
|
64 |
+
self,
|
65 |
+
embedding: torch.FloatTensor,
|
66 |
+
):
|
67 |
+
"""
|
68 |
+
The [`PriorTransformer`] forward method.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
hidden_states (`torch.FloatTensor` of shape `(batch_size, input_dim)`):
|
72 |
+
The currently input embeddings.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
The output embedding projection (`torch.FloatTensor` of shape `(batch_size, output_dim)`).
|
76 |
+
"""
|
77 |
+
proj_embedding = self.proj(embedding)
|
78 |
+
return proj_embedding
|
79 |
+
|
80 |
+
|
81 |
+
class Zero123Pipeline(DiffusionPipeline):
|
82 |
+
r"""
|
83 |
+
Pipeline to generate variations from an input image using Stable Diffusion.
|
84 |
+
|
85 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
86 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
87 |
+
|
88 |
+
Args:
|
89 |
+
vae ([`AutoencoderKL`]):
|
90 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
91 |
+
image_encoder ([`CLIPVisionModelWithProjection`]):
|
92 |
+
Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of
|
93 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),
|
94 |
+
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
95 |
+
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
96 |
+
scheduler ([`SchedulerMixin`]):
|
97 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
98 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
99 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
100 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
101 |
+
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
102 |
+
feature_extractor ([`CLIPImageProcessor`]):
|
103 |
+
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
104 |
+
"""
|
105 |
+
# TODO: feature_extractor is required to encode images (if they are in PIL format),
|
106 |
+
# we should give a descriptive message if the pipeline doesn't have one.
|
107 |
+
_optional_components = ["safety_checker"]
|
108 |
+
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
vae: AutoencoderKL,
|
112 |
+
image_encoder: CLIPVisionModelWithProjection,
|
113 |
+
unet: UNet2DConditionModel,
|
114 |
+
scheduler: KarrasDiffusionSchedulers,
|
115 |
+
safety_checker: StableDiffusionSafetyChecker,
|
116 |
+
feature_extractor: CLIPImageProcessor,
|
117 |
+
clip_camera_projection: CLIPCameraProjection,
|
118 |
+
requires_safety_checker: bool = True,
|
119 |
+
):
|
120 |
+
super().__init__()
|
121 |
+
|
122 |
+
if safety_checker is None and requires_safety_checker:
|
123 |
+
logger.warn(
|
124 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
125 |
+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
126 |
+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
127 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
128 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
129 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
130 |
+
)
|
131 |
+
|
132 |
+
if safety_checker is not None and feature_extractor is None:
|
133 |
+
raise ValueError(
|
134 |
+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
135 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
136 |
+
)
|
137 |
+
|
138 |
+
is_unet_version_less_0_9_0 = hasattr(
|
139 |
+
unet.config, "_diffusers_version"
|
140 |
+
) and version.parse(
|
141 |
+
version.parse(unet.config._diffusers_version).base_version
|
142 |
+
) < version.parse(
|
143 |
+
"0.9.0.dev0"
|
144 |
+
)
|
145 |
+
is_unet_sample_size_less_64 = (
|
146 |
+
hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
147 |
+
)
|
148 |
+
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
149 |
+
deprecation_message = (
|
150 |
+
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
151 |
+
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
|
152 |
+
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
153 |
+
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
154 |
+
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
155 |
+
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
156 |
+
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
157 |
+
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
158 |
+
" the `unet/config.json` file"
|
159 |
+
)
|
160 |
+
deprecate(
|
161 |
+
"sample_size<64", "1.0.0", deprecation_message, standard_warn=False
|
162 |
+
)
|
163 |
+
new_config = dict(unet.config)
|
164 |
+
new_config["sample_size"] = 64
|
165 |
+
unet._internal_dict = FrozenDict(new_config)
|
166 |
+
|
167 |
+
self.register_modules(
|
168 |
+
vae=vae,
|
169 |
+
image_encoder=image_encoder,
|
170 |
+
unet=unet,
|
171 |
+
scheduler=scheduler,
|
172 |
+
safety_checker=safety_checker,
|
173 |
+
feature_extractor=feature_extractor,
|
174 |
+
clip_camera_projection=clip_camera_projection,
|
175 |
+
)
|
176 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
177 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
178 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
179 |
+
|
180 |
+
def enable_sequential_cpu_offload(self, gpu_id=0):
|
181 |
+
r"""
|
182 |
+
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
183 |
+
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
184 |
+
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
185 |
+
"""
|
186 |
+
if is_accelerate_available():
|
187 |
+
from accelerate import cpu_offload
|
188 |
+
else:
|
189 |
+
raise ImportError("Please install accelerate via `pip install accelerate`")
|
190 |
+
|
191 |
+
device = torch.device(f"cuda:{gpu_id}")
|
192 |
+
|
193 |
+
for cpu_offloaded_model in [
|
194 |
+
self.unet,
|
195 |
+
self.image_encoder,
|
196 |
+
self.vae,
|
197 |
+
self.safety_checker,
|
198 |
+
]:
|
199 |
+
if cpu_offloaded_model is not None:
|
200 |
+
cpu_offload(cpu_offloaded_model, device)
|
201 |
+
|
202 |
+
@property
|
203 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
204 |
+
def _execution_device(self):
|
205 |
+
r"""
|
206 |
+
Returns the device on which the pipeline's models will be executed. After calling
|
207 |
+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
208 |
+
hooks.
|
209 |
+
"""
|
210 |
+
if not hasattr(self.unet, "_hf_hook"):
|
211 |
+
return self.device
|
212 |
+
for module in self.unet.modules():
|
213 |
+
if (
|
214 |
+
hasattr(module, "_hf_hook")
|
215 |
+
and hasattr(module._hf_hook, "execution_device")
|
216 |
+
and module._hf_hook.execution_device is not None
|
217 |
+
):
|
218 |
+
return torch.device(module._hf_hook.execution_device)
|
219 |
+
return self.device
|
220 |
+
|
221 |
+
def _encode_image(
|
222 |
+
self,
|
223 |
+
image,
|
224 |
+
elevation,
|
225 |
+
azimuth,
|
226 |
+
distance,
|
227 |
+
device,
|
228 |
+
num_images_per_prompt,
|
229 |
+
do_classifier_free_guidance,
|
230 |
+
clip_image_embeddings=None,
|
231 |
+
image_camera_embeddings=None,
|
232 |
+
):
|
233 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
234 |
+
|
235 |
+
if image_camera_embeddings is None:
|
236 |
+
if image is None:
|
237 |
+
assert clip_image_embeddings is not None
|
238 |
+
image_embeddings = clip_image_embeddings.to(device=device, dtype=dtype)
|
239 |
+
else:
|
240 |
+
if not isinstance(image, torch.Tensor):
|
241 |
+
image = self.feature_extractor(
|
242 |
+
images=image, return_tensors="pt"
|
243 |
+
).pixel_values
|
244 |
+
|
245 |
+
image = image.to(device=device, dtype=dtype)
|
246 |
+
image_embeddings = self.image_encoder(image).image_embeds
|
247 |
+
image_embeddings = image_embeddings.unsqueeze(1)
|
248 |
+
|
249 |
+
bs_embed, seq_len, _ = image_embeddings.shape
|
250 |
+
|
251 |
+
if isinstance(elevation, float):
|
252 |
+
elevation = torch.as_tensor(
|
253 |
+
[elevation] * bs_embed, dtype=dtype, device=device
|
254 |
+
)
|
255 |
+
if isinstance(azimuth, float):
|
256 |
+
azimuth = torch.as_tensor(
|
257 |
+
[azimuth] * bs_embed, dtype=dtype, device=device
|
258 |
+
)
|
259 |
+
if isinstance(distance, float):
|
260 |
+
distance = torch.as_tensor(
|
261 |
+
[distance] * bs_embed, dtype=dtype, device=device
|
262 |
+
)
|
263 |
+
|
264 |
+
camera_embeddings = torch.stack(
|
265 |
+
[
|
266 |
+
torch.deg2rad(elevation),
|
267 |
+
torch.sin(torch.deg2rad(azimuth)),
|
268 |
+
torch.cos(torch.deg2rad(azimuth)),
|
269 |
+
distance,
|
270 |
+
],
|
271 |
+
dim=-1,
|
272 |
+
)[:, None, :]
|
273 |
+
|
274 |
+
image_embeddings = torch.cat([image_embeddings, camera_embeddings], dim=-1)
|
275 |
+
|
276 |
+
# project (image, camera) embeddings to the same dimension as clip embeddings
|
277 |
+
image_embeddings = self.clip_camera_projection(image_embeddings)
|
278 |
+
else:
|
279 |
+
image_embeddings = image_camera_embeddings.to(device=device, dtype=dtype)
|
280 |
+
bs_embed, seq_len, _ = image_embeddings.shape
|
281 |
+
|
282 |
+
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
283 |
+
image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
|
284 |
+
image_embeddings = image_embeddings.view(
|
285 |
+
bs_embed * num_images_per_prompt, seq_len, -1
|
286 |
+
)
|
287 |
+
|
288 |
+
if do_classifier_free_guidance:
|
289 |
+
negative_prompt_embeds = torch.zeros_like(image_embeddings)
|
290 |
+
|
291 |
+
# For classifier free guidance, we need to do two forward passes.
|
292 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
293 |
+
# to avoid doing two forward passes
|
294 |
+
image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
|
295 |
+
|
296 |
+
return image_embeddings
|
297 |
+
|
298 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
299 |
+
def run_safety_checker(self, image, device, dtype):
|
300 |
+
if self.safety_checker is None:
|
301 |
+
has_nsfw_concept = None
|
302 |
+
else:
|
303 |
+
if torch.is_tensor(image):
|
304 |
+
feature_extractor_input = self.image_processor.postprocess(
|
305 |
+
image, output_type="pil"
|
306 |
+
)
|
307 |
+
else:
|
308 |
+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
309 |
+
safety_checker_input = self.feature_extractor(
|
310 |
+
feature_extractor_input, return_tensors="pt"
|
311 |
+
).to(device)
|
312 |
+
image, has_nsfw_concept = self.safety_checker(
|
313 |
+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
314 |
+
)
|
315 |
+
return image, has_nsfw_concept
|
316 |
+
|
317 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
318 |
+
def decode_latents(self, latents):
|
319 |
+
warnings.warn(
|
320 |
+
"The decode_latents method is deprecated and will be removed in a future version. Please"
|
321 |
+
" use VaeImageProcessor instead",
|
322 |
+
FutureWarning,
|
323 |
+
)
|
324 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
325 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
326 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
327 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
328 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
329 |
+
return image
|
330 |
+
|
331 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
332 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
333 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
334 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
335 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
336 |
+
# and should be between [0, 1]
|
337 |
+
|
338 |
+
accepts_eta = "eta" in set(
|
339 |
+
inspect.signature(self.scheduler.step).parameters.keys()
|
340 |
+
)
|
341 |
+
extra_step_kwargs = {}
|
342 |
+
if accepts_eta:
|
343 |
+
extra_step_kwargs["eta"] = eta
|
344 |
+
|
345 |
+
# check if the scheduler accepts generator
|
346 |
+
accepts_generator = "generator" in set(
|
347 |
+
inspect.signature(self.scheduler.step).parameters.keys()
|
348 |
+
)
|
349 |
+
if accepts_generator:
|
350 |
+
extra_step_kwargs["generator"] = generator
|
351 |
+
return extra_step_kwargs
|
352 |
+
|
353 |
+
def check_inputs(self, image, height, width, callback_steps):
|
354 |
+
# TODO: check image size or adjust image size to (height, width)
|
355 |
+
|
356 |
+
if height % 8 != 0 or width % 8 != 0:
|
357 |
+
raise ValueError(
|
358 |
+
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
|
359 |
+
)
|
360 |
+
|
361 |
+
if (callback_steps is None) or (
|
362 |
+
callback_steps is not None
|
363 |
+
and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
364 |
+
):
|
365 |
+
raise ValueError(
|
366 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
367 |
+
f" {type(callback_steps)}."
|
368 |
+
)
|
369 |
+
|
370 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
371 |
+
def prepare_latents(
|
372 |
+
self,
|
373 |
+
batch_size,
|
374 |
+
num_channels_latents,
|
375 |
+
height,
|
376 |
+
width,
|
377 |
+
dtype,
|
378 |
+
device,
|
379 |
+
generator,
|
380 |
+
latents=None,
|
381 |
+
):
|
382 |
+
shape = (
|
383 |
+
batch_size,
|
384 |
+
num_channels_latents,
|
385 |
+
height // self.vae_scale_factor,
|
386 |
+
width // self.vae_scale_factor,
|
387 |
+
)
|
388 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
389 |
+
raise ValueError(
|
390 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
391 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
392 |
+
)
|
393 |
+
|
394 |
+
if latents is None:
|
395 |
+
latents = randn_tensor(
|
396 |
+
shape, generator=generator, device=device, dtype=dtype
|
397 |
+
)
|
398 |
+
else:
|
399 |
+
latents = latents.to(device)
|
400 |
+
|
401 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
402 |
+
latents = latents * self.scheduler.init_noise_sigma
|
403 |
+
return latents
|
404 |
+
|
405 |
+
def _get_latent_model_input(
|
406 |
+
self,
|
407 |
+
latents: torch.FloatTensor,
|
408 |
+
image: Optional[
|
409 |
+
Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]
|
410 |
+
],
|
411 |
+
num_images_per_prompt: int,
|
412 |
+
do_classifier_free_guidance: bool,
|
413 |
+
image_latents: Optional[torch.FloatTensor] = None,
|
414 |
+
):
|
415 |
+
if isinstance(image, PIL.Image.Image):
|
416 |
+
image_pt = TF.to_tensor(image).unsqueeze(0).to(latents)
|
417 |
+
elif isinstance(image, list):
|
418 |
+
image_pt = torch.stack([TF.to_tensor(img) for img in image], dim=0).to(
|
419 |
+
latents
|
420 |
+
)
|
421 |
+
elif isinstance(image, torch.Tensor):
|
422 |
+
image_pt = image
|
423 |
+
else:
|
424 |
+
image_pt = None
|
425 |
+
|
426 |
+
if image_pt is None:
|
427 |
+
assert image_latents is not None
|
428 |
+
image_pt = image_latents.repeat_interleave(num_images_per_prompt, dim=0)
|
429 |
+
else:
|
430 |
+
image_pt = image_pt * 2.0 - 1.0 # scale to [-1, 1]
|
431 |
+
# FIXME: encoded latents should be multiplied with self.vae.config.scaling_factor
|
432 |
+
# but zero123 was not trained this way
|
433 |
+
image_pt = self.vae.encode(image_pt).latent_dist.mode()
|
434 |
+
image_pt = image_pt.repeat_interleave(num_images_per_prompt, dim=0)
|
435 |
+
if do_classifier_free_guidance:
|
436 |
+
latent_model_input = torch.cat(
|
437 |
+
[
|
438 |
+
torch.cat([latents, latents], dim=0),
|
439 |
+
torch.cat([torch.zeros_like(image_pt), image_pt], dim=0),
|
440 |
+
],
|
441 |
+
dim=1,
|
442 |
+
)
|
443 |
+
else:
|
444 |
+
latent_model_input = torch.cat([latents, image_pt], dim=1)
|
445 |
+
|
446 |
+
return latent_model_input
|
447 |
+
|
448 |
+
@torch.no_grad()
|
449 |
+
def __call__(
|
450 |
+
self,
|
451 |
+
image: Optional[
|
452 |
+
Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]
|
453 |
+
] = None,
|
454 |
+
elevation: Optional[Union[float, torch.FloatTensor]] = None,
|
455 |
+
azimuth: Optional[Union[float, torch.FloatTensor]] = None,
|
456 |
+
distance: Optional[Union[float, torch.FloatTensor]] = None,
|
457 |
+
height: Optional[int] = None,
|
458 |
+
width: Optional[int] = None,
|
459 |
+
num_inference_steps: int = 50,
|
460 |
+
guidance_scale: float = 3.0,
|
461 |
+
num_images_per_prompt: int = 1,
|
462 |
+
eta: float = 0.0,
|
463 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
464 |
+
latents: Optional[torch.FloatTensor] = None,
|
465 |
+
clip_image_embeddings: Optional[torch.FloatTensor] = None,
|
466 |
+
image_camera_embeddings: Optional[torch.FloatTensor] = None,
|
467 |
+
image_latents: Optional[torch.FloatTensor] = None,
|
468 |
+
output_type: Optional[str] = "pil",
|
469 |
+
return_dict: bool = True,
|
470 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
471 |
+
callback_steps: int = 1,
|
472 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
473 |
+
):
|
474 |
+
r"""
|
475 |
+
Function invoked when calling the pipeline for generation.
|
476 |
+
|
477 |
+
Args:
|
478 |
+
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
|
479 |
+
The image or images to guide the image generation. If you provide a tensor, it needs to comply with the
|
480 |
+
configuration of
|
481 |
+
[this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
|
482 |
+
`CLIPImageProcessor`
|
483 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
484 |
+
The height in pixels of the generated image.
|
485 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
486 |
+
The width in pixels of the generated image.
|
487 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
488 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
489 |
+
expense of slower inference.
|
490 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
491 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
492 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
493 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
494 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
495 |
+
usually at the expense of lower image quality.
|
496 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
497 |
+
The number of images to generate per prompt.
|
498 |
+
eta (`float`, *optional*, defaults to 0.0):
|
499 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
500 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
501 |
+
generator (`torch.Generator`, *optional*):
|
502 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
503 |
+
to make generation deterministic.
|
504 |
+
latents (`torch.FloatTensor`, *optional*):
|
505 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
506 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
507 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
508 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
509 |
+
The output format of the generate image. Choose between
|
510 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
511 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
512 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
513 |
+
plain tuple.
|
514 |
+
callback (`Callable`, *optional*):
|
515 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
516 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
517 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
518 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
519 |
+
called at every step.
|
520 |
+
|
521 |
+
Returns:
|
522 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
523 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
524 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
525 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
526 |
+
(nsfw) content, according to the `safety_checker`.
|
527 |
+
"""
|
528 |
+
# 0. Default height and width to unet
|
529 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
530 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
531 |
+
|
532 |
+
# 1. Check inputs. Raise error if not correct
|
533 |
+
# TODO: check input elevation, azimuth, and distance
|
534 |
+
# TODO: check image, clip_image_embeddings, image_latents
|
535 |
+
self.check_inputs(image, height, width, callback_steps)
|
536 |
+
|
537 |
+
# 2. Define call parameters
|
538 |
+
if isinstance(image, PIL.Image.Image):
|
539 |
+
batch_size = 1
|
540 |
+
elif isinstance(image, list):
|
541 |
+
batch_size = len(image)
|
542 |
+
elif isinstance(image, torch.Tensor):
|
543 |
+
batch_size = image.shape[0]
|
544 |
+
else:
|
545 |
+
assert image_latents is not None
|
546 |
+
assert (
|
547 |
+
clip_image_embeddings is not None or image_camera_embeddings is not None
|
548 |
+
)
|
549 |
+
batch_size = image_latents.shape[0]
|
550 |
+
|
551 |
+
device = self._execution_device
|
552 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
553 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
554 |
+
# corresponds to doing no classifier free guidance.
|
555 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
556 |
+
|
557 |
+
# 3. Encode input image
|
558 |
+
if isinstance(image, PIL.Image.Image) or isinstance(image, list):
|
559 |
+
pil_image = image
|
560 |
+
elif isinstance(image, torch.Tensor):
|
561 |
+
pil_image = [TF.to_pil_image(image[i]) for i in range(image.shape[0])]
|
562 |
+
else:
|
563 |
+
pil_image = None
|
564 |
+
image_embeddings = self._encode_image(
|
565 |
+
pil_image,
|
566 |
+
elevation,
|
567 |
+
azimuth,
|
568 |
+
distance,
|
569 |
+
device,
|
570 |
+
num_images_per_prompt,
|
571 |
+
do_classifier_free_guidance,
|
572 |
+
clip_image_embeddings,
|
573 |
+
image_camera_embeddings,
|
574 |
+
)
|
575 |
+
|
576 |
+
# 4. Prepare timesteps
|
577 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
578 |
+
timesteps = self.scheduler.timesteps
|
579 |
+
|
580 |
+
# 5. Prepare latent variables
|
581 |
+
# num_channels_latents = self.unet.config.in_channels
|
582 |
+
num_channels_latents = 4 # FIXME: hard-coded
|
583 |
+
latents = self.prepare_latents(
|
584 |
+
batch_size * num_images_per_prompt,
|
585 |
+
num_channels_latents,
|
586 |
+
height,
|
587 |
+
width,
|
588 |
+
image_embeddings.dtype,
|
589 |
+
device,
|
590 |
+
generator,
|
591 |
+
latents,
|
592 |
+
)
|
593 |
+
|
594 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
595 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
596 |
+
|
597 |
+
# 7. Denoising loop
|
598 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
599 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
600 |
+
for i, t in enumerate(timesteps):
|
601 |
+
# expand the latents if we are doing classifier free guidance
|
602 |
+
latent_model_input = self._get_latent_model_input(
|
603 |
+
latents,
|
604 |
+
image,
|
605 |
+
num_images_per_prompt,
|
606 |
+
do_classifier_free_guidance,
|
607 |
+
image_latents,
|
608 |
+
)
|
609 |
+
latent_model_input = self.scheduler.scale_model_input(
|
610 |
+
latent_model_input, t
|
611 |
+
)
|
612 |
+
|
613 |
+
# predict the noise residual
|
614 |
+
noise_pred = self.unet(
|
615 |
+
latent_model_input,
|
616 |
+
t,
|
617 |
+
encoder_hidden_states=image_embeddings,
|
618 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
619 |
+
).sample
|
620 |
+
|
621 |
+
# perform guidance
|
622 |
+
if do_classifier_free_guidance:
|
623 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
624 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
625 |
+
noise_pred_text - noise_pred_uncond
|
626 |
+
)
|
627 |
+
|
628 |
+
# compute the previous noisy sample x_t -> x_t-1
|
629 |
+
latents = self.scheduler.step(
|
630 |
+
noise_pred, t, latents, **extra_step_kwargs
|
631 |
+
).prev_sample
|
632 |
+
|
633 |
+
# call the callback, if provided
|
634 |
+
if i == len(timesteps) - 1 or (
|
635 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
636 |
+
):
|
637 |
+
progress_bar.update()
|
638 |
+
if callback is not None and i % callback_steps == 0:
|
639 |
+
callback(i, t, latents)
|
640 |
+
|
641 |
+
if not output_type == "latent":
|
642 |
+
image = self.vae.decode(
|
643 |
+
latents / self.vae.config.scaling_factor, return_dict=False
|
644 |
+
)[0]
|
645 |
+
image, has_nsfw_concept = self.run_safety_checker(
|
646 |
+
image, device, image_embeddings.dtype
|
647 |
+
)
|
648 |
+
else:
|
649 |
+
image = latents
|
650 |
+
has_nsfw_concept = None
|
651 |
+
|
652 |
+
if has_nsfw_concept is None:
|
653 |
+
do_denormalize = [True] * image.shape[0]
|
654 |
+
else:
|
655 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
656 |
+
|
657 |
+
image = self.image_processor.postprocess(
|
658 |
+
image, output_type=output_type, do_denormalize=do_denormalize
|
659 |
+
)
|
660 |
+
|
661 |
+
if not return_dict:
|
662 |
+
return (image, has_nsfw_concept)
|
663 |
+
|
664 |
+
return StableDiffusionPipelineOutput(
|
665 |
+
images=image, nsfw_content_detected=has_nsfw_concept
|
666 |
+
)
|
sparseags/guidance_utils/zero123_6d_utils.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import DDIMScheduler
|
2 |
+
import torchvision.transforms.functional as TF
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torchvision
|
10 |
+
from torchvision.utils import save_image
|
11 |
+
from torchvision import transforms
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from einops import rearrange
|
14 |
+
|
15 |
+
import sys
|
16 |
+
sys.path.append('./')
|
17 |
+
|
18 |
+
from sparseags.guidance_utils.zero123 import Zero123Pipeline
|
19 |
+
|
20 |
+
|
21 |
+
name_mapping = {
|
22 |
+
"model.diffusion_model.input_blocks.1.1.": "down_blocks.0.attentions.0.",
|
23 |
+
"model.diffusion_model.input_blocks.2.1.": "down_blocks.0.attentions.1.",
|
24 |
+
"model.diffusion_model.input_blocks.4.1.": "down_blocks.1.attentions.0.",
|
25 |
+
"model.diffusion_model.input_blocks.5.1.": "down_blocks.1.attentions.1.",
|
26 |
+
"model.diffusion_model.input_blocks.7.1.": "down_blocks.2.attentions.0.",
|
27 |
+
"model.diffusion_model.input_blocks.8.1.": "down_blocks.2.attentions.1.",
|
28 |
+
"model.diffusion_model.middle_block.1.": "mid_block.attentions.0.",
|
29 |
+
"model.diffusion_model.output_blocks.3.1.": "up_blocks.1.attentions.0.",
|
30 |
+
"model.diffusion_model.output_blocks.4.1.": "up_blocks.1.attentions.1.",
|
31 |
+
"model.diffusion_model.output_blocks.5.1.": "up_blocks.1.attentions.2.",
|
32 |
+
"model.diffusion_model.output_blocks.6.1.": "up_blocks.2.attentions.0.",
|
33 |
+
"model.diffusion_model.output_blocks.7.1.": "up_blocks.2.attentions.1.",
|
34 |
+
"model.diffusion_model.output_blocks.8.1.": "up_blocks.2.attentions.2.",
|
35 |
+
"model.diffusion_model.output_blocks.9.1.": "up_blocks.3.attentions.0.",
|
36 |
+
"model.diffusion_model.output_blocks.10.1.": "up_blocks.3.attentions.1.",
|
37 |
+
"model.diffusion_model.output_blocks.11.1.": "up_blocks.3.attentions.2.",
|
38 |
+
}
|
39 |
+
|
40 |
+
class Zero123(nn.Module):
|
41 |
+
def __init__(self, device, fp16=True, t_range=[0.02, 0.98], model_key="ashawkey/zero123-xl-diffusers"):
|
42 |
+
super().__init__()
|
43 |
+
|
44 |
+
self.device = device
|
45 |
+
self.fp16 = fp16
|
46 |
+
self.dtype = torch.float16 if fp16 else torch.float32
|
47 |
+
|
48 |
+
self.pipe = Zero123Pipeline.from_pretrained(
|
49 |
+
model_key,
|
50 |
+
trust_remote_code=True,
|
51 |
+
torch_dtype=self.dtype,
|
52 |
+
).to(self.device)
|
53 |
+
|
54 |
+
# load weights from the checkpoint
|
55 |
+
ckpt_path = "checkpoints/zero123_6dof_23k.ckpt"
|
56 |
+
print(f'[INFO] loading checkpoint from {ckpt_path} ...')
|
57 |
+
old_state = torch.load(ckpt_path)
|
58 |
+
pretrained_weights = old_state['state_dict']['cc_projection.weight']
|
59 |
+
pretrained_biases = old_state['state_dict']['cc_projection.bias']
|
60 |
+
linear_layer = torch.nn.Linear(768 + 18, 768)
|
61 |
+
linear_layer.weight.data = pretrained_weights
|
62 |
+
linear_layer.bias.data = pretrained_biases
|
63 |
+
self.pipe.clip_camera_projection.proj = linear_layer.to(dtype=self.dtype, device=self.device)
|
64 |
+
|
65 |
+
for name in list(old_state['state_dict'].keys()):
|
66 |
+
for k, v in name_mapping.items():
|
67 |
+
if k in name:
|
68 |
+
old_state['state_dict'][name.replace(k, name_mapping[k])] = old_state['state_dict'][name].to(dtype=self.dtype, device=self.device)
|
69 |
+
|
70 |
+
m, u = self.pipe.unet.load_state_dict(old_state['state_dict'], strict=False)
|
71 |
+
|
72 |
+
# stable-zero123 has a different camera embedding
|
73 |
+
self.use_stable_zero123 = 'stable' in model_key
|
74 |
+
|
75 |
+
self.pipe.image_encoder.eval()
|
76 |
+
self.pipe.vae.eval()
|
77 |
+
self.pipe.unet.eval()
|
78 |
+
self.pipe.clip_camera_projection.eval()
|
79 |
+
|
80 |
+
self.vae = self.pipe.vae
|
81 |
+
self.unet = self.pipe.unet
|
82 |
+
|
83 |
+
self.pipe.set_progress_bar_config(disable=True)
|
84 |
+
|
85 |
+
self.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
|
86 |
+
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
|
87 |
+
|
88 |
+
self.min_step = int(self.num_train_timesteps * t_range[0])
|
89 |
+
self.max_step = int(self.num_train_timesteps * t_range[1])
|
90 |
+
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
|
91 |
+
|
92 |
+
self.embeddings = None
|
93 |
+
|
94 |
+
@torch.no_grad()
|
95 |
+
def get_img_embeds(self, x):
|
96 |
+
# x: image tensor in [0, 1]
|
97 |
+
x = F.interpolate(x, (256, 256), mode='bilinear', align_corners=False)
|
98 |
+
x_pil = [TF.to_pil_image(image) for image in x]
|
99 |
+
x_clip = self.pipe.feature_extractor(images=x_pil, return_tensors="pt").pixel_values.to(device=self.device, dtype=self.dtype)
|
100 |
+
c = self.pipe.image_encoder(x_clip).image_embeds
|
101 |
+
v = self.encode_imgs(x.to(self.dtype)) / self.vae.config.scaling_factor
|
102 |
+
self.embeddings = [c, v]
|
103 |
+
|
104 |
+
def get_cam_embeddings(self, polar, azimuth, radius, default_elevation=0):
|
105 |
+
if self.use_stable_zero123:
|
106 |
+
T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), np.deg2rad([90 + default_elevation] * len(polar))], axis=-1)
|
107 |
+
else:
|
108 |
+
# original zero123 camera embedding
|
109 |
+
T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), radius], axis=-1)
|
110 |
+
T = torch.from_numpy(T).unsqueeze(1).to(dtype=self.dtype, device=self.device) # [8, 1, 4]
|
111 |
+
return T
|
112 |
+
|
113 |
+
def get_cam_embeddings_6D(self, target_RT, cond_RT):
|
114 |
+
T_target = torch.from_numpy(target_RT["c2w"])
|
115 |
+
focal_len_target = torch.from_numpy(target_RT["focal_length"])
|
116 |
+
|
117 |
+
T_cond = torch.from_numpy(cond_RT["c2w"])
|
118 |
+
focal_len_cond = torch.from_numpy(cond_RT["focal_length"])
|
119 |
+
|
120 |
+
focal_len = focal_len_target / focal_len_cond
|
121 |
+
|
122 |
+
d_T = torch.linalg.inv(T_target) @ T_cond
|
123 |
+
d_T = torch.cat([d_T.flatten(), torch.log(focal_len)])
|
124 |
+
return d_T.unsqueeze(0).unsqueeze(0).to(dtype=self.dtype, device=self.device)
|
125 |
+
|
126 |
+
@torch.no_grad()
|
127 |
+
def refine(self, pred_rgb, cam_embed,
|
128 |
+
guidance_scale=5, steps=50, strength=0.8, idx=None
|
129 |
+
):
|
130 |
+
|
131 |
+
######## Slight modification ########
|
132 |
+
if pred_rgb is not None:
|
133 |
+
batch_size = pred_rgb.shape[0]
|
134 |
+
else:
|
135 |
+
batch_size = 1
|
136 |
+
|
137 |
+
self.scheduler.set_timesteps(steps)
|
138 |
+
|
139 |
+
if strength == 0:
|
140 |
+
init_step = 0
|
141 |
+
latents = torch.randn((1, 4, 32, 32), device=self.device, dtype=self.dtype)
|
142 |
+
else:
|
143 |
+
init_step = int(steps * strength)
|
144 |
+
pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False)
|
145 |
+
latents = self.encode_imgs(pred_rgb_256.to(self.dtype))
|
146 |
+
latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step])
|
147 |
+
|
148 |
+
T = cam_embed
|
149 |
+
if idx is not None:
|
150 |
+
cc_emb = torch.cat([self.embeddings[0][idx].repeat(batch_size, 1, 1), T], dim=-1)
|
151 |
+
else:
|
152 |
+
cc_emb = torch.cat([self.embeddings[0].repeat(batch_size, 1, 1), T], dim=-1)
|
153 |
+
cc_emb = self.pipe.clip_camera_projection(cc_emb)
|
154 |
+
cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0)
|
155 |
+
|
156 |
+
if idx is not None:
|
157 |
+
vae_emb = self.embeddings[1][idx].repeat(batch_size, 1, 1, 1)
|
158 |
+
else:
|
159 |
+
vae_emb = self.embeddings[1].repeat(batch_size, 1, 1, 1)
|
160 |
+
vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0)
|
161 |
+
|
162 |
+
for i, t in enumerate(self.scheduler.timesteps[init_step:]):
|
163 |
+
|
164 |
+
x_in = torch.cat([latents] * 2)
|
165 |
+
t_in = torch.cat([t.view(1)]).to(self.device)
|
166 |
+
|
167 |
+
noise_pred = self.unet(
|
168 |
+
torch.cat([x_in, vae_emb], dim=1),
|
169 |
+
t_in.to(self.unet.dtype),
|
170 |
+
encoder_hidden_states=cc_emb,
|
171 |
+
).sample
|
172 |
+
|
173 |
+
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
|
174 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
175 |
+
|
176 |
+
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
177 |
+
|
178 |
+
imgs = self.decode_latents(latents) # [1, 3, 256, 256]
|
179 |
+
return imgs
|
180 |
+
|
181 |
+
def train_step(self, pred_rgb, polar, azimuth, radius, step_ratio=None, guidance_scale=5, as_latent=False):
|
182 |
+
# pred_rgb: tensor [1, 3, H, W] in [0, 1]
|
183 |
+
|
184 |
+
batch_size = pred_rgb.shape[0]
|
185 |
+
|
186 |
+
if as_latent:
|
187 |
+
latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1
|
188 |
+
else:
|
189 |
+
pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False)
|
190 |
+
latents = self.encode_imgs(pred_rgb_256.to(self.dtype))
|
191 |
+
|
192 |
+
if step_ratio is not None:
|
193 |
+
# dreamtime-like
|
194 |
+
# t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)
|
195 |
+
t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)
|
196 |
+
t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)
|
197 |
+
else:
|
198 |
+
t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)
|
199 |
+
|
200 |
+
w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1)
|
201 |
+
|
202 |
+
with torch.no_grad():
|
203 |
+
noise = torch.randn_like(latents)
|
204 |
+
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
205 |
+
|
206 |
+
x_in = torch.cat([latents_noisy] * 2)
|
207 |
+
t_in = torch.cat([t] * 2)
|
208 |
+
|
209 |
+
T = self.get_cam_embeddings(polar, azimuth, radius)
|
210 |
+
cc_emb = torch.cat([self.embeddings[0].repeat(batch_size, 1, 1), T], dim=-1)
|
211 |
+
cc_emb = self.pipe.clip_camera_projection(cc_emb)
|
212 |
+
cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0)
|
213 |
+
|
214 |
+
vae_emb = self.embeddings[1].repeat(batch_size, 1, 1, 1)
|
215 |
+
vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0)
|
216 |
+
|
217 |
+
noise_pred = self.unet(
|
218 |
+
torch.cat([x_in, vae_emb], dim=1),
|
219 |
+
t_in.to(self.unet.dtype),
|
220 |
+
encoder_hidden_states=cc_emb,
|
221 |
+
).sample
|
222 |
+
|
223 |
+
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
|
224 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
225 |
+
|
226 |
+
grad = w * (noise_pred - noise)
|
227 |
+
grad = torch.nan_to_num(grad)
|
228 |
+
|
229 |
+
target = (latents - grad).detach()
|
230 |
+
loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum')
|
231 |
+
|
232 |
+
return loss
|
233 |
+
|
234 |
+
def angle_between(self, sph_v1, sph_v2):
|
235 |
+
def sph2cart(sv):
|
236 |
+
r, theta, phi = sv[0], sv[1], sv[2]
|
237 |
+
# The polar representation is different from Stable-DreamFusion
|
238 |
+
return torch.tensor([r * torch.cos(theta) * torch.cos(phi), r * torch.cos(theta) * torch.sin(phi), r * torch.sin(theta)])
|
239 |
+
def unit_vector(v):
|
240 |
+
return v / torch.linalg.norm(v)
|
241 |
+
def angle_between_2_sph(sv1, sv2):
|
242 |
+
v1, v2 = sph2cart(sv1), sph2cart(sv2)
|
243 |
+
v1_u, v2_u = unit_vector(v1), unit_vector(v2)
|
244 |
+
return torch.arccos(torch.clip(torch.dot(v1_u, v2_u), -1.0, 1.0))
|
245 |
+
angles = torch.empty(len(sph_v1), len(sph_v2))
|
246 |
+
for i, sv1 in enumerate(sph_v1):
|
247 |
+
for j, sv2 in enumerate(sph_v2):
|
248 |
+
angles[i][j] = angle_between_2_sph(sv1, sv2)
|
249 |
+
return angles
|
250 |
+
|
251 |
+
def batch_train_step(self, pred_rgb, target_RT, cond_cams, step_ratio=None, guidance_scale=5, as_latent=False, step=None):
|
252 |
+
# pred_rgb: tensor [1, 3, H, W] in [0, 1]
|
253 |
+
|
254 |
+
batch_size = pred_rgb.shape[0]
|
255 |
+
|
256 |
+
if as_latent:
|
257 |
+
latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1
|
258 |
+
else:
|
259 |
+
pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False)
|
260 |
+
latents = self.encode_imgs(pred_rgb_256.to(self.dtype))
|
261 |
+
|
262 |
+
if step_ratio is not None:
|
263 |
+
# dreamtime-like
|
264 |
+
# t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio)
|
265 |
+
t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step)
|
266 |
+
t = torch.full((batch_size,), t, dtype=torch.long, device=self.device)
|
267 |
+
else:
|
268 |
+
t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)
|
269 |
+
|
270 |
+
w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1)
|
271 |
+
|
272 |
+
with torch.no_grad():
|
273 |
+
noise = torch.randn_like(latents)
|
274 |
+
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
275 |
+
|
276 |
+
x_in = torch.cat([latents_noisy] * 2 * self.num_views)
|
277 |
+
t_in = torch.cat([t] * 2 * self.num_views)
|
278 |
+
|
279 |
+
cc_embs = []
|
280 |
+
vae_embs = []
|
281 |
+
noise_preds = []
|
282 |
+
for idx in range(self.num_views):
|
283 |
+
cond_RT = {
|
284 |
+
"c2w": cond_cams[idx].c2w,
|
285 |
+
"focal_length": cond_cams[idx].focal_length,
|
286 |
+
}
|
287 |
+
T = self.get_cam_embeddings_6D(target_RT, cond_RT)
|
288 |
+
cc_emb = torch.cat([self.embeddings[0][idx].repeat(batch_size, 1, 1), T], dim=-1)
|
289 |
+
cc_emb = self.pipe.clip_camera_projection(cc_emb)
|
290 |
+
cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0)
|
291 |
+
|
292 |
+
vae_emb = self.embeddings[1][idx].repeat(batch_size, 1, 1, 1)
|
293 |
+
vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0)
|
294 |
+
|
295 |
+
cc_embs.append(cc_emb)
|
296 |
+
vae_embs.append(vae_emb)
|
297 |
+
|
298 |
+
cc_emb = torch.cat(cc_embs, dim=0)
|
299 |
+
vae_emb = torch.cat(vae_embs, dim=0)
|
300 |
+
noise_pred = self.unet(
|
301 |
+
torch.cat([x_in, vae_emb], dim=1),
|
302 |
+
t_in.to(self.unet.dtype),
|
303 |
+
encoder_hidden_states=cc_emb,
|
304 |
+
).sample
|
305 |
+
|
306 |
+
noise_pred_chunks = noise_pred.chunk(self.num_views)
|
307 |
+
for idx in range(self.num_views):
|
308 |
+
noise_pred_cond, noise_pred_uncond = noise_pred_chunks[idx][0], noise_pred_chunks[idx][1]
|
309 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
310 |
+
noise_preds.append(noise_pred)
|
311 |
+
|
312 |
+
noise_pred = torch.stack(noise_preds).sum(dim=0) / len(noise_preds) # self.num_views # Average over all views
|
313 |
+
|
314 |
+
grad = w * (noise_pred - noise)
|
315 |
+
grad = torch.nan_to_num(grad)
|
316 |
+
|
317 |
+
target = (latents - grad).detach()
|
318 |
+
loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum')
|
319 |
+
|
320 |
+
return loss
|
321 |
+
|
322 |
+
def decode_latents(self, latents):
|
323 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
324 |
+
|
325 |
+
imgs = self.vae.decode(latents).sample
|
326 |
+
imgs = (imgs / 2 + 0.5).clamp(0, 1)
|
327 |
+
|
328 |
+
return imgs
|
329 |
+
|
330 |
+
def encode_imgs(self, imgs, mode=False):
|
331 |
+
# imgs: [B, 3, H, W]
|
332 |
+
|
333 |
+
imgs = 2 * imgs - 1
|
334 |
+
|
335 |
+
posterior = self.vae.encode(imgs).latent_dist
|
336 |
+
if mode:
|
337 |
+
latents = posterior.mode()
|
338 |
+
else:
|
339 |
+
latents = posterior.sample()
|
340 |
+
latents = latents * self.vae.config.scaling_factor
|
341 |
+
|
342 |
+
return latents
|
343 |
+
|
344 |
+
|
345 |
+
def process_im(im):
|
346 |
+
if im.shape[-1] == 3:
|
347 |
+
if self.bg_remover is None:
|
348 |
+
self.bg_remover = rembg.new_session()
|
349 |
+
im = rembg.remove(im, session=self.bg_remover)
|
350 |
+
|
351 |
+
im = im.astype(np.float32) / 255.0
|
352 |
+
|
353 |
+
input_mask = im[..., 3:]
|
354 |
+
input_img = im[..., :3] * input_mask + (1 - input_mask)
|
355 |
+
input_img = input_img[..., ::-1].copy()
|
356 |
+
image = torch.from_numpy(input_img).permute(2, 0, 1).unsqueeze(0).contiguous().to(device)
|
357 |
+
image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False)
|
358 |
+
|
359 |
+
return image
|
360 |
+
|
361 |
+
|
362 |
+
def get_T_6d(target_RT, cond_RT, use_objaverse):
|
363 |
+
if use_objaverse:
|
364 |
+
new_row = torch.tensor([[0., 0., 0., 1.]])
|
365 |
+
|
366 |
+
T_target = torch.from_numpy(target_RT) # world to cam matrix
|
367 |
+
T_target = torch.cat((T_target, new_row), dim=0)
|
368 |
+
T_target = torch.linalg.inv(T_target) # Cam to world matrix
|
369 |
+
T_target[:3, :] = T_target[[1, 2, 0]]
|
370 |
+
|
371 |
+
T_cond = torch.from_numpy(cond_RT)
|
372 |
+
T_cond = torch.cat((T_cond, new_row), dim=0)
|
373 |
+
T_cond = torch.linalg.inv(T_cond)
|
374 |
+
T_cond[:3, :] = T_cond[[1, 2, 0]]
|
375 |
+
|
376 |
+
focal_len = torch.tensor([1., 1.])
|
377 |
+
|
378 |
+
else:
|
379 |
+
T_target = torch.from_numpy(target_RT["c2w"])
|
380 |
+
focal_len_target = torch.from_numpy(target_RT["focal_length"])
|
381 |
+
|
382 |
+
T_cond = torch.from_numpy(cond_RT["c2w"])
|
383 |
+
focal_len_cond = torch.from_numpy(cond_RT["focal_length"])
|
384 |
+
|
385 |
+
focal_len = focal_len_target / focal_len_cond
|
386 |
+
|
387 |
+
d_T = torch.linalg.inv(T_target) @ T_cond
|
388 |
+
d_T = torch.cat([d_T.flatten(), torch.log(focal_len)])
|
389 |
+
return d_T
|
sparseags/main_stage1.py
ADDED
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import sys
|
4 |
+
import json
|
5 |
+
import time
|
6 |
+
import tqdm
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
import rembg
|
14 |
+
from liegroups.torch import SE3
|
15 |
+
|
16 |
+
import sys
|
17 |
+
sys.path.append('./')
|
18 |
+
|
19 |
+
from sparseags.cam_utils import orbit_camera, OrbitCamera, mat2latlon, find_mask_center_and_translate
|
20 |
+
from sparseags.render_utils.gs_renderer import Renderer, Camera, FoVCamera, CustomCamera
|
21 |
+
from sparseags.mesh_utils.grid_put import mipmap_linear_grid_put_2d
|
22 |
+
from sparseags.mesh_utils.mesh import Mesh, safe_normalize
|
23 |
+
|
24 |
+
|
25 |
+
class GUI:
|
26 |
+
def __init__(self, opt):
|
27 |
+
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
|
28 |
+
self.gui = opt.gui # enable gui
|
29 |
+
self.W = opt.W
|
30 |
+
self.H = opt.H
|
31 |
+
|
32 |
+
self.mode = "image"
|
33 |
+
self.seed = 0
|
34 |
+
|
35 |
+
self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32)
|
36 |
+
self.need_update = True # update buffer_image
|
37 |
+
|
38 |
+
# models
|
39 |
+
self.device = torch.device("cuda")
|
40 |
+
self.bg_remover = None
|
41 |
+
|
42 |
+
self.guidance_sd = None
|
43 |
+
self.guidance_zero123 = None
|
44 |
+
self.guidance_dino = None
|
45 |
+
|
46 |
+
self.enable_sd = False
|
47 |
+
self.enable_zero123 = False
|
48 |
+
self.enable_dino = False
|
49 |
+
|
50 |
+
# renderer
|
51 |
+
self.renderer = Renderer(sh_degree=self.opt.sh_degree)
|
52 |
+
self.renderer.enable_dino = self.opt.lambda_dino > 0
|
53 |
+
self.renderer.gaussians.enable_dino = self.opt.lambda_dino > 0
|
54 |
+
self.renderer.gaussians.dino_feat_dim = 36
|
55 |
+
self.gaussain_scale_factor = 1
|
56 |
+
|
57 |
+
# input image
|
58 |
+
self.input_img = None
|
59 |
+
self.input_mask = None
|
60 |
+
self.input_img_torch = None
|
61 |
+
self.input_mask_torch = None
|
62 |
+
|
63 |
+
# training stuff
|
64 |
+
self.training = False
|
65 |
+
self.optimizer = None
|
66 |
+
self.step = 0
|
67 |
+
self.train_steps = 1 # steps per rendering loop
|
68 |
+
|
69 |
+
# load input data
|
70 |
+
self.load_input(self.opt.camera_path, self.opt.order_path)
|
71 |
+
|
72 |
+
self.cam = OrbitCamera(opt.W, opt.H, r=3, fovy=opt.fovy)
|
73 |
+
|
74 |
+
# override if provide a checkpoint
|
75 |
+
if self.opt.load is not None:
|
76 |
+
self.renderer.initialize(self.opt.load)
|
77 |
+
else:
|
78 |
+
# initialize gaussians to a blob
|
79 |
+
self.renderer.initialize(num_pts=self.opt.num_pts, radius=0.3, mode='sphere') # 0.5 for radius 3
|
80 |
+
|
81 |
+
# initialize gaussians to a carved voxel
|
82 |
+
# self.renderer.initialize(num_pts=self.opt.num_pts, radius=0.5, cameras=self.cams, masks=self.input_mask, mode='carve') # 0.5
|
83 |
+
|
84 |
+
def seed_everything(self):
|
85 |
+
try:
|
86 |
+
seed = int(self.seed)
|
87 |
+
except:
|
88 |
+
seed = np.random.randint(0, 1000000)
|
89 |
+
|
90 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
91 |
+
np.random.seed(seed)
|
92 |
+
torch.manual_seed(seed)
|
93 |
+
torch.cuda.manual_seed(seed)
|
94 |
+
torch.backends.cudnn.deterministic = True
|
95 |
+
torch.backends.cudnn.benchmark = True
|
96 |
+
|
97 |
+
self.last_seed = seed
|
98 |
+
|
99 |
+
def prepare_train(self):
|
100 |
+
self.step = 0
|
101 |
+
|
102 |
+
# setup training
|
103 |
+
self.renderer.gaussians.training_setup(self.opt)
|
104 |
+
# do progressive sh-level
|
105 |
+
self.renderer.gaussians.active_sh_degree = 0
|
106 |
+
self.optimizer = self.renderer.gaussians.optimizer
|
107 |
+
|
108 |
+
self.enable_sd = self.opt.lambda_sd > 0 and self.prompt != ""
|
109 |
+
self.enable_zero123 = self.opt.lambda_zero123 > 0 and self.input_img is not None
|
110 |
+
self.enable_dino = self.opt.lambda_dino > 0
|
111 |
+
|
112 |
+
# lazy load guidance model
|
113 |
+
if self.guidance_zero123 is None and self.enable_zero123:
|
114 |
+
print(f"[INFO] loading zero123...")
|
115 |
+
from sparseags.guidance_utils.zero123_6d_utils import Zero123
|
116 |
+
self.guidance_zero123 = Zero123(self.device, model_key='ashawkey/zero123-xl-diffusers')
|
117 |
+
print(f"[INFO] loaded zero123!")
|
118 |
+
self.guidance_zero123.opt = self.opt
|
119 |
+
self.guidance_zero123.num_views = self.num_views
|
120 |
+
|
121 |
+
# input image
|
122 |
+
if self.input_img is not None:
|
123 |
+
import torchvision.transforms as transforms
|
124 |
+
from PIL import Image
|
125 |
+
self.input_img_torch = torch.from_numpy(self.input_img).permute(0, 3, 1, 2).to(self.device)
|
126 |
+
self.input_mask_torch = torch.from_numpy(self.input_mask).permute(0, 3, 1, 2).to(self.device)
|
127 |
+
|
128 |
+
# prepare embeddings
|
129 |
+
with torch.no_grad():
|
130 |
+
if self.enable_zero123:
|
131 |
+
self.guidance_zero123.get_img_embeds(self.input_img_torch)
|
132 |
+
|
133 |
+
def train_step(self):
|
134 |
+
starter = torch.cuda.Event(enable_timing=True)
|
135 |
+
ender = torch.cuda.Event(enable_timing=True)
|
136 |
+
starter.record()
|
137 |
+
|
138 |
+
for _ in range(self.train_steps):
|
139 |
+
|
140 |
+
self.step += 1
|
141 |
+
step_ratio = min(1, self.step / self.opt.iters)
|
142 |
+
|
143 |
+
# update lr
|
144 |
+
self.renderer.gaussians.update_learning_rate(self.step)
|
145 |
+
|
146 |
+
loss = 0
|
147 |
+
|
148 |
+
### known view
|
149 |
+
for choice in range(self.num_views):
|
150 |
+
# For multiview training
|
151 |
+
cur_cam = self.cams[choice]
|
152 |
+
|
153 |
+
bg_size = self.renderer.gaussians.dino_feat_dim if self.enable_dino else 3
|
154 |
+
bg_color = torch.ones(
|
155 |
+
bg_size,
|
156 |
+
dtype=torch.float32,
|
157 |
+
device="cuda",
|
158 |
+
)
|
159 |
+
out = self.renderer.render(cur_cam, bg_color=bg_color)
|
160 |
+
|
161 |
+
# rgb loss
|
162 |
+
image = out["image"]
|
163 |
+
loss = loss + 10000 * step_ratio * F.mse_loss(image, self.input_img_torch[choice])
|
164 |
+
|
165 |
+
# mask loss
|
166 |
+
mask = out["alpha"]
|
167 |
+
loss = loss + 1000 * step_ratio * F.mse_loss(mask, self.input_mask_torch[choice])
|
168 |
+
|
169 |
+
# dino loss
|
170 |
+
if self.enable_dino:
|
171 |
+
feature = out["feature"]
|
172 |
+
loss = loss + 1000 * step_ratio * F.mse_loss(feature, self.guidance_dino.embeddings[choice])
|
173 |
+
|
174 |
+
### novel view (manual batch)
|
175 |
+
render_resolution = 128 if step_ratio < 0.3 else (256 if step_ratio < 0.6 else 512)
|
176 |
+
images = []
|
177 |
+
masks = []
|
178 |
+
vers, hors, radii = [], [], []
|
179 |
+
# avoid too large elevation (> 80 or < -80)
|
180 |
+
min_ver = max(-60 + np.array(self.opt.ref_polars).min(), -80) # + - 30 for co3D
|
181 |
+
max_ver = min(60 + np.array(self.opt.ref_polars).max(), 80)
|
182 |
+
|
183 |
+
for _ in range(self.opt.batch_size):
|
184 |
+
# render random view
|
185 |
+
ver = np.random.randint(min_ver, max_ver) - self.opt.ref_polars[0]
|
186 |
+
hor = np.random.randint(-180, 180)
|
187 |
+
radius = 0
|
188 |
+
|
189 |
+
vers.append(ver)
|
190 |
+
hors.append(hor)
|
191 |
+
radii.append(radius)
|
192 |
+
|
193 |
+
pose = orbit_camera(
|
194 |
+
self.opt.ref_polars[0] + ver,
|
195 |
+
self.opt.ref_azimuths[0] + hor,
|
196 |
+
np.array(self.opt.ref_radii).mean() + radius,
|
197 |
+
)
|
198 |
+
|
199 |
+
# Azimuth
|
200 |
+
# [-180, -135): -4, [-135, -90): -3, [-90, -45): -2, [-45, 0): -1
|
201 |
+
# [0, 45): 0, [45, 90): 1, [90, 135): 2, [135, 180): 3.
|
202 |
+
# Elevation: [0, 90): 0 [-90, 0): 1
|
203 |
+
idx_ver, idx_hor = int((self.opt.ref_polars[0]+ver) < 0), hor // 45
|
204 |
+
|
205 |
+
flag = 0
|
206 |
+
cx, cy = self.pp_pools[idx_ver, idx_hor+4].tolist()
|
207 |
+
cnt = 0
|
208 |
+
fx, fy = self.fx, self.fy
|
209 |
+
|
210 |
+
# in each iter we modify cx, cy, fx, fy to make sure the rendered object is at the center and has a reasonable size
|
211 |
+
while not flag:
|
212 |
+
|
213 |
+
if cnt >= 10:
|
214 |
+
# print(f"[ERROR] Something might be wrong here!")
|
215 |
+
break
|
216 |
+
|
217 |
+
flag_principal_point, flag_focal_length = 0, 0
|
218 |
+
|
219 |
+
# we modified the field of view. Otherwise, the rendered object will be too small
|
220 |
+
# cur_cam = FoVCamera(pose, render_resolution, render_resolution, self.fovy, self.fovx, self.cam.near, self.cam.far)
|
221 |
+
cur_cam = Camera(pose, render_resolution, render_resolution, fx, fy, cx, cy, self.cam.near, self.cam.far)
|
222 |
+
|
223 |
+
bg_size = self.renderer.gaussians.dino_feat_dim if self.enable_dino else 3
|
224 |
+
bg_color = torch.ones(bg_size, dtype=torch.float32, device="cuda") if np.random.rand() > self.opt.invert_bg_prob else torch.zeros(bg_size, dtype=torch.float32, device="cuda")
|
225 |
+
out = self.renderer.render(cur_cam, bg_color=bg_color)
|
226 |
+
|
227 |
+
image = out["image"].unsqueeze(0)
|
228 |
+
mask = out["alpha"].unsqueeze(0)
|
229 |
+
delta_xy = find_mask_center_and_translate(image.detach(), mask.detach()) / render_resolution * 256
|
230 |
+
|
231 |
+
# (1) check if the principal points are appropriate
|
232 |
+
if delta_xy[0].abs() > 10 or delta_xy[1].abs() > 10:
|
233 |
+
cx -= delta_xy[0]
|
234 |
+
cy -= delta_xy[1]
|
235 |
+
self.pp_pools[idx_ver, idx_hor+4] = torch.tensor([cx, cy]) # Update pp_pools
|
236 |
+
else:
|
237 |
+
flag_principal_point = 1
|
238 |
+
|
239 |
+
num_pixs_mask = (mask > 0.5).float().sum().item()
|
240 |
+
target_num_pixs = render_resolution ** 2 / (1.2 ** 2)
|
241 |
+
|
242 |
+
mask_to_compute = (mask > 0.5).squeeze().detach().cpu().numpy()
|
243 |
+
y_indices, x_indices = np.where(mask_to_compute > 0)
|
244 |
+
|
245 |
+
if len(x_indices) == 0 or len(y_indices) == 0:
|
246 |
+
# return None or some indication that there's no object in the mask
|
247 |
+
continue
|
248 |
+
|
249 |
+
# find the bounding box coordinates
|
250 |
+
x1, y1 = np.min(x_indices), np.min(y_indices)
|
251 |
+
x2, y2 = np.max(x_indices), np.max(y_indices)
|
252 |
+
|
253 |
+
bbox = np.array([x1, y1, x2, y2])
|
254 |
+
extents = (bbox[2:] - bbox[:2]).max()
|
255 |
+
num_pixs_mask = extents ** 2
|
256 |
+
|
257 |
+
# (2) check if the focal lengths are appropriate
|
258 |
+
if abs(num_pixs_mask - target_num_pixs) > 0.05 * render_resolution ** 2:
|
259 |
+
if num_pixs_mask == 0:
|
260 |
+
pass
|
261 |
+
else:
|
262 |
+
fx = fx * np.sqrt(target_num_pixs / num_pixs_mask)
|
263 |
+
fy = fy * np.sqrt(target_num_pixs / num_pixs_mask)
|
264 |
+
else:
|
265 |
+
flag_focal_length = 1
|
266 |
+
|
267 |
+
if flag_principal_point * flag_focal_length == 1:
|
268 |
+
flag = 1
|
269 |
+
|
270 |
+
cnt += 1
|
271 |
+
|
272 |
+
images.append(image)
|
273 |
+
masks.append(mask)
|
274 |
+
|
275 |
+
images = torch.cat(images, dim=0)
|
276 |
+
|
277 |
+
if self.enable_zero123:
|
278 |
+
target_RT = {
|
279 |
+
"c2w": pose,
|
280 |
+
"focal_length": np.array(fx, fy),
|
281 |
+
}
|
282 |
+
loss = loss + self.opt.lambda_zero123 * self.guidance_zero123.batch_train_step(images, target_RT, self.cams, step_ratio=step_ratio if self.opt.anneal_timestep else None)
|
283 |
+
|
284 |
+
if self.enable_dino:
|
285 |
+
loss_dino = self.guidance_dino.train_step(
|
286 |
+
images,
|
287 |
+
out["feature"],
|
288 |
+
step_ratio=step_ratio if self.opt.anneal_timestep else None
|
289 |
+
)
|
290 |
+
loss = loss + self.opt.lambda_dino * loss_dino
|
291 |
+
|
292 |
+
# optimize step
|
293 |
+
loss.backward()
|
294 |
+
self.optimizer.step()
|
295 |
+
self.optimizer.zero_grad()
|
296 |
+
latlons = [mat2latlon(cam.c2w[:3, 3]) for cam in self.cams]
|
297 |
+
if self.opt.opt_cam:
|
298 |
+
for i, cam in enumerate(self.cams):
|
299 |
+
w2c = cam.w2c @ SE3.exp(cam.cam_params.detach()).as_matrix()
|
300 |
+
w2c[:2, :3] *= -1
|
301 |
+
w2c[:2, 3] *= -1
|
302 |
+
self.camera_tracks[i].append(w2c.tolist())
|
303 |
+
self.opt.ref_polars = [float(cam[0]) for cam in latlons]
|
304 |
+
self.opt.ref_azimuths = [float(cam[1]) for cam in latlons]
|
305 |
+
self.opt.ref_radii = [float(cam[2]) for cam in latlons]
|
306 |
+
|
307 |
+
# densify and prune
|
308 |
+
if self.step >= self.opt.density_start_iter and self.step <= self.opt.density_end_iter:
|
309 |
+
viewspace_point_tensor, visibility_filter, radii = out["viewspace_points"], out["visibility_filter"], out["radii"]
|
310 |
+
self.renderer.gaussians.max_radii2D[visibility_filter] = torch.max(self.renderer.gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
|
311 |
+
self.renderer.gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
|
312 |
+
|
313 |
+
if self.step % self.opt.densification_interval == 0:
|
314 |
+
self.renderer.gaussians.densify_and_prune(self.opt.densify_grad_threshold, min_opacity=0.01, extent=4, max_screen_size=1)
|
315 |
+
|
316 |
+
# if self.step % self.opt.opacity_reset_interval == 0:
|
317 |
+
# self.renderer.gaussians.reset_opacity()
|
318 |
+
|
319 |
+
if self.step % 100 == 0 and self.renderer.gaussians.max_sh_degree != 0:
|
320 |
+
self.renderer.gaussians.oneupSHdegree()
|
321 |
+
|
322 |
+
ender.record()
|
323 |
+
torch.cuda.synchronize()
|
324 |
+
t = starter.elapsed_time(ender)
|
325 |
+
|
326 |
+
self.need_update = True
|
327 |
+
|
328 |
+
def load_input(self, camera_path, order_path=None):
|
329 |
+
# load image
|
330 |
+
print(f'[INFO] load data from {camera_path}...')
|
331 |
+
|
332 |
+
if order_path is not None:
|
333 |
+
with open(order_path, 'r') as f:
|
334 |
+
indices = json.load(f)
|
335 |
+
else:
|
336 |
+
indices = None
|
337 |
+
|
338 |
+
with open(camera_path, 'r') as f:
|
339 |
+
data = json.load(f)
|
340 |
+
|
341 |
+
self.cam_params = {}
|
342 |
+
for k, v in data.items():
|
343 |
+
if indices is None:
|
344 |
+
self.cam_params[k] = data[k]
|
345 |
+
else:
|
346 |
+
if int(k) in indices or k in indices:
|
347 |
+
self.cam_params[k] = data[k]
|
348 |
+
|
349 |
+
if self.opt.all_views:
|
350 |
+
for k, v in self.cam_params.items():
|
351 |
+
self.cam_params[k]['opt_cam'] = 1
|
352 |
+
self.cam_params[k]['flag'] = 1
|
353 |
+
else:
|
354 |
+
for k, v in self.cam_params.items():
|
355 |
+
if int(self.cam_params[k]['flag']):
|
356 |
+
self.cam_params[k]['opt_cam'] = 1
|
357 |
+
else:
|
358 |
+
self.cam_params[k]['opt_cam'] = 0
|
359 |
+
|
360 |
+
img_paths = [v["filepath"] for k, v in self.cam_params.items() if v["flag"]]
|
361 |
+
self.num_views = len(img_paths)
|
362 |
+
print(f"[INFO] Number of views: {self.num_views}")
|
363 |
+
|
364 |
+
for filepath in img_paths:
|
365 |
+
print(filepath)
|
366 |
+
|
367 |
+
images, masks = [], []
|
368 |
+
|
369 |
+
for i in range(self.num_views):
|
370 |
+
img = cv2.imread(img_paths[i], cv2.IMREAD_UNCHANGED)
|
371 |
+
if img.shape[-1] == 3:
|
372 |
+
if self.bg_remover is None:
|
373 |
+
self.bg_remover = rembg.new_session()
|
374 |
+
img = rembg.remove(img, session=self.bg_remover)
|
375 |
+
|
376 |
+
img = img.astype(np.float32) / 255.0
|
377 |
+
|
378 |
+
# Non-integer cropping creates non-zero mask values
|
379 |
+
input_mask = (img[..., 3:] > 0.5).astype(np.float32)
|
380 |
+
|
381 |
+
# white bg
|
382 |
+
input_img = img[..., :3] * input_mask + (1 - input_mask)
|
383 |
+
# bgr to rgb
|
384 |
+
input_img = input_img[..., ::-1].copy()
|
385 |
+
|
386 |
+
images.append(input_img), masks.append(input_mask)
|
387 |
+
|
388 |
+
images = np.stack(images, axis=0)
|
389 |
+
masks = np.stack(masks, axis=0)
|
390 |
+
self.input_img = images[:self.num_views]
|
391 |
+
self.input_mask = masks[:self.num_views]
|
392 |
+
self.all_input_images = images
|
393 |
+
|
394 |
+
self.cams = [CustomCamera(v, index=int(k), opt_pose=self.opt.opt_cam and v['opt_cam']) for k, v in self.cam_params.items() if v["flag"]]
|
395 |
+
cam_centers = [mat2latlon(cam.camera_center) for cam in self.cams]
|
396 |
+
self.opt.ref_polars = [float(cam[0]) for cam in cam_centers]
|
397 |
+
self.opt.ref_azimuths = [float(cam[1]) for cam in cam_centers]
|
398 |
+
self.opt.ref_radii = [float(cam[2]) for cam in cam_centers]
|
399 |
+
self.fx = np.array([cam.fx for cam in self.cams], dtype=np.float32).mean()
|
400 |
+
self.fy = np.array([cam.fy for cam in self.cams], dtype=np.float32).mean()
|
401 |
+
self.cx = 128
|
402 |
+
self.cy = 128
|
403 |
+
if self.opt.opt_cam:
|
404 |
+
self.camera_tracks = {}
|
405 |
+
for i, cam in enumerate(self.cams):
|
406 |
+
self.camera_tracks[i] = []
|
407 |
+
|
408 |
+
# Azimuth Mapping: [-180, -135): -4, [-135, -90): -3, [-90, -45): -2, [-45, 0): -1,
|
409 |
+
# [0, 45): 0, [45, 90): 1, [90, 135): 2, [135, 180): 3.
|
410 |
+
# Elevation Mapping: [0, 90): 0, [-90, 0): 1.
|
411 |
+
|
412 |
+
# Principal Point Pool: Tensor (2, 8, 2), where:
|
413 |
+
# - 2: Elevation groups, 8: Azimuth intervals, 2: x, y coordinates (init to 128).
|
414 |
+
|
415 |
+
# we created a "pool" for principal points
|
416 |
+
# we use these principal points to render image to make sure object is at the center
|
417 |
+
self.pp_pools = torch.full((2, 8, 2), 128)
|
418 |
+
if self.opt.opt_cam:
|
419 |
+
self.renderer.gaussians.cam_params = [cam.cam_params for cam in self.cams[:] if cam.opt_pose]
|
420 |
+
|
421 |
+
@torch.no_grad()
|
422 |
+
def save_video(self, post_fix=None):
|
423 |
+
xyz = self.renderer.gaussians._xyz
|
424 |
+
center = self.renderer.gaussians._xyz.mean(dim=0)
|
425 |
+
squared_distances = torch.sum((xyz - center) ** 2, dim=1)
|
426 |
+
max_distance_squared = torch.max(squared_distances)
|
427 |
+
radius = torch.sqrt(max_distance_squared) + 1.0
|
428 |
+
radius = radius.detach().cpu().numpy()
|
429 |
+
|
430 |
+
render_resolution = 256
|
431 |
+
images = []
|
432 |
+
frame_rate = 30
|
433 |
+
image_size = (render_resolution, render_resolution) # Size of each image
|
434 |
+
video_path = self.opt.save_path + f'_rendered_video_{post_fix}.mp4'
|
435 |
+
|
436 |
+
azimuth = np.arange(0, 360, 3, dtype=np.int32)
|
437 |
+
|
438 |
+
for azi in tqdm.tqdm(azimuth):
|
439 |
+
target = center.detach().cpu().numpy()
|
440 |
+
pose = orbit_camera(-30, azi, radius, target=target)
|
441 |
+
cur_cam = FoVCamera(
|
442 |
+
pose,
|
443 |
+
render_resolution,
|
444 |
+
render_resolution,
|
445 |
+
self.cam.fovy,
|
446 |
+
self.cam.fovx,
|
447 |
+
self.cam.near,
|
448 |
+
self.cam.far,
|
449 |
+
)
|
450 |
+
|
451 |
+
out = self.renderer.render(cur_cam)
|
452 |
+
img = out["image"].detach().cpu().numpy() # [3, H, W] in [0, 1]
|
453 |
+
img = np.transpose(img, (1, 2, 0))
|
454 |
+
image = (img * 255).astype(np.uint8)
|
455 |
+
images.append(image)
|
456 |
+
|
457 |
+
images = np.stack(images, axis=0)
|
458 |
+
# ~4 seconds, 120 frames at 30 fps
|
459 |
+
import imageio
|
460 |
+
imageio.mimwrite(video_path, images, fps=30, quality=8, macro_block_size=1)
|
461 |
+
|
462 |
+
|
463 |
+
@torch.no_grad()
|
464 |
+
def save_model(self, mode='geo', texture_size=1024):
|
465 |
+
os.makedirs(self.opt.outdir, exist_ok=True)
|
466 |
+
if mode == 'geo':
|
467 |
+
path = os.path.join(self.opt.outdir, self.opt.save_path + '_mesh.ply')
|
468 |
+
mesh = self.renderer.gaussians.extract_mesh(path, self.opt.density_thresh)
|
469 |
+
mesh.write_ply(path)
|
470 |
+
|
471 |
+
elif mode == 'geo+tex':
|
472 |
+
path = os.path.join(self.opt.outdir, self.opt.save_path + '_mesh.' + self.opt.mesh_format)
|
473 |
+
mesh = self.renderer.gaussians.extract_mesh(path, self.opt.density_thresh)
|
474 |
+
|
475 |
+
# perform texture extraction
|
476 |
+
print(f"[INFO] unwrap uv...")
|
477 |
+
h = w = texture_size
|
478 |
+
mesh.auto_uv()
|
479 |
+
mesh.auto_normal()
|
480 |
+
|
481 |
+
albedo = torch.zeros((h, w, 3), device=self.device, dtype=torch.float32)
|
482 |
+
cnt = torch.zeros((h, w, 1), device=self.device, dtype=torch.float32)
|
483 |
+
if self.enable_dino:
|
484 |
+
feature = torch.zeros((h, w, self.renderer.gaussians.dino_feat_dim), device=self.device, dtype=torch.float32)
|
485 |
+
|
486 |
+
# self.prepare_train() # tmp fix for not loading 0123
|
487 |
+
# vers = [0]
|
488 |
+
# hors = [0]
|
489 |
+
vers = [0] * 8 + [-45] * 8 + [45] * 8 + [-89.9, 89.9]
|
490 |
+
hors = [0, 45, -45, 90, -90, 135, -135, 180] * 3 + [0, 0]
|
491 |
+
|
492 |
+
render_resolution = 512
|
493 |
+
|
494 |
+
import nvdiffrast.torch as dr
|
495 |
+
|
496 |
+
if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'):
|
497 |
+
glctx = dr.RasterizeGLContext()
|
498 |
+
else:
|
499 |
+
glctx = dr.RasterizeCudaContext()
|
500 |
+
|
501 |
+
for ver, hor in zip(vers, hors):
|
502 |
+
# render image
|
503 |
+
pose = orbit_camera(ver, hor, self.cam.radius)
|
504 |
+
|
505 |
+
cur_cam = FoVCamera(
|
506 |
+
pose,
|
507 |
+
render_resolution,
|
508 |
+
render_resolution,
|
509 |
+
self.cam.fovy,
|
510 |
+
self.cam.fovx,
|
511 |
+
self.cam.near,
|
512 |
+
self.cam.far,
|
513 |
+
)
|
514 |
+
|
515 |
+
cur_out = self.renderer.render(cur_cam)
|
516 |
+
|
517 |
+
rgbs = cur_out["image"].unsqueeze(0) # [1, 3, H, W] in [0, 1]
|
518 |
+
if self.enable_dino:
|
519 |
+
features = cur_out["feature"].unsqueeze(0) # [1, 384, 512, 512]
|
520 |
+
|
521 |
+
# enhance texture quality with zero123 [not working well]
|
522 |
+
# if self.opt.guidance_model == 'zero123':
|
523 |
+
# rgbs = self.guidance.refine(rgbs, [ver], [hor], [0])
|
524 |
+
# import kiui
|
525 |
+
# kiui.vis.plot_image(rgbs)
|
526 |
+
|
527 |
+
# get coordinate in texture image
|
528 |
+
pose = torch.from_numpy(pose.astype(np.float32)).to(self.device)
|
529 |
+
proj = torch.from_numpy(self.cam.perspective.astype(np.float32)).to(self.device)
|
530 |
+
|
531 |
+
v_cam = torch.matmul(F.pad(mesh.v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
|
532 |
+
v_clip = v_cam @ proj.T
|
533 |
+
rast, rast_db = dr.rasterize(glctx, v_clip, mesh.f, (render_resolution, render_resolution))
|
534 |
+
|
535 |
+
depth, _ = dr.interpolate(-v_cam[..., [2]], rast, mesh.f) # [1, H, W, 1]
|
536 |
+
depth = depth.squeeze(0) # [H, W, 1]
|
537 |
+
|
538 |
+
alpha = (rast[0, ..., 3:] > 0).float()
|
539 |
+
|
540 |
+
uvs, _ = dr.interpolate(mesh.vt.unsqueeze(0), rast, mesh.ft) # [1, 512, 512, 2] in [0, 1]
|
541 |
+
|
542 |
+
# use normal to produce a back-project mask
|
543 |
+
normal, _ = dr.interpolate(mesh.vn.unsqueeze(0).contiguous(), rast, mesh.fn)
|
544 |
+
normal = safe_normalize(normal[0])
|
545 |
+
|
546 |
+
# rotated normal (where [0, 0, 1] always faces camera)
|
547 |
+
rot_normal = normal @ pose[:3, :3]
|
548 |
+
viewcos = rot_normal[..., [2]]
|
549 |
+
|
550 |
+
mask = (alpha > 0) & (viewcos > 0.5) # [H, W, 1]
|
551 |
+
mask = mask.view(-1)
|
552 |
+
|
553 |
+
uvs = uvs.view(-1, 2).clamp(0, 1)[mask]
|
554 |
+
rgbs = rgbs.view(3, -1).permute(1, 0)[mask].contiguous()
|
555 |
+
|
556 |
+
# update texture image
|
557 |
+
cur_albedo, cur_cnt = mipmap_linear_grid_put_2d(
|
558 |
+
h, w,
|
559 |
+
uvs[..., [1, 0]] * 2 - 1,
|
560 |
+
rgbs,
|
561 |
+
min_resolution=256,
|
562 |
+
return_count=True,
|
563 |
+
)
|
564 |
+
|
565 |
+
if self.enable_dino:
|
566 |
+
features = features.view(features.shape[1], -1).permute(1, 0)[mask].contiguous()
|
567 |
+
cur_feature, _ = mipmap_linear_grid_put_2d(
|
568 |
+
h, w,
|
569 |
+
uvs[..., [1, 0]] * 2 - 1,
|
570 |
+
features,
|
571 |
+
min_resolution=256,
|
572 |
+
return_count=True,
|
573 |
+
)
|
574 |
+
|
575 |
+
# albedo += cur_albedo
|
576 |
+
# cnt += cur_cnt
|
577 |
+
mask = cnt.squeeze(-1) < 0.1
|
578 |
+
albedo[mask] += cur_albedo[mask]
|
579 |
+
cnt[mask] += cur_cnt[mask]
|
580 |
+
|
581 |
+
if self.enable_dino:
|
582 |
+
feature[mask] += cur_feature[mask]
|
583 |
+
|
584 |
+
mask = cnt.squeeze(-1) > 0
|
585 |
+
albedo[mask] = albedo[mask] / cnt[mask].repeat(1, 3)
|
586 |
+
|
587 |
+
if self.enable_dino:
|
588 |
+
feature[mask] = feature[mask] / cnt[mask].repeat(1, feature.shape[-1])
|
589 |
+
|
590 |
+
mask = mask.view(h, w)
|
591 |
+
|
592 |
+
albedo = albedo.detach().cpu().numpy()
|
593 |
+
mask = mask.detach().cpu().numpy()
|
594 |
+
|
595 |
+
if self.enable_dino:
|
596 |
+
feature = feature.detach().cpu().numpy()
|
597 |
+
|
598 |
+
# dilate texture
|
599 |
+
from sklearn.neighbors import NearestNeighbors
|
600 |
+
from scipy.ndimage import binary_dilation, binary_erosion
|
601 |
+
|
602 |
+
inpaint_region = binary_dilation(mask, iterations=32)
|
603 |
+
inpaint_region[mask] = 0
|
604 |
+
|
605 |
+
search_region = mask.copy()
|
606 |
+
not_search_region = binary_erosion(search_region, iterations=3)
|
607 |
+
search_region[not_search_region] = 0
|
608 |
+
|
609 |
+
search_coords = np.stack(np.nonzero(search_region), axis=-1)
|
610 |
+
inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1)
|
611 |
+
|
612 |
+
knn = NearestNeighbors(n_neighbors=1, algorithm="kd_tree").fit(
|
613 |
+
search_coords
|
614 |
+
)
|
615 |
+
_, indices = knn.kneighbors(inpaint_coords)
|
616 |
+
|
617 |
+
albedo[tuple(inpaint_coords.T)] = albedo[tuple(search_coords[indices[:, 0]].T)]
|
618 |
+
|
619 |
+
mesh.albedo = torch.from_numpy(albedo).to(self.device)
|
620 |
+
# mesh.write(path)
|
621 |
+
|
622 |
+
if self.enable_dino:
|
623 |
+
feature[tuple(inpaint_coords.T)] = feature[tuple(search_coords[indices[:, 0]].T)]
|
624 |
+
mesh.feature = torch.from_numpy(feature).to(self.device)
|
625 |
+
|
626 |
+
mesh.write(path, self.enable_dino)
|
627 |
+
|
628 |
+
else:
|
629 |
+
path = os.path.join(self.opt.outdir, self.opt.save_path + '_model.ply')
|
630 |
+
self.renderer.gaussians.save_ply(path)
|
631 |
+
|
632 |
+
print(f"[INFO] save model to {path}.")
|
633 |
+
|
634 |
+
# no gui mode
|
635 |
+
def train(self, iters=500):
|
636 |
+
if iters > 0:
|
637 |
+
self.prepare_train()
|
638 |
+
for i in tqdm.trange(iters):
|
639 |
+
self.train_step()
|
640 |
+
# do a last prune
|
641 |
+
self.renderer.gaussians.prune(min_opacity=0.01, extent=1, max_screen_size=1)
|
642 |
+
if self.opt.opt_cam:
|
643 |
+
for cam in self.cams:
|
644 |
+
try:
|
645 |
+
self.cam_params[str(cam.index)]["R"] = cam.rotation.tolist()
|
646 |
+
self.cam_params[str(cam.index)]["T"] = cam.translation.tolist()
|
647 |
+
except KeyError:
|
648 |
+
self.cam_params[f"{cam.index:03}"]["R"] = cam.rotation.tolist()
|
649 |
+
self.cam_params[f"{cam.index:03}"]["T"] = cam.translation.tolist()
|
650 |
+
with open(self.opt.camera_path.replace(".json", "_updated.json"), "w") as file:
|
651 |
+
json.dump(self.cam_params, file, indent=4)
|
652 |
+
self.save_model(mode='model')
|
653 |
+
self.save_model(mode='geo+tex')
|
654 |
+
|
655 |
+
|
656 |
+
if __name__ == "__main__":
|
657 |
+
import argparse
|
658 |
+
from omegaconf import OmegaConf
|
659 |
+
|
660 |
+
parser = argparse.ArgumentParser()
|
661 |
+
parser.add_argument("--config", required=True, help="path to the yaml config file")
|
662 |
+
args, extras = parser.parse_known_args()
|
663 |
+
|
664 |
+
# override default config from cli
|
665 |
+
opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras))
|
666 |
+
|
667 |
+
gui = GUI(opt)
|
668 |
+
|
669 |
+
gui.train(opt.iters)
|
sparseags/main_stage2.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import json
|
4 |
+
import time
|
5 |
+
import copy
|
6 |
+
import tqdm
|
7 |
+
import rembg
|
8 |
+
import trimesh
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import numpy as np
|
12 |
+
import pandas as pd
|
13 |
+
|
14 |
+
from kiui.lpips import LPIPS
|
15 |
+
|
16 |
+
import sys
|
17 |
+
sys.path.append('./')
|
18 |
+
|
19 |
+
from sparseags.cam_utils import orbit_camera, OrbitCamera, mat2latlon, find_mask_center_and_translate
|
20 |
+
from sparseags.render_utils.gs_renderer import CustomCamera
|
21 |
+
from sparseags.mesh_utils.mesh_renderer import Renderer
|
22 |
+
|
23 |
+
|
24 |
+
class GUI:
|
25 |
+
def __init__(self, opt):
|
26 |
+
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
|
27 |
+
self.gui = opt.gui # enable gui
|
28 |
+
self.W = opt.W
|
29 |
+
self.H = opt.H
|
30 |
+
|
31 |
+
self.mode = "image"
|
32 |
+
self.seed = 0
|
33 |
+
|
34 |
+
self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32)
|
35 |
+
self.need_update = True # update buffer_image
|
36 |
+
|
37 |
+
# models
|
38 |
+
self.device = torch.device("cuda")
|
39 |
+
self.bg_remover = None
|
40 |
+
|
41 |
+
self.guidance_sd = None
|
42 |
+
self.guidance_zero123 = None
|
43 |
+
self.guidance_dino = None
|
44 |
+
|
45 |
+
self.enable_sd = False
|
46 |
+
self.enable_zero123 = False
|
47 |
+
self.enable_dino = False
|
48 |
+
|
49 |
+
# renderer
|
50 |
+
self.renderer = Renderer(opt).to(self.device)
|
51 |
+
|
52 |
+
# input image
|
53 |
+
self.input_img = None
|
54 |
+
self.input_mask = None
|
55 |
+
self.input_img_torch = None
|
56 |
+
self.input_mask_torch = None
|
57 |
+
self.overlay_input_img = False
|
58 |
+
self.overlay_input_img_ratio = 0.5
|
59 |
+
|
60 |
+
# input text
|
61 |
+
self.prompt = ""
|
62 |
+
self.negative_prompt = ""
|
63 |
+
|
64 |
+
# training stuff
|
65 |
+
self.training = False
|
66 |
+
self.optimizer = None
|
67 |
+
self.step = 0
|
68 |
+
self.train_steps = 1 # steps per rendering loop
|
69 |
+
|
70 |
+
# load input data
|
71 |
+
self.load_input(self.opt.camera_path, self.opt.order_path)
|
72 |
+
|
73 |
+
# override prompt from cmdline
|
74 |
+
if self.opt.prompt is not None:
|
75 |
+
self.prompt = self.opt.prompt
|
76 |
+
if self.opt.negative_prompt is not None:
|
77 |
+
self.negative_prompt = self.opt.negative_prompt
|
78 |
+
|
79 |
+
def seed_everything(self):
|
80 |
+
try:
|
81 |
+
seed = int(self.seed)
|
82 |
+
except:
|
83 |
+
seed = np.random.randint(0, 1000000)
|
84 |
+
|
85 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
86 |
+
np.random.seed(seed)
|
87 |
+
torch.manual_seed(seed)
|
88 |
+
torch.cuda.manual_seed(seed)
|
89 |
+
torch.backends.cudnn.deterministic = True
|
90 |
+
torch.backends.cudnn.benchmark = True
|
91 |
+
|
92 |
+
self.last_seed = seed
|
93 |
+
|
94 |
+
def prepare_train(self):
|
95 |
+
|
96 |
+
self.step = 0
|
97 |
+
|
98 |
+
# setup training
|
99 |
+
self.optimizer = torch.optim.Adam(self.renderer.get_params())
|
100 |
+
|
101 |
+
cameras = [CustomCamera(v, index=int(k)) for k, v in self.cam_params.items() if v["flag"]]
|
102 |
+
cam_centers = [mat2latlon(cam.camera_center) for cam in cameras]
|
103 |
+
self.opt.ref_polars = [float(cam[0]) for cam in cam_centers]
|
104 |
+
self.opt.ref_azimuths = [float(cam[1]) for cam in cam_centers]
|
105 |
+
self.opt.ref_radii = [float(cam[2]) for cam in cam_centers]
|
106 |
+
self.cams = [(cam.c2w, cam.perspective, cam.focal_length) for cam in cameras]
|
107 |
+
self.cam = copy.deepcopy(cameras[0])
|
108 |
+
|
109 |
+
# Azimuth Mapping: [-180, -135): -4, [-135, -90): -3, [-90, -45): -2, [-45, 0): -1,
|
110 |
+
# [0, 45): 0, [45, 90): 1, [90, 135): 2, [135, 180): 3.
|
111 |
+
# Elevation Mapping: [0, 90): 0, [-90, 0): 1.
|
112 |
+
|
113 |
+
# Principal Point Pool: Tensor (2, 8, 2), where:
|
114 |
+
# - 2: Elevation groups, 8: Azimuth intervals, 2: x, y coordinates (init to 128).
|
115 |
+
|
116 |
+
# we created a "pool" for principal points
|
117 |
+
# we use these principal points to render image to make sure object is at the center
|
118 |
+
self.pp_pools = torch.full((2, 8, 2), 128)
|
119 |
+
|
120 |
+
# The intrinsics is the average over all cams
|
121 |
+
self.cam.fx = np.array([cam.fx for cam in cameras], dtype=np.float32).mean()
|
122 |
+
self.cam.fy = np.array([cam.fy for cam in cameras], dtype=np.float32).mean()
|
123 |
+
self.cam.cx = np.array([cam.cx for cam in cameras], dtype=np.float32).mean()
|
124 |
+
self.cam.cy = np.array([cam.cy for cam in cameras], dtype=np.float32).mean()
|
125 |
+
|
126 |
+
self.enable_sd = self.opt.lambda_sd > 0 and self.prompt != ""
|
127 |
+
self.enable_zero123 = self.opt.lambda_zero123 > 0 and self.input_img is not None
|
128 |
+
self.enable_dino = self.opt.lambda_dino > 0
|
129 |
+
|
130 |
+
# lazy load guidance model
|
131 |
+
if self.guidance_sd is None and self.enable_sd:
|
132 |
+
if self.opt.mvdream:
|
133 |
+
print(f"[INFO] loading MVDream...")
|
134 |
+
from guidance.mvdream_utils import MVDream
|
135 |
+
self.guidance_sd = MVDream(self.device)
|
136 |
+
print(f"[INFO] loaded MVDream!")
|
137 |
+
else:
|
138 |
+
print(f"[INFO] loading SD...")
|
139 |
+
from guidance.sd_utils import StableDiffusion
|
140 |
+
self.guidance_sd = StableDiffusion(self.device)
|
141 |
+
print(f"[INFO] loaded SD!")
|
142 |
+
|
143 |
+
if self.guidance_zero123 is None and self.enable_zero123:
|
144 |
+
print(f"[INFO] loading zero123...")
|
145 |
+
from sparseags.guidance_utils.zero123_6d_utils import Zero123
|
146 |
+
self.guidance_zero123 = Zero123(self.device, model_key='ashawkey/zero123-xl-diffusers')
|
147 |
+
print(f"[INFO] loaded zero123!")
|
148 |
+
|
149 |
+
if self.guidance_dino is None and self.enable_dino:
|
150 |
+
print(f"[INFO] loading dino...")
|
151 |
+
from guidance.dino_utils import Dino
|
152 |
+
self.guidance_dino = Dino(self.device, n_components=36, model_key="dinov2_vits14")
|
153 |
+
self.guidance_dino.fit_pca(self.all_input_images)
|
154 |
+
print(f"[INFO] loaded dino!")
|
155 |
+
|
156 |
+
# input image
|
157 |
+
if self.input_img is not None:
|
158 |
+
self.input_img_torch = torch.from_numpy(self.input_img).permute(0, 3, 1, 2).to(self.device)
|
159 |
+
self.input_img_torch = F.interpolate(self.input_img_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False)
|
160 |
+
|
161 |
+
self.input_mask_torch = torch.from_numpy(self.input_mask).permute(0, 3, 1, 2).to(self.device)
|
162 |
+
self.input_mask_torch = F.interpolate(self.input_mask_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False)
|
163 |
+
self.input_img_torch_channel_last = self.input_img_torch.permute(0, 2, 3, 1).contiguous()
|
164 |
+
|
165 |
+
# prepare embeddings
|
166 |
+
with torch.no_grad():
|
167 |
+
|
168 |
+
if self.enable_sd:
|
169 |
+
self.guidance_sd.get_text_embeds([self.prompt], [self.negative_prompt])
|
170 |
+
|
171 |
+
if self.enable_zero123:
|
172 |
+
self.guidance_zero123.get_img_embeds(self.input_img_torch)
|
173 |
+
|
174 |
+
if self.enable_dino:
|
175 |
+
self.guidance_dino.embeddings = self.guidance_dino.get_dino_embeds(self.input_img_torch, upscale=True, reduced=True, learned_up=True) # [8, 18, 18, 36]
|
176 |
+
|
177 |
+
def train_step(self):
|
178 |
+
starter = torch.cuda.Event(enable_timing=True)
|
179 |
+
ender = torch.cuda.Event(enable_timing=True)
|
180 |
+
starter.record()
|
181 |
+
|
182 |
+
|
183 |
+
for _ in range(self.train_steps):
|
184 |
+
|
185 |
+
self.step += 1
|
186 |
+
step_ratio = min(1, self.step / self.opt.iters_refine)
|
187 |
+
|
188 |
+
loss = 0
|
189 |
+
|
190 |
+
### known view
|
191 |
+
for choice in range(self.num_views):
|
192 |
+
ssaa = min(2.0, max(0.125, 2 * np.random.random()))
|
193 |
+
out = self.renderer.render(*self.cams[choice][:2], self.opt.ref_size, self.opt.ref_size, ssaa=ssaa)
|
194 |
+
|
195 |
+
# rgb loss
|
196 |
+
image = out["image"] # [H, W, 3] in [0, 1]
|
197 |
+
valid_mask = (out["alpha"] > 0).detach()
|
198 |
+
loss = loss + F.mse_loss(image * valid_mask, self.input_img_torch_channel_last[choice] * valid_mask)
|
199 |
+
|
200 |
+
if self.enable_dino:
|
201 |
+
feature = out["feature"]
|
202 |
+
loss = loss + F.mse_loss(feature * valid_mask, self.guidance_dino.embeddings[choice] * valid_mask)
|
203 |
+
|
204 |
+
### novel view (manual batch)
|
205 |
+
render_resolution = 512
|
206 |
+
images = []
|
207 |
+
vers, hors, radii = [], [], []
|
208 |
+
# avoid too large elevation (> 80 or < -80), and make sure it always cover [-30, 30]
|
209 |
+
# min_ver = max(min(-30, -30 - self.opt.elevation), -80 - self.opt.elevation)
|
210 |
+
# max_ver = min(max(30, 30 - self.opt.elevation), 80 - self.opt.elevation)
|
211 |
+
# min_ver = max(min(-30, -30 + np.array(self.opt.ref_polars).min()), -80)
|
212 |
+
# max_ver = min(max(30, 30 + np.array(self.opt.ref_polars).max()), 80)
|
213 |
+
min_ver = max(-30 + np.array(self.opt.ref_polars).min(), -80)
|
214 |
+
max_ver = min(30 + np.array(self.opt.ref_polars).max(), 80)
|
215 |
+
|
216 |
+
for _ in range(self.opt.batch_size):
|
217 |
+
|
218 |
+
# render random view
|
219 |
+
ver = np.random.randint(min_ver, max_ver) - self.opt.ref_polars[0]
|
220 |
+
hor = np.random.randint(-180, 180)
|
221 |
+
radius = 0
|
222 |
+
|
223 |
+
vers.append(ver)
|
224 |
+
hors.append(hor)
|
225 |
+
radii.append(radius)
|
226 |
+
|
227 |
+
pose = orbit_camera(self.opt.ref_polars[0] + ver, self.opt.ref_azimuths[0] + hor, np.array(self.opt.ref_radii).mean() + radius)
|
228 |
+
|
229 |
+
# random render resolution
|
230 |
+
ssaa = min(2.0, max(0.125, 2 * np.random.random()))
|
231 |
+
|
232 |
+
# Azimuth
|
233 |
+
# [-180, -135): -4, [-135, -90): -3, [-90, -45): -2, [-45, 0): -1
|
234 |
+
# [0, 45): 0, [45, 90): 1, [90, 135): 2, [135, 180): 3.
|
235 |
+
# Elevation: [0, 90): 0 [-90, 0): 1
|
236 |
+
idx_ver, idx_hor = int((self.opt.ref_polars[0]+ver) < 0), hor // 45
|
237 |
+
|
238 |
+
flag = 0
|
239 |
+
cx, cy = self.pp_pools[idx_ver, idx_hor+4].tolist()
|
240 |
+
cnt = 0
|
241 |
+
|
242 |
+
while not flag:
|
243 |
+
|
244 |
+
self.cam.cx = cx
|
245 |
+
self.cam.cy = cy
|
246 |
+
|
247 |
+
if cnt >= 5:
|
248 |
+
print(f"[ERROR] Something must be wrong here!")
|
249 |
+
break
|
250 |
+
|
251 |
+
# We modified the field of view. Otherwise, the rendered object will be too small
|
252 |
+
out = self.renderer.render(pose, self.cam.perspective, render_resolution, render_resolution, ssaa=ssaa)
|
253 |
+
|
254 |
+
image = out["image"]
|
255 |
+
image = image.permute(2, 0, 1).contiguous().unsqueeze(0)
|
256 |
+
mask = out["alpha"] > 0
|
257 |
+
mask = mask.permute(2, 0, 1).contiguous().unsqueeze(0)
|
258 |
+
delta_xy = find_mask_center_and_translate(image.detach(), mask.detach()) / render_resolution * 256
|
259 |
+
|
260 |
+
if delta_xy[0].abs() > 10 or delta_xy[1].abs() > 10:
|
261 |
+
cx -= delta_xy[0]
|
262 |
+
cy -= delta_xy[1]
|
263 |
+
self.pp_pools[idx_ver, idx_hor+4] = torch.tensor([cx, cy]) # Update pp_pools
|
264 |
+
cnt += 1
|
265 |
+
else:
|
266 |
+
flag = 1
|
267 |
+
|
268 |
+
images.append(image)
|
269 |
+
|
270 |
+
images = torch.cat(images, dim=0)
|
271 |
+
|
272 |
+
# guidance loss
|
273 |
+
strength = step_ratio * 0.15 + 0.8
|
274 |
+
if self.enable_zero123:
|
275 |
+
v1 = torch.stack([torch.tensor([radius]) + self.opt.ref_radii[0], torch.deg2rad(torch.tensor([ver]) + self.opt.ref_polars[0]), torch.deg2rad(torch.tensor([hor]) + self.opt.ref_azimuths[0])], dim=-1) # polar,azimuth,radius are all actually delta wrt default
|
276 |
+
v2 = torch.stack([torch.tensor(self.opt.ref_radii), torch.deg2rad(torch.tensor(self.opt.ref_polars)), torch.deg2rad(torch.tensor(self.opt.ref_azimuths))], dim=-1)
|
277 |
+
angles = torch.rad2deg(self.guidance_zero123.angle_between(v1, v2)).to(self.device)
|
278 |
+
choice = torch.argmin(angles.squeeze()).item()
|
279 |
+
|
280 |
+
cond_RT = {
|
281 |
+
"c2w": self.cams[choice][0],
|
282 |
+
"focal_length": self.cams[choice][-1],
|
283 |
+
}
|
284 |
+
target_RT = {
|
285 |
+
"c2w": pose,
|
286 |
+
"focal_length": np.array(self.cam.fx, self.cam.fy),
|
287 |
+
}
|
288 |
+
cam_embed = self.guidance_zero123.get_cam_embeddings_6D(target_RT, cond_RT)
|
289 |
+
|
290 |
+
# Additionally add an idx parameter to choose the correct viewpoints
|
291 |
+
refined_images = self.guidance_zero123.refine(images, cam_embed, strength=strength, idx=choice).float()
|
292 |
+
refined_images = F.interpolate(refined_images, (render_resolution, render_resolution), mode="bilinear", align_corners=False)
|
293 |
+
loss = loss + self.opt.lambda_zero123 * F.mse_loss(images, refined_images)
|
294 |
+
|
295 |
+
if self.enable_dino:
|
296 |
+
loss_dino = self.guidance_dino.train_step(
|
297 |
+
images,
|
298 |
+
out["feature"].permute(2, 0, 1).contiguous(),
|
299 |
+
step_ratio=step_ratio if self.opt.anneal_timestep else None
|
300 |
+
)
|
301 |
+
loss = loss + self.opt.lambda_dino * loss_dino
|
302 |
+
|
303 |
+
# optimize step
|
304 |
+
loss.backward()
|
305 |
+
self.optimizer.step()
|
306 |
+
self.optimizer.zero_grad()
|
307 |
+
|
308 |
+
ender.record()
|
309 |
+
torch.cuda.synchronize()
|
310 |
+
t = starter.elapsed_time(ender)
|
311 |
+
|
312 |
+
self.need_update = True
|
313 |
+
|
314 |
+
def load_input(self, camera_path, order_path=None):
|
315 |
+
# load image
|
316 |
+
print(f'[INFO] load data from {camera_path}...')
|
317 |
+
|
318 |
+
if order_path is not None:
|
319 |
+
with open(order_path, 'r') as f:
|
320 |
+
indices = json.load(f)
|
321 |
+
else:
|
322 |
+
indices = None
|
323 |
+
|
324 |
+
with open(camera_path, 'r') as f:
|
325 |
+
data = json.load(f)
|
326 |
+
|
327 |
+
self.cam_params = {}
|
328 |
+
for k, v in data.items():
|
329 |
+
if indices is None:
|
330 |
+
self.cam_params[k] = data[k]
|
331 |
+
else:
|
332 |
+
if int(k) in indices or k in indices:
|
333 |
+
self.cam_params[k] = data[k]
|
334 |
+
|
335 |
+
if self.opt.all_views:
|
336 |
+
v['flag'] = 1
|
337 |
+
|
338 |
+
img_paths = [v["filepath"] for k, v in self.cam_params.items() if v["flag"]]
|
339 |
+
self.num_views = len(img_paths)
|
340 |
+
print(f"[INFO] Number of views: {self.num_views}")
|
341 |
+
|
342 |
+
for filepath in img_paths:
|
343 |
+
print(filepath)
|
344 |
+
|
345 |
+
images, masks = [], []
|
346 |
+
|
347 |
+
for i in range(len(img_paths)):
|
348 |
+
img = cv2.imread(img_paths[i], cv2.IMREAD_UNCHANGED)
|
349 |
+
if img.shape[-1] == 3:
|
350 |
+
if self.bg_remover is None:
|
351 |
+
self.bg_remover = rembg.new_session()
|
352 |
+
img = rembg.remove(img, session=self.bg_remover)
|
353 |
+
|
354 |
+
img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_AREA)
|
355 |
+
img = img.astype(np.float32) / 255.0
|
356 |
+
|
357 |
+
input_mask = img[..., 3:]
|
358 |
+
# white bg
|
359 |
+
input_img = img[..., :3] * input_mask + (1 - input_mask)
|
360 |
+
# bgr to rgb
|
361 |
+
input_img = input_img[..., ::-1].copy()
|
362 |
+
|
363 |
+
images.append(input_img), masks.append(input_mask)
|
364 |
+
|
365 |
+
images = np.stack(images, axis=0)
|
366 |
+
masks = np.stack(masks, axis=0)
|
367 |
+
self.input_img = images[:self.num_views]
|
368 |
+
self.input_mask = masks[:self.num_views]
|
369 |
+
self.all_input_images = images
|
370 |
+
|
371 |
+
def save_model(self):
|
372 |
+
os.makedirs(self.opt.outdir, exist_ok=True)
|
373 |
+
|
374 |
+
path = os.path.join(self.opt.outdir, self.opt.save_path + '.' + self.opt.mesh_format)
|
375 |
+
self.renderer.export_mesh(path)
|
376 |
+
|
377 |
+
print(f"[INFO] save model to {path}.")
|
378 |
+
|
379 |
+
# no gui mode
|
380 |
+
def train(self, iters=500):
|
381 |
+
if iters > 0:
|
382 |
+
self.prepare_train()
|
383 |
+
for i in tqdm.trange(iters):
|
384 |
+
self.train_step()
|
385 |
+
# save
|
386 |
+
self.save_model()
|
387 |
+
|
388 |
+
|
389 |
+
if __name__ == "__main__":
|
390 |
+
import argparse
|
391 |
+
from omegaconf import OmegaConf
|
392 |
+
|
393 |
+
parser = argparse.ArgumentParser()
|
394 |
+
parser.add_argument("--config", required=True, help="path to the yaml config file")
|
395 |
+
args, extras = parser.parse_known_args()
|
396 |
+
|
397 |
+
# override default config from cli
|
398 |
+
opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras))
|
399 |
+
|
400 |
+
# auto find mesh from stage 1
|
401 |
+
if opt.mesh is None:
|
402 |
+
default_path = os.path.join(opt.outdir, opt.save_path + '_mesh.' + opt.mesh_format)
|
403 |
+
if os.path.exists(default_path):
|
404 |
+
opt.mesh = default_path
|
405 |
+
else:
|
406 |
+
raise ValueError(f"Cannot find mesh from {default_path}, must specify --mesh explicitly!")
|
407 |
+
|
408 |
+
gui = GUI(opt)
|
409 |
+
|
410 |
+
gui.train(opt.iters_refine)
|
sparseags/mesh_utils/grid_put.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
def stride_from_shape(shape):
|
6 |
+
stride = [1]
|
7 |
+
for x in reversed(shape[1:]):
|
8 |
+
stride.append(stride[-1] * x)
|
9 |
+
return list(reversed(stride))
|
10 |
+
|
11 |
+
|
12 |
+
def scatter_add_nd(input, indices, values):
|
13 |
+
# input: [..., C], D dimension + C channel
|
14 |
+
# indices: [N, D], long
|
15 |
+
# values: [N, C]
|
16 |
+
|
17 |
+
D = indices.shape[-1]
|
18 |
+
C = input.shape[-1]
|
19 |
+
size = input.shape[:-1]
|
20 |
+
stride = stride_from_shape(size)
|
21 |
+
|
22 |
+
assert len(size) == D
|
23 |
+
|
24 |
+
input = input.view(-1, C) # [HW, C]
|
25 |
+
flatten_indices = (indices * torch.tensor(stride, dtype=torch.long, device=indices.device)).sum(-1) # [N]
|
26 |
+
|
27 |
+
input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values)
|
28 |
+
|
29 |
+
return input.view(*size, C)
|
30 |
+
|
31 |
+
|
32 |
+
def scatter_add_nd_with_count(input, count, indices, values, weights=None):
|
33 |
+
# input: [..., C], D dimension + C channel
|
34 |
+
# count: [..., 1], D dimension
|
35 |
+
# indices: [N, D], long
|
36 |
+
# values: [N, C]
|
37 |
+
|
38 |
+
D = indices.shape[-1]
|
39 |
+
C = input.shape[-1]
|
40 |
+
size = input.shape[:-1]
|
41 |
+
stride = stride_from_shape(size)
|
42 |
+
|
43 |
+
assert len(size) == D
|
44 |
+
|
45 |
+
input = input.view(-1, C) # [HW, C]
|
46 |
+
count = count.view(-1, 1)
|
47 |
+
|
48 |
+
flatten_indices = (indices * torch.tensor(stride, dtype=torch.long, device=indices.device)).sum(-1) # [N]
|
49 |
+
|
50 |
+
if weights is None:
|
51 |
+
weights = torch.ones_like(values[..., :1])
|
52 |
+
|
53 |
+
input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values)
|
54 |
+
count.scatter_add_(0, flatten_indices.unsqueeze(1), weights)
|
55 |
+
|
56 |
+
return input.view(*size, C), count.view(*size, 1)
|
57 |
+
|
58 |
+
def nearest_grid_put_2d(H, W, coords, values, return_count=False):
|
59 |
+
# coords: [N, 2], float in [-1, 1]
|
60 |
+
# values: [N, C]
|
61 |
+
|
62 |
+
C = values.shape[-1]
|
63 |
+
|
64 |
+
indices = (coords * 0.5 + 0.5) * torch.tensor(
|
65 |
+
[H - 1, W - 1], dtype=torch.float32, device=coords.device
|
66 |
+
)
|
67 |
+
indices = indices.round().long() # [N, 2]
|
68 |
+
|
69 |
+
result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C]
|
70 |
+
count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1]
|
71 |
+
weights = torch.ones_like(values[..., :1]) # [N, 1]
|
72 |
+
|
73 |
+
result, count = scatter_add_nd_with_count(result, count, indices, values, weights)
|
74 |
+
|
75 |
+
if return_count:
|
76 |
+
return result, count
|
77 |
+
|
78 |
+
mask = (count.squeeze(-1) > 0)
|
79 |
+
result[mask] = result[mask] / count[mask].repeat(1, C)
|
80 |
+
|
81 |
+
return result
|
82 |
+
|
83 |
+
|
84 |
+
def linear_grid_put_2d(H, W, coords, values, return_count=False):
|
85 |
+
# coords: [N, 2], float in [-1, 1]
|
86 |
+
# values: [N, C]
|
87 |
+
|
88 |
+
C = values.shape[-1]
|
89 |
+
|
90 |
+
indices = (coords * 0.5 + 0.5) * torch.tensor(
|
91 |
+
[H - 1, W - 1], dtype=torch.float32, device=coords.device
|
92 |
+
)
|
93 |
+
indices_00 = indices.floor().long() # [N, 2]
|
94 |
+
indices_00[:, 0].clamp_(0, H - 2)
|
95 |
+
indices_00[:, 1].clamp_(0, W - 2)
|
96 |
+
indices_01 = indices_00 + torch.tensor(
|
97 |
+
[0, 1], dtype=torch.long, device=indices.device
|
98 |
+
)
|
99 |
+
indices_10 = indices_00 + torch.tensor(
|
100 |
+
[1, 0], dtype=torch.long, device=indices.device
|
101 |
+
)
|
102 |
+
indices_11 = indices_00 + torch.tensor(
|
103 |
+
[1, 1], dtype=torch.long, device=indices.device
|
104 |
+
)
|
105 |
+
|
106 |
+
h = indices[..., 0] - indices_00[..., 0].float()
|
107 |
+
w = indices[..., 1] - indices_00[..., 1].float()
|
108 |
+
w_00 = (1 - h) * (1 - w)
|
109 |
+
w_01 = (1 - h) * w
|
110 |
+
w_10 = h * (1 - w)
|
111 |
+
w_11 = h * w
|
112 |
+
|
113 |
+
result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C]
|
114 |
+
count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1]
|
115 |
+
weights = torch.ones_like(values[..., :1]) # [N, 1]
|
116 |
+
|
117 |
+
result, count = scatter_add_nd_with_count(result, count, indices_00, values * w_00.unsqueeze(1), weights* w_00.unsqueeze(1))
|
118 |
+
result, count = scatter_add_nd_with_count(result, count, indices_01, values * w_01.unsqueeze(1), weights* w_01.unsqueeze(1))
|
119 |
+
result, count = scatter_add_nd_with_count(result, count, indices_10, values * w_10.unsqueeze(1), weights* w_10.unsqueeze(1))
|
120 |
+
result, count = scatter_add_nd_with_count(result, count, indices_11, values * w_11.unsqueeze(1), weights* w_11.unsqueeze(1))
|
121 |
+
|
122 |
+
if return_count:
|
123 |
+
return result, count
|
124 |
+
|
125 |
+
mask = (count.squeeze(-1) > 0)
|
126 |
+
result[mask] = result[mask] / count[mask].repeat(1, C)
|
127 |
+
|
128 |
+
return result
|
129 |
+
|
130 |
+
def mipmap_linear_grid_put_2d(H, W, coords, values, min_resolution=32, return_count=False):
|
131 |
+
# coords: [N, 2], float in [-1, 1]
|
132 |
+
# values: [N, C]
|
133 |
+
|
134 |
+
C = values.shape[-1]
|
135 |
+
|
136 |
+
result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C]
|
137 |
+
count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1]
|
138 |
+
|
139 |
+
cur_H, cur_W = H, W
|
140 |
+
|
141 |
+
while min(cur_H, cur_W) > min_resolution:
|
142 |
+
|
143 |
+
# try to fill the holes
|
144 |
+
mask = (count.squeeze(-1) == 0)
|
145 |
+
if not mask.any():
|
146 |
+
break
|
147 |
+
|
148 |
+
cur_result, cur_count = linear_grid_put_2d(cur_H, cur_W, coords, values, return_count=True)
|
149 |
+
result[mask] = result[mask] + F.interpolate(cur_result.permute(2,0,1).unsqueeze(0).contiguous(), (H, W), mode='bilinear', align_corners=False).squeeze(0).permute(1,2,0).contiguous()[mask]
|
150 |
+
count[mask] = count[mask] + F.interpolate(cur_count.view(1, 1, cur_H, cur_W), (H, W), mode='bilinear', align_corners=False).view(H, W, 1)[mask]
|
151 |
+
cur_H //= 2
|
152 |
+
cur_W //= 2
|
153 |
+
|
154 |
+
if return_count:
|
155 |
+
return result, count
|
156 |
+
|
157 |
+
mask = (count.squeeze(-1) > 0)
|
158 |
+
result[mask] = result[mask] / count[mask].repeat(1, C)
|
159 |
+
|
160 |
+
return result
|
161 |
+
|
162 |
+
def nearest_grid_put_3d(H, W, D, coords, values, return_count=False):
|
163 |
+
# coords: [N, 3], float in [-1, 1]
|
164 |
+
# values: [N, C]
|
165 |
+
|
166 |
+
C = values.shape[-1]
|
167 |
+
|
168 |
+
indices = (coords * 0.5 + 0.5) * torch.tensor(
|
169 |
+
[H - 1, W - 1, D - 1], dtype=torch.float32, device=coords.device
|
170 |
+
)
|
171 |
+
indices = indices.round().long() # [N, 2]
|
172 |
+
|
173 |
+
result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype) # [H, W, C]
|
174 |
+
count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype) # [H, W, 1]
|
175 |
+
weights = torch.ones_like(values[..., :1]) # [N, 1]
|
176 |
+
|
177 |
+
result, count = scatter_add_nd_with_count(result, count, indices, values, weights)
|
178 |
+
|
179 |
+
if return_count:
|
180 |
+
return result, count
|
181 |
+
|
182 |
+
mask = (count.squeeze(-1) > 0)
|
183 |
+
result[mask] = result[mask] / count[mask].repeat(1, C)
|
184 |
+
|
185 |
+
return result
|
186 |
+
|
187 |
+
|
188 |
+
def linear_grid_put_3d(H, W, D, coords, values, return_count=False):
|
189 |
+
# coords: [N, 3], float in [-1, 1]
|
190 |
+
# values: [N, C]
|
191 |
+
|
192 |
+
C = values.shape[-1]
|
193 |
+
|
194 |
+
indices = (coords * 0.5 + 0.5) * torch.tensor(
|
195 |
+
[H - 1, W - 1, D - 1], dtype=torch.float32, device=coords.device
|
196 |
+
)
|
197 |
+
indices_000 = indices.floor().long() # [N, 3]
|
198 |
+
indices_000[:, 0].clamp_(0, H - 2)
|
199 |
+
indices_000[:, 1].clamp_(0, W - 2)
|
200 |
+
indices_000[:, 2].clamp_(0, D - 2)
|
201 |
+
|
202 |
+
indices_001 = indices_000 + torch.tensor([0, 0, 1], dtype=torch.long, device=indices.device)
|
203 |
+
indices_010 = indices_000 + torch.tensor([0, 1, 0], dtype=torch.long, device=indices.device)
|
204 |
+
indices_011 = indices_000 + torch.tensor([0, 1, 1], dtype=torch.long, device=indices.device)
|
205 |
+
indices_100 = indices_000 + torch.tensor([1, 0, 0], dtype=torch.long, device=indices.device)
|
206 |
+
indices_101 = indices_000 + torch.tensor([1, 0, 1], dtype=torch.long, device=indices.device)
|
207 |
+
indices_110 = indices_000 + torch.tensor([1, 1, 0], dtype=torch.long, device=indices.device)
|
208 |
+
indices_111 = indices_000 + torch.tensor([1, 1, 1], dtype=torch.long, device=indices.device)
|
209 |
+
|
210 |
+
h = indices[..., 0] - indices_000[..., 0].float()
|
211 |
+
w = indices[..., 1] - indices_000[..., 1].float()
|
212 |
+
d = indices[..., 2] - indices_000[..., 2].float()
|
213 |
+
|
214 |
+
w_000 = (1 - h) * (1 - w) * (1 - d)
|
215 |
+
w_001 = (1 - h) * w * (1 - d)
|
216 |
+
w_010 = h * (1 - w) * (1 - d)
|
217 |
+
w_011 = h * w * (1 - d)
|
218 |
+
w_100 = (1 - h) * (1 - w) * d
|
219 |
+
w_101 = (1 - h) * w * d
|
220 |
+
w_110 = h * (1 - w) * d
|
221 |
+
w_111 = h * w * d
|
222 |
+
|
223 |
+
result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype) # [H, W, D, C]
|
224 |
+
count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype) # [H, W, D, 1]
|
225 |
+
weights = torch.ones_like(values[..., :1]) # [N, 1]
|
226 |
+
|
227 |
+
result, count = scatter_add_nd_with_count(result, count, indices_000, values * w_000.unsqueeze(1), weights * w_000.unsqueeze(1))
|
228 |
+
result, count = scatter_add_nd_with_count(result, count, indices_001, values * w_001.unsqueeze(1), weights * w_001.unsqueeze(1))
|
229 |
+
result, count = scatter_add_nd_with_count(result, count, indices_010, values * w_010.unsqueeze(1), weights * w_010.unsqueeze(1))
|
230 |
+
result, count = scatter_add_nd_with_count(result, count, indices_011, values * w_011.unsqueeze(1), weights * w_011.unsqueeze(1))
|
231 |
+
result, count = scatter_add_nd_with_count(result, count, indices_100, values * w_100.unsqueeze(1), weights * w_100.unsqueeze(1))
|
232 |
+
result, count = scatter_add_nd_with_count(result, count, indices_101, values * w_101.unsqueeze(1), weights * w_101.unsqueeze(1))
|
233 |
+
result, count = scatter_add_nd_with_count(result, count, indices_110, values * w_110.unsqueeze(1), weights * w_110.unsqueeze(1))
|
234 |
+
result, count = scatter_add_nd_with_count(result, count, indices_111, values * w_111.unsqueeze(1), weights * w_111.unsqueeze(1))
|
235 |
+
|
236 |
+
if return_count:
|
237 |
+
return result, count
|
238 |
+
|
239 |
+
mask = (count.squeeze(-1) > 0)
|
240 |
+
result[mask] = result[mask] / count[mask].repeat(1, C)
|
241 |
+
|
242 |
+
return result
|
243 |
+
|
244 |
+
def mipmap_linear_grid_put_3d(H, W, D, coords, values, min_resolution=32, return_count=False):
|
245 |
+
# coords: [N, 3], float in [-1, 1]
|
246 |
+
# values: [N, C]
|
247 |
+
|
248 |
+
C = values.shape[-1]
|
249 |
+
|
250 |
+
result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype) # [H, W, D, C]
|
251 |
+
count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype) # [H, W, D, 1]
|
252 |
+
cur_H, cur_W, cur_D = H, W, D
|
253 |
+
|
254 |
+
while min(min(cur_H, cur_W), cur_D) > min_resolution:
|
255 |
+
|
256 |
+
# try to fill the holes
|
257 |
+
mask = (count.squeeze(-1) == 0)
|
258 |
+
if not mask.any():
|
259 |
+
break
|
260 |
+
|
261 |
+
cur_result, cur_count = linear_grid_put_3d(cur_H, cur_W, cur_D, coords, values, return_count=True)
|
262 |
+
result[mask] = result[mask] + F.interpolate(cur_result.permute(3,0,1,2).unsqueeze(0).contiguous(), (H, W, D), mode='trilinear', align_corners=False).squeeze(0).permute(1,2,3,0).contiguous()[mask]
|
263 |
+
count[mask] = count[mask] + F.interpolate(cur_count.view(1, 1, cur_H, cur_W, cur_D), (H, W, D), mode='trilinear', align_corners=False).view(H, W, D, 1)[mask]
|
264 |
+
cur_H //= 2
|
265 |
+
cur_W //= 2
|
266 |
+
cur_D //= 2
|
267 |
+
|
268 |
+
if return_count:
|
269 |
+
return result, count
|
270 |
+
|
271 |
+
mask = (count.squeeze(-1) > 0)
|
272 |
+
result[mask] = result[mask] / count[mask].repeat(1, C)
|
273 |
+
|
274 |
+
return result
|
275 |
+
|
276 |
+
|
277 |
+
def grid_put(shape, coords, values, mode='linear-mipmap', min_resolution=32, return_raw=False):
|
278 |
+
# shape: [D], list/tuple
|
279 |
+
# coords: [N, D], float in [-1, 1]
|
280 |
+
# values: [N, C]
|
281 |
+
|
282 |
+
D = len(shape)
|
283 |
+
assert D in [2, 3], f'only support D == 2 or 3, but got D == {D}'
|
284 |
+
|
285 |
+
if mode == 'nearest':
|
286 |
+
if D == 2:
|
287 |
+
return nearest_grid_put_2d(*shape, coords, values, return_raw)
|
288 |
+
else:
|
289 |
+
return nearest_grid_put_3d(*shape, coords, values, return_raw)
|
290 |
+
elif mode == 'linear':
|
291 |
+
if D == 2:
|
292 |
+
return linear_grid_put_2d(*shape, coords, values, return_raw)
|
293 |
+
else:
|
294 |
+
return linear_grid_put_3d(*shape, coords, values, return_raw)
|
295 |
+
elif mode == 'linear-mipmap':
|
296 |
+
if D == 2:
|
297 |
+
return mipmap_linear_grid_put_2d(*shape, coords, values, min_resolution, return_raw)
|
298 |
+
else:
|
299 |
+
return mipmap_linear_grid_put_3d(*shape, coords, values, min_resolution, return_raw)
|
300 |
+
else:
|
301 |
+
raise NotImplementedError(f"got mode {mode}")
|
sparseags/mesh_utils/mesh.py
ADDED
@@ -0,0 +1,638 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import trimesh
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def dot(x, y):
|
9 |
+
return torch.sum(x * y, -1, keepdim=True)
|
10 |
+
|
11 |
+
|
12 |
+
def length(x, eps=1e-20):
|
13 |
+
return torch.sqrt(torch.clamp(dot(x, x), min=eps))
|
14 |
+
|
15 |
+
|
16 |
+
def safe_normalize(x, eps=1e-20):
|
17 |
+
return x / length(x, eps)
|
18 |
+
|
19 |
+
class Mesh:
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
v=None,
|
23 |
+
f=None,
|
24 |
+
vn=None,
|
25 |
+
fn=None,
|
26 |
+
vt=None,
|
27 |
+
ft=None,
|
28 |
+
albedo=None,
|
29 |
+
vc=None, # vertex color
|
30 |
+
device=None,
|
31 |
+
):
|
32 |
+
self.device = device
|
33 |
+
self.v = v
|
34 |
+
self.vn = vn
|
35 |
+
self.vt = vt
|
36 |
+
self.f = f
|
37 |
+
self.fn = fn
|
38 |
+
self.ft = ft
|
39 |
+
# only support a single albedo
|
40 |
+
self.albedo = albedo
|
41 |
+
# support vertex color is no albedo
|
42 |
+
self.vc = vc
|
43 |
+
|
44 |
+
self.ori_center = 0
|
45 |
+
self.ori_scale = 1
|
46 |
+
|
47 |
+
@classmethod
|
48 |
+
def load(cls, path=None, resize=True, renormal=True, retex=False, front_dir='+z', **kwargs):
|
49 |
+
# assume init with kwargs
|
50 |
+
if path is None:
|
51 |
+
mesh = cls(**kwargs)
|
52 |
+
# obj supports face uv
|
53 |
+
elif path.endswith(".obj"):
|
54 |
+
mesh = cls.load_obj(path, **kwargs)
|
55 |
+
# trimesh only supports vertex uv, but can load more formats
|
56 |
+
else:
|
57 |
+
mesh = cls.load_trimesh(path, **kwargs)
|
58 |
+
|
59 |
+
print(f"[Mesh loading] v: {mesh.v.shape}, f: {mesh.f.shape}")
|
60 |
+
# auto-normalize
|
61 |
+
if resize:
|
62 |
+
mesh.auto_size()
|
63 |
+
# auto-fix normal
|
64 |
+
if renormal or mesh.vn is None:
|
65 |
+
mesh.auto_normal()
|
66 |
+
print(f"[Mesh loading] vn: {mesh.vn.shape}, fn: {mesh.fn.shape}")
|
67 |
+
# auto-fix texcoords
|
68 |
+
if retex or (mesh.albedo is not None and mesh.vt is None):
|
69 |
+
mesh.auto_uv(cache_path=path)
|
70 |
+
print(f"[Mesh loading] vt: {mesh.vt.shape}, ft: {mesh.ft.shape}")
|
71 |
+
|
72 |
+
# rotate front dir to +z
|
73 |
+
if front_dir != "+z":
|
74 |
+
# axis switch
|
75 |
+
if "-z" in front_dir:
|
76 |
+
T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, -1]], device=mesh.device, dtype=torch.float32)
|
77 |
+
elif "+x" in front_dir:
|
78 |
+
T = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32)
|
79 |
+
elif "-x" in front_dir:
|
80 |
+
T = torch.tensor([[0, 0, -1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32)
|
81 |
+
elif "+y" in front_dir:
|
82 |
+
T = torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]], device=mesh.device, dtype=torch.float32)
|
83 |
+
elif "-y" in front_dir:
|
84 |
+
T = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=mesh.device, dtype=torch.float32)
|
85 |
+
else:
|
86 |
+
T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
|
87 |
+
# rotation (how many 90 degrees)
|
88 |
+
if '1' in front_dir:
|
89 |
+
T @= torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
|
90 |
+
elif '2' in front_dir:
|
91 |
+
T @= torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
|
92 |
+
elif '3' in front_dir:
|
93 |
+
T @= torch.tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
|
94 |
+
mesh.v @= T
|
95 |
+
mesh.vn @= T
|
96 |
+
|
97 |
+
return mesh
|
98 |
+
|
99 |
+
# load from obj file
|
100 |
+
@classmethod
|
101 |
+
def load_obj(cls, path, albedo_path=None, device=None, enable_dino=False):
|
102 |
+
assert os.path.splitext(path)[-1] == ".obj"
|
103 |
+
|
104 |
+
mesh = cls()
|
105 |
+
|
106 |
+
# device
|
107 |
+
if device is None:
|
108 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
109 |
+
|
110 |
+
mesh.device = device
|
111 |
+
|
112 |
+
# load obj
|
113 |
+
with open(path, "r") as f:
|
114 |
+
lines = f.readlines()
|
115 |
+
|
116 |
+
def parse_f_v(fv):
|
117 |
+
# pass in a vertex term of a face, return {v, vt, vn} (-1 if not provided)
|
118 |
+
# supported forms:
|
119 |
+
# f v1 v2 v3
|
120 |
+
# f v1/vt1 v2/vt2 v3/vt3
|
121 |
+
# f v1/vt1/vn1 v2/vt2/vn2 v3/vt3/vn3
|
122 |
+
# f v1//vn1 v2//vn2 v3//vn3
|
123 |
+
xs = [int(x) - 1 if x != "" else -1 for x in fv.split("/")]
|
124 |
+
xs.extend([-1] * (3 - len(xs)))
|
125 |
+
return xs[0], xs[1], xs[2]
|
126 |
+
|
127 |
+
# NOTE: we ignore usemtl, and assume the mesh ONLY uses one material (first in mtl)
|
128 |
+
vertices, texcoords, normals = [], [], []
|
129 |
+
faces, tfaces, nfaces = [], [], []
|
130 |
+
mtl_path = None
|
131 |
+
|
132 |
+
for line in lines:
|
133 |
+
split_line = line.split()
|
134 |
+
# empty line
|
135 |
+
if len(split_line) == 0:
|
136 |
+
continue
|
137 |
+
prefix = split_line[0].lower()
|
138 |
+
# mtllib
|
139 |
+
if prefix == "mtllib":
|
140 |
+
mtl_path = split_line[1]
|
141 |
+
# usemtl
|
142 |
+
elif prefix == "usemtl":
|
143 |
+
pass # ignored
|
144 |
+
# v/vn/vt
|
145 |
+
elif prefix == "v":
|
146 |
+
vertices.append([float(v) for v in split_line[1:]])
|
147 |
+
elif prefix == "vn":
|
148 |
+
normals.append([float(v) for v in split_line[1:]])
|
149 |
+
elif prefix == "vt":
|
150 |
+
val = [float(v) for v in split_line[1:]]
|
151 |
+
texcoords.append([val[0], 1.0 - val[1]])
|
152 |
+
elif prefix == "f":
|
153 |
+
vs = split_line[1:]
|
154 |
+
nv = len(vs)
|
155 |
+
v0, t0, n0 = parse_f_v(vs[0])
|
156 |
+
for i in range(nv - 2): # triangulate (assume vertices are ordered)
|
157 |
+
v1, t1, n1 = parse_f_v(vs[i + 1])
|
158 |
+
v2, t2, n2 = parse_f_v(vs[i + 2])
|
159 |
+
faces.append([v0, v1, v2])
|
160 |
+
tfaces.append([t0, t1, t2])
|
161 |
+
nfaces.append([n0, n1, n2])
|
162 |
+
|
163 |
+
mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
|
164 |
+
mesh.vt = (
|
165 |
+
torch.tensor(texcoords, dtype=torch.float32, device=device)
|
166 |
+
if len(texcoords) > 0
|
167 |
+
else None
|
168 |
+
)
|
169 |
+
mesh.vn = (
|
170 |
+
torch.tensor(normals, dtype=torch.float32, device=device)
|
171 |
+
if len(normals) > 0
|
172 |
+
else None
|
173 |
+
)
|
174 |
+
|
175 |
+
mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
|
176 |
+
mesh.ft = (
|
177 |
+
torch.tensor(tfaces, dtype=torch.int32, device=device)
|
178 |
+
if len(texcoords) > 0
|
179 |
+
else None
|
180 |
+
)
|
181 |
+
mesh.fn = (
|
182 |
+
torch.tensor(nfaces, dtype=torch.int32, device=device)
|
183 |
+
if len(normals) > 0
|
184 |
+
else None
|
185 |
+
)
|
186 |
+
|
187 |
+
# see if there is vertex color
|
188 |
+
use_vertex_color = False
|
189 |
+
if mesh.v.shape[1] == 6:
|
190 |
+
use_vertex_color = True
|
191 |
+
mesh.vc = mesh.v[:, 3:]
|
192 |
+
mesh.v = mesh.v[:, :3]
|
193 |
+
print(f"[load_obj] use vertex color: {mesh.vc.shape}")
|
194 |
+
|
195 |
+
# try to load texture image
|
196 |
+
if not use_vertex_color:
|
197 |
+
# try to retrieve mtl file
|
198 |
+
mtl_path_candidates = []
|
199 |
+
if mtl_path is not None:
|
200 |
+
mtl_path_candidates.append(mtl_path)
|
201 |
+
mtl_path_candidates.append(os.path.join(os.path.dirname(path), mtl_path))
|
202 |
+
mtl_path_candidates.append(path.replace(".obj", ".mtl"))
|
203 |
+
|
204 |
+
mtl_path = None
|
205 |
+
for candidate in mtl_path_candidates:
|
206 |
+
if os.path.exists(candidate):
|
207 |
+
mtl_path = candidate
|
208 |
+
break
|
209 |
+
|
210 |
+
# if albedo_path is not provided, try retrieve it from mtl
|
211 |
+
if mtl_path is not None and albedo_path is None:
|
212 |
+
with open(mtl_path, "r") as f:
|
213 |
+
lines = f.readlines()
|
214 |
+
for line in lines:
|
215 |
+
split_line = line.split()
|
216 |
+
# empty line
|
217 |
+
if len(split_line) == 0:
|
218 |
+
continue
|
219 |
+
prefix = split_line[0]
|
220 |
+
# NOTE: simply use the first map_Kd as albedo!
|
221 |
+
if "map_Kd" in prefix:
|
222 |
+
albedo_path = os.path.join(os.path.dirname(path), split_line[1])
|
223 |
+
print(f"[load_obj] use texture from: {albedo_path}")
|
224 |
+
# break
|
225 |
+
if "map_Ft" in prefix:
|
226 |
+
feature_path = os.path.join(os.path.dirname(path), split_line[1])
|
227 |
+
print(f"[load_obj] use feature from: {feature_path}")
|
228 |
+
break
|
229 |
+
|
230 |
+
# still not found albedo_path, or the path doesn't exist
|
231 |
+
if albedo_path is None or not os.path.exists(albedo_path):
|
232 |
+
# init an empty texture
|
233 |
+
print(f"[load_obj] init empty albedo!")
|
234 |
+
# albedo = np.random.rand(1024, 1024, 3).astype(np.float32)
|
235 |
+
albedo = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5]) # default color
|
236 |
+
else:
|
237 |
+
albedo = cv2.imread(albedo_path, cv2.IMREAD_UNCHANGED)
|
238 |
+
albedo = cv2.cvtColor(albedo, cv2.COLOR_BGR2RGB)
|
239 |
+
albedo = albedo.astype(np.float32) / 255
|
240 |
+
print(f"[load_obj] load texture: {albedo.shape}")
|
241 |
+
|
242 |
+
# import matplotlib.pyplot as plt
|
243 |
+
# plt.imshow(albedo)
|
244 |
+
# plt.show()
|
245 |
+
if enable_dino and os.path.exists(feature_path):
|
246 |
+
feature = torch.load(feature_path).to(device)
|
247 |
+
mesh.feature = feature
|
248 |
+
print(f"[load_obj] load feature: {feature.shape}")
|
249 |
+
|
250 |
+
mesh.albedo = torch.tensor(albedo, dtype=torch.float32, device=device)
|
251 |
+
|
252 |
+
return mesh
|
253 |
+
|
254 |
+
@classmethod
|
255 |
+
def load_trimesh(cls, path, device=None, enable_dino=False):
|
256 |
+
mesh = cls()
|
257 |
+
|
258 |
+
# device
|
259 |
+
if device is None:
|
260 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
261 |
+
|
262 |
+
mesh.device = device
|
263 |
+
|
264 |
+
# use trimesh to load ply/glb, assume only has one single RootMesh...
|
265 |
+
_data = trimesh.load(path)
|
266 |
+
if isinstance(_data, trimesh.Scene):
|
267 |
+
if len(_data.geometry) == 1:
|
268 |
+
_mesh = list(_data.geometry.values())[0]
|
269 |
+
else:
|
270 |
+
# manual concat, will lose texture
|
271 |
+
_concat = []
|
272 |
+
for g in _data.geometry.values():
|
273 |
+
if isinstance(g, trimesh.Trimesh):
|
274 |
+
_concat.append(g)
|
275 |
+
_mesh = trimesh.util.concatenate(_concat)
|
276 |
+
else:
|
277 |
+
_mesh = _data
|
278 |
+
|
279 |
+
if _mesh.visual.kind == 'vertex':
|
280 |
+
vertex_colors = _mesh.visual.vertex_colors
|
281 |
+
vertex_colors = np.array(vertex_colors[..., :3]).astype(np.float32) / 255
|
282 |
+
mesh.vc = torch.tensor(vertex_colors, dtype=torch.float32, device=device)
|
283 |
+
print(f"[load_trimesh] use vertex color: {mesh.vc.shape}")
|
284 |
+
elif _mesh.visual.kind == 'texture':
|
285 |
+
_material = _mesh.visual.material
|
286 |
+
if isinstance(_material, trimesh.visual.material.PBRMaterial):
|
287 |
+
texture = np.array(_material.baseColorTexture).astype(np.float32) / 255
|
288 |
+
elif isinstance(_material, trimesh.visual.material.SimpleMaterial):
|
289 |
+
texture = np.array(_material.to_pbr().baseColorTexture).astype(np.float32) / 255
|
290 |
+
else:
|
291 |
+
raise NotImplementedError(f"material type {type(_material)} not supported!")
|
292 |
+
mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device)
|
293 |
+
print(f"[load_trimesh] load texture: {texture.shape}")
|
294 |
+
else:
|
295 |
+
texture = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5])
|
296 |
+
mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device)
|
297 |
+
print(f"[load_trimesh] failed to load texture.")
|
298 |
+
|
299 |
+
vertices = _mesh.vertices
|
300 |
+
|
301 |
+
try:
|
302 |
+
texcoords = _mesh.visual.uv
|
303 |
+
texcoords[:, 1] = 1 - texcoords[:, 1]
|
304 |
+
except Exception as e:
|
305 |
+
texcoords = None
|
306 |
+
|
307 |
+
try:
|
308 |
+
normals = _mesh.vertex_normals
|
309 |
+
except Exception as e:
|
310 |
+
normals = None
|
311 |
+
|
312 |
+
# trimesh only support vertex uv...
|
313 |
+
faces = tfaces = nfaces = _mesh.faces
|
314 |
+
|
315 |
+
mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
|
316 |
+
mesh.vt = (
|
317 |
+
torch.tensor(texcoords, dtype=torch.float32, device=device)
|
318 |
+
if texcoords is not None
|
319 |
+
else None
|
320 |
+
)
|
321 |
+
mesh.vn = (
|
322 |
+
torch.tensor(normals, dtype=torch.float32, device=device)
|
323 |
+
if normals is not None
|
324 |
+
else None
|
325 |
+
)
|
326 |
+
|
327 |
+
mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
|
328 |
+
mesh.ft = (
|
329 |
+
torch.tensor(tfaces, dtype=torch.int32, device=device)
|
330 |
+
if texcoords is not None
|
331 |
+
else None
|
332 |
+
)
|
333 |
+
mesh.fn = (
|
334 |
+
torch.tensor(nfaces, dtype=torch.int32, device=device)
|
335 |
+
if normals is not None
|
336 |
+
else None
|
337 |
+
)
|
338 |
+
|
339 |
+
return mesh
|
340 |
+
|
341 |
+
# aabb
|
342 |
+
def aabb(self):
|
343 |
+
return torch.min(self.v, dim=0).values, torch.max(self.v, dim=0).values
|
344 |
+
|
345 |
+
# unit size
|
346 |
+
@torch.no_grad()
|
347 |
+
def auto_size(self):
|
348 |
+
vmin, vmax = self.aabb()
|
349 |
+
self.ori_center = (vmax + vmin) / 2
|
350 |
+
self.ori_scale = 1.2 / torch.max(vmax - vmin).item()
|
351 |
+
self.v = (self.v - self.ori_center) * self.ori_scale
|
352 |
+
|
353 |
+
def auto_normal(self):
|
354 |
+
i0, i1, i2 = self.f[:, 0].long(), self.f[:, 1].long(), self.f[:, 2].long()
|
355 |
+
v0, v1, v2 = self.v[i0, :], self.v[i1, :], self.v[i2, :]
|
356 |
+
|
357 |
+
face_normals = torch.cross(v1 - v0, v2 - v0)
|
358 |
+
|
359 |
+
# Splat face normals to vertices
|
360 |
+
vn = torch.zeros_like(self.v)
|
361 |
+
vn.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
|
362 |
+
vn.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
|
363 |
+
vn.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
|
364 |
+
|
365 |
+
# Normalize, replace zero (degenerated) normals with some default value
|
366 |
+
vn = torch.where(
|
367 |
+
dot(vn, vn) > 1e-20,
|
368 |
+
vn,
|
369 |
+
torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device),
|
370 |
+
)
|
371 |
+
vn = safe_normalize(vn)
|
372 |
+
|
373 |
+
self.vn = vn
|
374 |
+
self.fn = self.f
|
375 |
+
|
376 |
+
def auto_uv(self, cache_path=None, vmap=True):
|
377 |
+
# try to load cache
|
378 |
+
if cache_path is not None:
|
379 |
+
cache_path = os.path.splitext(cache_path)[0] + "_uv.npz"
|
380 |
+
if cache_path is not None and os.path.exists(cache_path):
|
381 |
+
data = np.load(cache_path)
|
382 |
+
vt_np, ft_np, vmapping = data["vt"], data["ft"], data["vmapping"]
|
383 |
+
else:
|
384 |
+
import xatlas
|
385 |
+
|
386 |
+
v_np = self.v.detach().cpu().numpy()
|
387 |
+
f_np = self.f.detach().int().cpu().numpy()
|
388 |
+
atlas = xatlas.Atlas()
|
389 |
+
atlas.add_mesh(v_np, f_np)
|
390 |
+
chart_options = xatlas.ChartOptions()
|
391 |
+
# chart_options.max_iterations = 4
|
392 |
+
atlas.generate(chart_options=chart_options)
|
393 |
+
vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
|
394 |
+
|
395 |
+
# save to cache
|
396 |
+
if cache_path is not None:
|
397 |
+
np.savez(cache_path, vt=vt_np, ft=ft_np, vmapping=vmapping)
|
398 |
+
|
399 |
+
vt = torch.from_numpy(vt_np.astype(np.float32)).to(self.device)
|
400 |
+
ft = torch.from_numpy(ft_np.astype(np.int32)).to(self.device)
|
401 |
+
self.vt = vt
|
402 |
+
self.ft = ft
|
403 |
+
|
404 |
+
if vmap:
|
405 |
+
# remap v/f to vt/ft, so each v correspond to a unique vt. (necessary for gltf)
|
406 |
+
vmapping = torch.from_numpy(vmapping.astype(np.int64)).long().to(self.device)
|
407 |
+
self.align_v_to_vt(vmapping)
|
408 |
+
|
409 |
+
def align_v_to_vt(self, vmapping=None):
|
410 |
+
# remap v/f and vn/vn to vt/ft.
|
411 |
+
if vmapping is None:
|
412 |
+
ft = self.ft.view(-1).long()
|
413 |
+
f = self.f.view(-1).long()
|
414 |
+
vmapping = torch.zeros(self.vt.shape[0], dtype=torch.long, device=self.device)
|
415 |
+
vmapping[ft] = f # scatter, randomly choose one if index is not unique
|
416 |
+
|
417 |
+
self.v = self.v[vmapping]
|
418 |
+
self.f = self.ft
|
419 |
+
# assume fn == f
|
420 |
+
if self.vn is not None:
|
421 |
+
self.vn = self.vn[vmapping]
|
422 |
+
self.fn = self.ft
|
423 |
+
|
424 |
+
def to(self, device):
|
425 |
+
self.device = device
|
426 |
+
for name in ["v", "f", "vn", "fn", "vt", "ft", "albedo"]:
|
427 |
+
tensor = getattr(self, name)
|
428 |
+
if tensor is not None:
|
429 |
+
setattr(self, name, tensor.to(device))
|
430 |
+
return self
|
431 |
+
|
432 |
+
def write(self, path, enable_dino=False):
|
433 |
+
if path.endswith(".ply"):
|
434 |
+
self.write_ply(path)
|
435 |
+
elif path.endswith(".obj"):
|
436 |
+
self.write_obj(path, enable_dino)
|
437 |
+
elif path.endswith(".glb") or path.endswith(".gltf"):
|
438 |
+
self.write_glb(path)
|
439 |
+
else:
|
440 |
+
raise NotImplementedError(f"format {path} not supported!")
|
441 |
+
|
442 |
+
# write to ply file (only geom)
|
443 |
+
def write_ply(self, path):
|
444 |
+
|
445 |
+
v_np = self.v.detach().cpu().numpy()
|
446 |
+
f_np = self.f.detach().cpu().numpy()
|
447 |
+
|
448 |
+
_mesh = trimesh.Trimesh(vertices=v_np, faces=f_np)
|
449 |
+
_mesh.export(path)
|
450 |
+
|
451 |
+
# write to gltf/glb file (geom + texture)
|
452 |
+
def write_glb(self, path):
|
453 |
+
|
454 |
+
assert self.vn is not None and self.vt is not None # should be improved to support export without texture...
|
455 |
+
|
456 |
+
# assert self.v.shape[0] == self.vn.shape[0] and self.v.shape[0] == self.vt.shape[0]
|
457 |
+
if self.v.shape[0] != self.vt.shape[0]:
|
458 |
+
self.align_v_to_vt()
|
459 |
+
|
460 |
+
# assume f == fn == ft
|
461 |
+
|
462 |
+
import pygltflib
|
463 |
+
|
464 |
+
f_np = self.f.detach().cpu().numpy().astype(np.uint32)
|
465 |
+
v_np = self.v.detach().cpu().numpy().astype(np.float32)
|
466 |
+
# vn_np = self.vn.detach().cpu().numpy().astype(np.float32)
|
467 |
+
vt_np = self.vt.detach().cpu().numpy().astype(np.float32)
|
468 |
+
|
469 |
+
albedo = self.albedo.detach().cpu().numpy()
|
470 |
+
albedo = (albedo * 255).astype(np.uint8)
|
471 |
+
albedo = cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR)
|
472 |
+
|
473 |
+
f_np_blob = f_np.flatten().tobytes()
|
474 |
+
v_np_blob = v_np.tobytes()
|
475 |
+
# vn_np_blob = vn_np.tobytes()
|
476 |
+
vt_np_blob = vt_np.tobytes()
|
477 |
+
albedo_blob = cv2.imencode('.png', albedo)[1].tobytes()
|
478 |
+
|
479 |
+
gltf = pygltflib.GLTF2(
|
480 |
+
scene=0,
|
481 |
+
scenes=[pygltflib.Scene(nodes=[0])],
|
482 |
+
nodes=[pygltflib.Node(mesh=0)],
|
483 |
+
meshes=[pygltflib.Mesh(primitives=[
|
484 |
+
pygltflib.Primitive(
|
485 |
+
# indices to accessors (0 is triangles)
|
486 |
+
attributes=pygltflib.Attributes(
|
487 |
+
POSITION=1, TEXCOORD_0=2,
|
488 |
+
),
|
489 |
+
indices=0, material=0,
|
490 |
+
)
|
491 |
+
])],
|
492 |
+
materials=[
|
493 |
+
pygltflib.Material(
|
494 |
+
pbrMetallicRoughness=pygltflib.PbrMetallicRoughness(
|
495 |
+
baseColorTexture=pygltflib.TextureInfo(index=0, texCoord=0),
|
496 |
+
metallicFactor=0.0,
|
497 |
+
roughnessFactor=1.0,
|
498 |
+
),
|
499 |
+
alphaCutoff=0,
|
500 |
+
doubleSided=True,
|
501 |
+
)
|
502 |
+
],
|
503 |
+
textures=[
|
504 |
+
pygltflib.Texture(sampler=0, source=0),
|
505 |
+
],
|
506 |
+
samplers=[
|
507 |
+
pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT),
|
508 |
+
],
|
509 |
+
images=[
|
510 |
+
# use embedded (buffer) image
|
511 |
+
pygltflib.Image(bufferView=3, mimeType="image/png"),
|
512 |
+
],
|
513 |
+
buffers=[
|
514 |
+
pygltflib.Buffer(byteLength=len(f_np_blob) + len(v_np_blob) + len(vt_np_blob) + len(albedo_blob))
|
515 |
+
],
|
516 |
+
# buffer view (based on dtype)
|
517 |
+
bufferViews=[
|
518 |
+
# triangles; as flatten (element) array
|
519 |
+
pygltflib.BufferView(
|
520 |
+
buffer=0,
|
521 |
+
byteLength=len(f_np_blob),
|
522 |
+
target=pygltflib.ELEMENT_ARRAY_BUFFER, # GL_ELEMENT_ARRAY_BUFFER (34963)
|
523 |
+
),
|
524 |
+
# positions; as vec3 array
|
525 |
+
pygltflib.BufferView(
|
526 |
+
buffer=0,
|
527 |
+
byteOffset=len(f_np_blob),
|
528 |
+
byteLength=len(v_np_blob),
|
529 |
+
byteStride=12, # vec3
|
530 |
+
target=pygltflib.ARRAY_BUFFER, # GL_ARRAY_BUFFER (34962)
|
531 |
+
),
|
532 |
+
# texcoords; as vec2 array
|
533 |
+
pygltflib.BufferView(
|
534 |
+
buffer=0,
|
535 |
+
byteOffset=len(f_np_blob) + len(v_np_blob),
|
536 |
+
byteLength=len(vt_np_blob),
|
537 |
+
byteStride=8, # vec2
|
538 |
+
target=pygltflib.ARRAY_BUFFER,
|
539 |
+
),
|
540 |
+
# texture; as none target
|
541 |
+
pygltflib.BufferView(
|
542 |
+
buffer=0,
|
543 |
+
byteOffset=len(f_np_blob) + len(v_np_blob) + len(vt_np_blob),
|
544 |
+
byteLength=len(albedo_blob),
|
545 |
+
),
|
546 |
+
],
|
547 |
+
accessors=[
|
548 |
+
# 0 = triangles
|
549 |
+
pygltflib.Accessor(
|
550 |
+
bufferView=0,
|
551 |
+
componentType=pygltflib.UNSIGNED_INT, # GL_UNSIGNED_INT (5125)
|
552 |
+
count=f_np.size,
|
553 |
+
type=pygltflib.SCALAR,
|
554 |
+
max=[int(f_np.max())],
|
555 |
+
min=[int(f_np.min())],
|
556 |
+
),
|
557 |
+
# 1 = positions
|
558 |
+
pygltflib.Accessor(
|
559 |
+
bufferView=1,
|
560 |
+
componentType=pygltflib.FLOAT, # GL_FLOAT (5126)
|
561 |
+
count=len(v_np),
|
562 |
+
type=pygltflib.VEC3,
|
563 |
+
max=v_np.max(axis=0).tolist(),
|
564 |
+
min=v_np.min(axis=0).tolist(),
|
565 |
+
),
|
566 |
+
# 2 = texcoords
|
567 |
+
pygltflib.Accessor(
|
568 |
+
bufferView=2,
|
569 |
+
componentType=pygltflib.FLOAT,
|
570 |
+
count=len(vt_np),
|
571 |
+
type=pygltflib.VEC2,
|
572 |
+
max=vt_np.max(axis=0).tolist(),
|
573 |
+
min=vt_np.min(axis=0).tolist(),
|
574 |
+
),
|
575 |
+
],
|
576 |
+
)
|
577 |
+
|
578 |
+
# set actual data
|
579 |
+
gltf.set_binary_blob(f_np_blob + v_np_blob + vt_np_blob + albedo_blob)
|
580 |
+
|
581 |
+
# glb = b"".join(gltf.save_to_bytes())
|
582 |
+
gltf.save(path)
|
583 |
+
|
584 |
+
# write to obj file (geom + texture)
|
585 |
+
def write_obj(self, path, enable_dino=False):
|
586 |
+
|
587 |
+
mtl_path = path.replace(".obj", ".mtl")
|
588 |
+
albedo_path = path.replace(".obj", "_albedo.png")
|
589 |
+
feature_path = path.replace(".obj", "_feature.pt")
|
590 |
+
|
591 |
+
v_np = self.v.detach().cpu().numpy()
|
592 |
+
vt_np = self.vt.detach().cpu().numpy() if self.vt is not None else None
|
593 |
+
vn_np = self.vn.detach().cpu().numpy() if self.vn is not None else None
|
594 |
+
f_np = self.f.detach().cpu().numpy()
|
595 |
+
ft_np = self.ft.detach().cpu().numpy() if self.ft is not None else None
|
596 |
+
fn_np = self.fn.detach().cpu().numpy() if self.fn is not None else None
|
597 |
+
|
598 |
+
with open(path, "w") as fp:
|
599 |
+
fp.write(f"mtllib {os.path.basename(mtl_path)} \n")
|
600 |
+
|
601 |
+
for v in v_np:
|
602 |
+
fp.write(f"v {v[0]} {v[1]} {v[2]} \n")
|
603 |
+
|
604 |
+
if vt_np is not None:
|
605 |
+
for v in vt_np:
|
606 |
+
fp.write(f"vt {v[0]} {1 - v[1]} \n")
|
607 |
+
|
608 |
+
if vn_np is not None:
|
609 |
+
for v in vn_np:
|
610 |
+
fp.write(f"vn {v[0]} {v[1]} {v[2]} \n")
|
611 |
+
|
612 |
+
fp.write(f"usemtl defaultMat \n")
|
613 |
+
for i in range(len(f_np)):
|
614 |
+
fp.write(
|
615 |
+
f'f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1 if ft_np is not None else ""}/{fn_np[i, 0] + 1 if fn_np is not None else ""} \
|
616 |
+
{f_np[i, 1] + 1}/{ft_np[i, 1] + 1 if ft_np is not None else ""}/{fn_np[i, 1] + 1 if fn_np is not None else ""} \
|
617 |
+
{f_np[i, 2] + 1}/{ft_np[i, 2] + 1 if ft_np is not None else ""}/{fn_np[i, 2] + 1 if fn_np is not None else ""} \n'
|
618 |
+
)
|
619 |
+
|
620 |
+
with open(mtl_path, "w") as fp:
|
621 |
+
fp.write(f"newmtl defaultMat \n")
|
622 |
+
fp.write(f"Ka 1 1 1 \n")
|
623 |
+
fp.write(f"Kd 1 1 1 \n")
|
624 |
+
fp.write(f"Ks 0 0 0 \n")
|
625 |
+
fp.write(f"Tr 1 \n")
|
626 |
+
fp.write(f"illum 1 \n")
|
627 |
+
fp.write(f"Ns 0 \n")
|
628 |
+
fp.write(f"map_Kd {os.path.basename(albedo_path)} \n")
|
629 |
+
if enable_dino:
|
630 |
+
fp.write(f"map_Ft {os.path.basename(feature_path)} \n")
|
631 |
+
|
632 |
+
albedo = self.albedo.detach().cpu().numpy()
|
633 |
+
albedo = (albedo * 255).astype(np.uint8)
|
634 |
+
cv2.imwrite(albedo_path, cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR))
|
635 |
+
|
636 |
+
if enable_dino:
|
637 |
+
feature = self.feature.detach().cpu()
|
638 |
+
torch.save(feature, feature_path)
|
sparseags/mesh_utils/mesh_renderer.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import cv2
|
4 |
+
import trimesh
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
import nvdiffrast.torch as dr
|
12 |
+
from sparseags.mesh_utils.mesh import Mesh, safe_normalize
|
13 |
+
|
14 |
+
|
15 |
+
def scale_img_nhwc(x, size, mag='bilinear', min='bilinear'):
|
16 |
+
assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[1]), "Trying to magnify image in one dimension and minify in the other"
|
17 |
+
y = x.permute(0, 3, 1, 2) # NHWC -> NCHW
|
18 |
+
if x.shape[1] > size[0] and x.shape[2] > size[1]: # Minification, previous size was bigger
|
19 |
+
y = torch.nn.functional.interpolate(y, size, mode=min)
|
20 |
+
else: # Magnification
|
21 |
+
if mag == 'bilinear' or mag == 'bicubic':
|
22 |
+
y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True)
|
23 |
+
else:
|
24 |
+
y = torch.nn.functional.interpolate(y, size, mode=mag)
|
25 |
+
return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC
|
26 |
+
|
27 |
+
def scale_img_hwc(x, size, mag='bilinear', min='bilinear'):
|
28 |
+
return scale_img_nhwc(x[None, ...], size, mag, min)[0]
|
29 |
+
|
30 |
+
def scale_img_nhw(x, size, mag='bilinear', min='bilinear'):
|
31 |
+
return scale_img_nhwc(x[..., None], size, mag, min)[..., 0]
|
32 |
+
|
33 |
+
def scale_img_hw(x, size, mag='bilinear', min='bilinear'):
|
34 |
+
return scale_img_nhwc(x[None, ..., None], size, mag, min)[0, ..., 0]
|
35 |
+
|
36 |
+
def trunc_rev_sigmoid(x, eps=1e-6):
|
37 |
+
x = x.clamp(eps, 1 - eps)
|
38 |
+
return torch.log(x / (1 - x))
|
39 |
+
|
40 |
+
def make_divisible(x, m=8):
|
41 |
+
return int(math.ceil(x / m) * m)
|
42 |
+
|
43 |
+
class Renderer(nn.Module):
|
44 |
+
def __init__(self, opt):
|
45 |
+
|
46 |
+
super().__init__()
|
47 |
+
|
48 |
+
self.opt = opt
|
49 |
+
self.enable_dino = self.opt.lambda_dino > 0
|
50 |
+
|
51 |
+
self.mesh = Mesh.load(self.opt.mesh, resize=False, enable_dino=self.enable_dino)
|
52 |
+
|
53 |
+
if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'):
|
54 |
+
self.glctx = dr.RasterizeGLContext()
|
55 |
+
else:
|
56 |
+
self.glctx = dr.RasterizeCudaContext()
|
57 |
+
|
58 |
+
self.v_offsets = torch.zeros_like(self.mesh.v)
|
59 |
+
self.raw_albedo = trunc_rev_sigmoid(self.mesh.albedo)
|
60 |
+
|
61 |
+
# extract trainable parameters
|
62 |
+
if opt.trainable_texture:
|
63 |
+
self.v_offsets = nn.Parameter(self.v_offsets)
|
64 |
+
self.raw_albedo = nn.Parameter(self.raw_albedo)
|
65 |
+
|
66 |
+
if self.enable_dino:
|
67 |
+
self.raw_feature = nn.Parameter((self.mesh.feature))
|
68 |
+
|
69 |
+
|
70 |
+
def get_params(self):
|
71 |
+
|
72 |
+
params = [
|
73 |
+
{'params': self.raw_albedo, 'lr': self.opt.texture_lr},
|
74 |
+
]
|
75 |
+
|
76 |
+
if self.enable_dino:
|
77 |
+
params.append({'params': self.raw_feature, 'lr': self.opt.texture_lr})
|
78 |
+
|
79 |
+
if self.opt.train_geo:
|
80 |
+
params.append({'params': self.v_offsets, 'lr': self.opt.geom_lr})
|
81 |
+
|
82 |
+
return params
|
83 |
+
|
84 |
+
@torch.no_grad()
|
85 |
+
def export_mesh(self, save_path):
|
86 |
+
self.mesh.v = (self.mesh.v + self.v_offsets).detach()
|
87 |
+
self.mesh.albedo = torch.sigmoid(self.raw_albedo.detach())
|
88 |
+
if self.enable_dino:
|
89 |
+
self.mesh.feature = self.raw_feature.detach()
|
90 |
+
self.mesh.write(save_path, self.enable_dino)
|
91 |
+
|
92 |
+
|
93 |
+
def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1, texture_filter='linear-mipmap-linear'):
|
94 |
+
|
95 |
+
# do super-sampling
|
96 |
+
if ssaa != 1:
|
97 |
+
h = make_divisible(h0 * ssaa, 8)
|
98 |
+
w = make_divisible(w0 * ssaa, 8)
|
99 |
+
else:
|
100 |
+
h, w = h0, w0
|
101 |
+
|
102 |
+
results = {}
|
103 |
+
|
104 |
+
# get v
|
105 |
+
if self.opt.train_geo:
|
106 |
+
v = self.mesh.v + self.v_offsets # [N, 3]
|
107 |
+
else:
|
108 |
+
v = self.mesh.v
|
109 |
+
|
110 |
+
pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)
|
111 |
+
proj = torch.from_numpy(proj.astype(np.float32)).to(v.device)
|
112 |
+
|
113 |
+
# get v_clip and render rgb
|
114 |
+
v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
|
115 |
+
v_clip = v_cam @ proj.T
|
116 |
+
|
117 |
+
rast, rast_db = dr.rasterize(self.glctx, v_clip, self.mesh.f, (h, w))
|
118 |
+
|
119 |
+
alpha = (rast[0, ..., 3:] > 0).float()
|
120 |
+
depth, _ = dr.interpolate(-v_cam[..., [2]], rast, self.mesh.f) # [1, H, W, 1]
|
121 |
+
depth = depth.squeeze(0) # [H, W, 1]
|
122 |
+
|
123 |
+
texc, texc_db = dr.interpolate(self.mesh.vt.unsqueeze(0).contiguous(), rast, self.mesh.ft, rast_db=rast_db, diff_attrs='all')
|
124 |
+
albedo = dr.texture(self.raw_albedo.unsqueeze(0), texc, uv_da=texc_db, filter_mode=texture_filter) # [1, H, W, 3]
|
125 |
+
albedo = torch.sigmoid(albedo)
|
126 |
+
if self.enable_dino:
|
127 |
+
# NOTE: backward error when use filter_mode='linear-mipmap-linear'
|
128 |
+
feature = dr.texture(self.raw_feature.unsqueeze(0), texc, uv_da=texc_db, filter_mode='linear')
|
129 |
+
# feature = torch.sigmoid(feature)
|
130 |
+
# get vn and render normal
|
131 |
+
if self.opt.train_geo:
|
132 |
+
i0, i1, i2 = self.mesh.f[:, 0].long(), self.mesh.f[:, 1].long(), self.mesh.f[:, 2].long()
|
133 |
+
v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :]
|
134 |
+
|
135 |
+
face_normals = torch.cross(v1 - v0, v2 - v0)
|
136 |
+
face_normals = safe_normalize(face_normals)
|
137 |
+
|
138 |
+
vn = torch.zeros_like(v)
|
139 |
+
vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)
|
140 |
+
vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)
|
141 |
+
vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)
|
142 |
+
|
143 |
+
vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))
|
144 |
+
else:
|
145 |
+
vn = self.mesh.vn
|
146 |
+
|
147 |
+
normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, self.mesh.fn)
|
148 |
+
normal = safe_normalize(normal[0])
|
149 |
+
|
150 |
+
# rotated normal (where [0, 0, 1] always faces camera)
|
151 |
+
rot_normal = normal @ pose[:3, :3]
|
152 |
+
viewcos = rot_normal[..., [2]]
|
153 |
+
|
154 |
+
# antialias
|
155 |
+
albedo = dr.antialias(albedo, rast, v_clip, self.mesh.f).squeeze(0) # [H, W, 3]
|
156 |
+
albedo = alpha * albedo + (1 - alpha) * bg_color
|
157 |
+
|
158 |
+
if self.enable_dino:
|
159 |
+
feature = dr.antialias(feature, rast, v_clip, self.mesh.f).squeeze(0) # [H, W, 3]
|
160 |
+
feature = alpha * feature + (1 - alpha) * bg_color
|
161 |
+
|
162 |
+
# ssaa
|
163 |
+
if ssaa != 1:
|
164 |
+
albedo = scale_img_hwc(albedo, (h0, w0))
|
165 |
+
alpha = scale_img_hwc(alpha, (h0, w0))
|
166 |
+
depth = scale_img_hwc(depth, (h0, w0))
|
167 |
+
normal = scale_img_hwc(normal, (h0, w0))
|
168 |
+
viewcos = scale_img_hwc(viewcos, (h0, w0))
|
169 |
+
if self.enable_dino:
|
170 |
+
feature = scale_img_hwc(feature, (h0, w0))
|
171 |
+
|
172 |
+
results['image'] = albedo.clamp(0, 1)
|
173 |
+
results['alpha'] = alpha
|
174 |
+
results['depth'] = depth
|
175 |
+
results['normal'] = (normal + 1) / 2
|
176 |
+
results['viewcos'] = viewcos
|
177 |
+
results['feature'] = feature if self.enable_dino else None # [H, W, 384]
|
178 |
+
|
179 |
+
return results
|
180 |
+
|
181 |
+
|
182 |
+
def render_batch(self, pose, proj, h0, w0, ssaa=1, bg_color=1, texture_filter='linear-mipmap-linear'):
|
183 |
+
|
184 |
+
# do super-sampling
|
185 |
+
if ssaa != 1:
|
186 |
+
h = make_divisible(h0 * ssaa, 8)
|
187 |
+
w = make_divisible(w0 * ssaa, 8)
|
188 |
+
else:
|
189 |
+
h, w = h0, w0
|
190 |
+
|
191 |
+
results = {}
|
192 |
+
|
193 |
+
# get v
|
194 |
+
if self.opt.train_geo:
|
195 |
+
v = self.mesh.v + self.v_offsets # [N, 3]
|
196 |
+
else:
|
197 |
+
v = self.mesh.v
|
198 |
+
|
199 |
+
bs = pose.shape[0]
|
200 |
+
pose = pose.to(v.device)
|
201 |
+
proj = proj.to(v.device).transpose(1, 2)
|
202 |
+
|
203 |
+
# get v_clip and render rgb
|
204 |
+
v_cam = torch.bmm(F.pad(v, pad=(0, 1), mode='constant', value=1.0).expand(bs, -1, -1), torch.linalg.inv(pose).transpose(1, 2)).float()
|
205 |
+
v_clip = torch.bmm(v_cam, proj)
|
206 |
+
|
207 |
+
rast, rast_db = dr.rasterize(self.glctx, v_clip, self.mesh.f, (h, w))
|
208 |
+
|
209 |
+
alpha = (rast[..., 3:] > 0).float()
|
210 |
+
depth, _ = dr.interpolate(-v_cam[..., [2]], rast, self.mesh.f) # [1, H, W, 1]
|
211 |
+
|
212 |
+
texc, texc_db = dr.interpolate(self.mesh.vt.expand(bs, -1, -1).contiguous(), rast, self.mesh.ft, rast_db=rast_db, diff_attrs='all')
|
213 |
+
albedo = dr.texture(self.raw_albedo.detach().unsqueeze(0).contiguous(), texc, uv_da=texc_db, filter_mode=texture_filter) # [1, H, W, 3]
|
214 |
+
albedo = torch.sigmoid(albedo)
|
215 |
+
if self.enable_dino:
|
216 |
+
# NOTE: backward error when use filter_mode='linear-mipmap-linear'
|
217 |
+
feature = dr.texture(self.raw_feature.unsqueeze(0), texc, uv_da=texc_db, filter_mode='linear')
|
218 |
+
# feature = torch.sigmoid(feature)
|
219 |
+
# get vn and render normal
|
220 |
+
if self.opt.train_geo:
|
221 |
+
i0, i1, i2 = self.mesh.f[:, 0].long(), self.mesh.f[:, 1].long(), self.mesh.f[:, 2].long()
|
222 |
+
v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :]
|
223 |
+
|
224 |
+
face_normals = torch.cross(v1 - v0, v2 - v0)
|
225 |
+
face_normals = safe_normalize(face_normals)
|
226 |
+
|
227 |
+
vn = torch.zeros_like(v)
|
228 |
+
vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)
|
229 |
+
vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)
|
230 |
+
vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)
|
231 |
+
|
232 |
+
vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))
|
233 |
+
else:
|
234 |
+
vn = self.mesh.vn
|
235 |
+
|
236 |
+
normal, _ = dr.interpolate(vn.expand(bs, -1, -1).contiguous(), rast, self.mesh.fn)
|
237 |
+
normal = safe_normalize(normal).reshape(bs, -1, 3)
|
238 |
+
|
239 |
+
# rotated normal (where [0, 0, 1] always faces camera)
|
240 |
+
rot_normal = torch.bmm(normal, pose[:, :3, :3]).reshape(bs, h, w, 3)
|
241 |
+
viewcos = rot_normal[..., [2]]
|
242 |
+
|
243 |
+
# antialias
|
244 |
+
albedo = dr.antialias(albedo, rast, v_clip, self.mesh.f) # [H, W, 3]
|
245 |
+
albedo = alpha * albedo + (1 - alpha) * bg_color
|
246 |
+
|
247 |
+
if self.enable_dino:
|
248 |
+
feature = dr.antialias(feature, rast, v_clip, self.mesh.f).squeeze(0) # [H, W, 3]
|
249 |
+
feature = alpha * feature + (1 - alpha) * bg_color
|
250 |
+
|
251 |
+
# ssaa
|
252 |
+
if ssaa != 1:
|
253 |
+
albedo = scale_img_hwc(albedo, (h0, w0))
|
254 |
+
alpha = scale_img_hwc(alpha, (h0, w0))
|
255 |
+
depth = scale_img_hwc(depth, (h0, w0))
|
256 |
+
normal = scale_img_hwc(normal, (h0, w0))
|
257 |
+
viewcos = scale_img_hwc(viewcos, (h0, w0))
|
258 |
+
if self.enable_dino:
|
259 |
+
feature = scale_img_hwc(feature, (h0, w0))
|
260 |
+
|
261 |
+
results['image'] = albedo.clamp(0, 1)
|
262 |
+
results['alpha'] = alpha
|
263 |
+
results['depth'] = depth
|
264 |
+
results['normal'] = (normal + 1) / 2
|
265 |
+
results['viewcos'] = viewcos
|
266 |
+
results['feature'] = feature if self.enable_dino else None # [H, W, 384]
|
267 |
+
|
268 |
+
return results
|
sparseags/mesh_utils/mesh_utils.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pymeshlab as pml
|
3 |
+
|
4 |
+
|
5 |
+
def poisson_mesh_reconstruction(points, normals=None):
|
6 |
+
# points/normals: [N, 3] np.ndarray
|
7 |
+
|
8 |
+
import open3d as o3d
|
9 |
+
|
10 |
+
pcd = o3d.geometry.PointCloud()
|
11 |
+
pcd.points = o3d.utility.Vector3dVector(points)
|
12 |
+
|
13 |
+
# outlier removal
|
14 |
+
pcd, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=10)
|
15 |
+
|
16 |
+
# normals
|
17 |
+
if normals is None:
|
18 |
+
pcd.estimate_normals()
|
19 |
+
else:
|
20 |
+
pcd.normals = o3d.utility.Vector3dVector(normals[ind])
|
21 |
+
|
22 |
+
# visualize
|
23 |
+
o3d.visualization.draw_geometries([pcd], point_show_normal=False)
|
24 |
+
|
25 |
+
mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
|
26 |
+
pcd, depth=9
|
27 |
+
)
|
28 |
+
vertices_to_remove = densities < np.quantile(densities, 0.1)
|
29 |
+
mesh.remove_vertices_by_mask(vertices_to_remove)
|
30 |
+
|
31 |
+
# visualize
|
32 |
+
o3d.visualization.draw_geometries([mesh])
|
33 |
+
|
34 |
+
vertices = np.asarray(mesh.vertices)
|
35 |
+
triangles = np.asarray(mesh.triangles)
|
36 |
+
|
37 |
+
print(
|
38 |
+
f"[INFO] poisson mesh reconstruction: {points.shape} --> {vertices.shape} / {triangles.shape}"
|
39 |
+
)
|
40 |
+
|
41 |
+
return vertices, triangles
|
42 |
+
|
43 |
+
|
44 |
+
def decimate_mesh(
|
45 |
+
verts, faces, target, backend="pymeshlab", remesh=False, optimalplacement=True
|
46 |
+
):
|
47 |
+
# optimalplacement: default is True, but for flat mesh must turn False to prevent spike artifect.
|
48 |
+
|
49 |
+
_ori_vert_shape = verts.shape
|
50 |
+
_ori_face_shape = faces.shape
|
51 |
+
|
52 |
+
if backend == "pyfqmr":
|
53 |
+
import pyfqmr
|
54 |
+
|
55 |
+
solver = pyfqmr.Simplify()
|
56 |
+
solver.setMesh(verts, faces)
|
57 |
+
solver.simplify_mesh(target_count=target, preserve_border=False, verbose=False)
|
58 |
+
verts, faces, normals = solver.getMesh()
|
59 |
+
else:
|
60 |
+
m = pml.Mesh(verts, faces)
|
61 |
+
ms = pml.MeshSet()
|
62 |
+
ms.add_mesh(m, "mesh") # will copy!
|
63 |
+
|
64 |
+
# filters
|
65 |
+
# ms.meshing_decimation_clustering(threshold=pml.Percentage(1))
|
66 |
+
ms.meshing_decimation_quadric_edge_collapse(
|
67 |
+
targetfacenum=int(target), optimalplacement=optimalplacement
|
68 |
+
)
|
69 |
+
|
70 |
+
if remesh:
|
71 |
+
# ms.apply_coord_taubin_smoothing()
|
72 |
+
ms.meshing_isotropic_explicit_remeshing(
|
73 |
+
iterations=3, targetlen=pml.Percentage(1)
|
74 |
+
)
|
75 |
+
|
76 |
+
# extract mesh
|
77 |
+
m = ms.current_mesh()
|
78 |
+
verts = m.vertex_matrix()
|
79 |
+
faces = m.face_matrix()
|
80 |
+
|
81 |
+
print(
|
82 |
+
f"[INFO] mesh decimation: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}"
|
83 |
+
)
|
84 |
+
|
85 |
+
return verts, faces
|
86 |
+
|
87 |
+
|
88 |
+
def clean_mesh(
|
89 |
+
verts,
|
90 |
+
faces,
|
91 |
+
v_pct=1,
|
92 |
+
min_f=64,
|
93 |
+
min_d=20,
|
94 |
+
repair=True,
|
95 |
+
remesh=True,
|
96 |
+
remesh_size=0.01,
|
97 |
+
):
|
98 |
+
# verts: [N, 3]
|
99 |
+
# faces: [N, 3]
|
100 |
+
|
101 |
+
_ori_vert_shape = verts.shape
|
102 |
+
_ori_face_shape = faces.shape
|
103 |
+
|
104 |
+
m = pml.Mesh(verts, faces)
|
105 |
+
ms = pml.MeshSet()
|
106 |
+
ms.add_mesh(m, "mesh") # will copy!
|
107 |
+
|
108 |
+
# filters
|
109 |
+
ms.meshing_remove_unreferenced_vertices() # verts not refed by any faces
|
110 |
+
|
111 |
+
if v_pct > 0:
|
112 |
+
ms.meshing_merge_close_vertices(
|
113 |
+
threshold=pml.Percentage(v_pct)
|
114 |
+
) # 1/10000 of bounding box diagonal
|
115 |
+
|
116 |
+
ms.meshing_remove_duplicate_faces() # faces defined by the same verts
|
117 |
+
ms.meshing_remove_null_faces() # faces with area == 0
|
118 |
+
|
119 |
+
if min_d > 0:
|
120 |
+
ms.meshing_remove_connected_component_by_diameter(
|
121 |
+
mincomponentdiag=pml.Percentage(min_d)
|
122 |
+
)
|
123 |
+
|
124 |
+
if min_f > 0:
|
125 |
+
ms.meshing_remove_connected_component_by_face_number(mincomponentsize=min_f)
|
126 |
+
|
127 |
+
if repair:
|
128 |
+
# ms.meshing_remove_t_vertices(method=0, threshold=40, repeat=True)
|
129 |
+
ms.meshing_repair_non_manifold_edges(method=0)
|
130 |
+
ms.meshing_repair_non_manifold_vertices(vertdispratio=0)
|
131 |
+
|
132 |
+
if remesh:
|
133 |
+
# ms.apply_coord_taubin_smoothing()
|
134 |
+
ms.meshing_isotropic_explicit_remeshing(
|
135 |
+
iterations=3, targetlen=pml.AbsoluteValue(remesh_size)
|
136 |
+
)
|
137 |
+
|
138 |
+
# extract mesh
|
139 |
+
m = ms.current_mesh()
|
140 |
+
verts = m.vertex_matrix()
|
141 |
+
faces = m.face_matrix()
|
142 |
+
|
143 |
+
print(
|
144 |
+
f"[INFO] mesh cleaning: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}"
|
145 |
+
)
|
146 |
+
|
147 |
+
return verts, faces
|
sparseags/render_utils/gs_renderer.py
ADDED
@@ -0,0 +1,1102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
from typing import NamedTuple
|
5 |
+
from plyfile import PlyData, PlyElement
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
from liegroups.torch import SE3
|
11 |
+
from simple_knn._C import distCUDA2
|
12 |
+
|
13 |
+
from sparseags.sh_utils import eval_sh, SH2RGB, RGB2SH
|
14 |
+
from sparseags.mesh_utils.mesh import Mesh
|
15 |
+
from sparseags.mesh_utils.mesh_utils import decimate_mesh, clean_mesh
|
16 |
+
from sparseags.cam_utils import sample_points_from_voxel
|
17 |
+
|
18 |
+
import kiui
|
19 |
+
|
20 |
+
|
21 |
+
def inverse_sigmoid(x):
|
22 |
+
return torch.log(x/(1-x))
|
23 |
+
|
24 |
+
|
25 |
+
def get_expon_lr_func(
|
26 |
+
lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
|
27 |
+
):
|
28 |
+
|
29 |
+
def helper(step):
|
30 |
+
if lr_init == lr_final:
|
31 |
+
# constant lr, ignore other params
|
32 |
+
return lr_init
|
33 |
+
if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
|
34 |
+
# Disable this parameter
|
35 |
+
return 0.0
|
36 |
+
if lr_delay_steps > 0:
|
37 |
+
# A kind of reverse cosine decay.
|
38 |
+
delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
|
39 |
+
0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
|
40 |
+
)
|
41 |
+
else:
|
42 |
+
delay_rate = 1.0
|
43 |
+
t = np.clip(step / max_steps, 0, 1)
|
44 |
+
log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
|
45 |
+
return delay_rate * log_lerp
|
46 |
+
|
47 |
+
return helper
|
48 |
+
|
49 |
+
|
50 |
+
def strip_lowerdiag(L):
|
51 |
+
uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
|
52 |
+
|
53 |
+
uncertainty[:, 0] = L[:, 0, 0]
|
54 |
+
uncertainty[:, 1] = L[:, 0, 1]
|
55 |
+
uncertainty[:, 2] = L[:, 0, 2]
|
56 |
+
uncertainty[:, 3] = L[:, 1, 1]
|
57 |
+
uncertainty[:, 4] = L[:, 1, 2]
|
58 |
+
uncertainty[:, 5] = L[:, 2, 2]
|
59 |
+
return uncertainty
|
60 |
+
|
61 |
+
def strip_symmetric(sym):
|
62 |
+
return strip_lowerdiag(sym)
|
63 |
+
|
64 |
+
|
65 |
+
def gaussian_3d_coeff(xyzs, covs):
|
66 |
+
# xyzs: [N, 3]
|
67 |
+
# covs: [N, 6]
|
68 |
+
x, y, z = xyzs[:, 0], xyzs[:, 1], xyzs[:, 2]
|
69 |
+
a, b, c, d, e, f = covs[:, 0], covs[:, 1], covs[:, 2], covs[:, 3], covs[:, 4], covs[:, 5]
|
70 |
+
|
71 |
+
# eps must be small enough !!!
|
72 |
+
inv_det = 1 / (a * d * f + 2 * e * c * b - e**2 * a - c**2 * d - b**2 * f + 1e-24)
|
73 |
+
inv_a = (d * f - e**2) * inv_det
|
74 |
+
inv_b = (e * c - b * f) * inv_det
|
75 |
+
inv_c = (e * b - c * d) * inv_det
|
76 |
+
inv_d = (a * f - c**2) * inv_det
|
77 |
+
inv_e = (b * c - e * a) * inv_det
|
78 |
+
inv_f = (a * d - b**2) * inv_det
|
79 |
+
|
80 |
+
power = -0.5 * (x**2 * inv_a + y**2 * inv_d + z**2 * inv_f) - x * y * inv_b - x * z * inv_c - y * z * inv_e
|
81 |
+
|
82 |
+
power[power > 0] = -1e10 # abnormal values... make weights 0
|
83 |
+
|
84 |
+
return torch.exp(power)
|
85 |
+
|
86 |
+
|
87 |
+
def build_rotation(r):
|
88 |
+
norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
|
89 |
+
|
90 |
+
q = r / norm[:, None]
|
91 |
+
|
92 |
+
R = torch.zeros((q.size(0), 3, 3), device='cuda')
|
93 |
+
|
94 |
+
r = q[:, 0]
|
95 |
+
x = q[:, 1]
|
96 |
+
y = q[:, 2]
|
97 |
+
z = q[:, 3]
|
98 |
+
|
99 |
+
R[:, 0, 0] = 1 - 2 * (y*y + z*z)
|
100 |
+
R[:, 0, 1] = 2 * (x*y - r*z)
|
101 |
+
R[:, 0, 2] = 2 * (x*z + r*y)
|
102 |
+
R[:, 1, 0] = 2 * (x*y + r*z)
|
103 |
+
R[:, 1, 1] = 1 - 2 * (x*x + z*z)
|
104 |
+
R[:, 1, 2] = 2 * (y*z - r*x)
|
105 |
+
R[:, 2, 0] = 2 * (x*z - r*y)
|
106 |
+
R[:, 2, 1] = 2 * (y*z + r*x)
|
107 |
+
R[:, 2, 2] = 1 - 2 * (x*x + y*y)
|
108 |
+
return R
|
109 |
+
|
110 |
+
|
111 |
+
def build_scaling_rotation(s, r):
|
112 |
+
L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
|
113 |
+
R = build_rotation(r)
|
114 |
+
|
115 |
+
L[:,0,0] = s[:,0]
|
116 |
+
L[:,1,1] = s[:,1]
|
117 |
+
L[:,2,2] = s[:,2]
|
118 |
+
|
119 |
+
L = R @ L
|
120 |
+
return L
|
121 |
+
|
122 |
+
|
123 |
+
class BasicPointCloud(NamedTuple):
|
124 |
+
points: np.array
|
125 |
+
colors: np.array
|
126 |
+
normals: np.array
|
127 |
+
|
128 |
+
|
129 |
+
class GaussianModel:
|
130 |
+
|
131 |
+
def setup_functions(self):
|
132 |
+
def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
|
133 |
+
L = build_scaling_rotation(scaling_modifier * scaling, rotation)
|
134 |
+
actual_covariance = L @ L.transpose(1, 2)
|
135 |
+
symm = strip_symmetric(actual_covariance)
|
136 |
+
return symm
|
137 |
+
|
138 |
+
self.scaling_activation = torch.exp
|
139 |
+
self.scaling_inverse_activation = torch.log
|
140 |
+
|
141 |
+
self.covariance_activation = build_covariance_from_scaling_rotation
|
142 |
+
|
143 |
+
self.opacity_activation = torch.sigmoid
|
144 |
+
self.inverse_opacity_activation = inverse_sigmoid
|
145 |
+
|
146 |
+
self.rotation_activation = torch.nn.functional.normalize
|
147 |
+
|
148 |
+
|
149 |
+
def __init__(self, sh_degree : int):
|
150 |
+
self.active_sh_degree = 0
|
151 |
+
self.max_sh_degree = sh_degree
|
152 |
+
self._xyz = torch.empty(0)
|
153 |
+
self._features_dc = torch.empty(0)
|
154 |
+
self._features_rest = torch.empty(0)
|
155 |
+
self._scaling = torch.empty(0)
|
156 |
+
self._rotation = torch.empty(0)
|
157 |
+
self._opacity = torch.empty(0)
|
158 |
+
self.max_radii2D = torch.empty(0)
|
159 |
+
self.xyz_gradient_accum = torch.empty(0)
|
160 |
+
self.denom = torch.empty(0)
|
161 |
+
self.optimizer = None
|
162 |
+
self.percent_dense = 0
|
163 |
+
self.spatial_lr_scale = 0
|
164 |
+
self.setup_functions()
|
165 |
+
|
166 |
+
def capture(self):
|
167 |
+
return (
|
168 |
+
self.active_sh_degree,
|
169 |
+
self._xyz,
|
170 |
+
self._features_dc,
|
171 |
+
self._features_rest,
|
172 |
+
self._scaling,
|
173 |
+
self._rotation,
|
174 |
+
self._opacity,
|
175 |
+
self.max_radii2D,
|
176 |
+
self.xyz_gradient_accum,
|
177 |
+
self.denom,
|
178 |
+
self.optimizer.state_dict(),
|
179 |
+
self.spatial_lr_scale,
|
180 |
+
)
|
181 |
+
|
182 |
+
def restore(self, model_args, training_args):
|
183 |
+
(self.active_sh_degree,
|
184 |
+
self._xyz,
|
185 |
+
self._features_dc,
|
186 |
+
self._features_rest,
|
187 |
+
self._scaling,
|
188 |
+
self._rotation,
|
189 |
+
self._opacity,
|
190 |
+
self.max_radii2D,
|
191 |
+
xyz_gradient_accum,
|
192 |
+
denom,
|
193 |
+
opt_dict,
|
194 |
+
self.spatial_lr_scale) = model_args
|
195 |
+
self.training_setup(training_args)
|
196 |
+
self.xyz_gradient_accum = xyz_gradient_accum
|
197 |
+
self.denom = denom
|
198 |
+
self.optimizer.load_state_dict(opt_dict)
|
199 |
+
|
200 |
+
@property
|
201 |
+
def get_scaling(self):
|
202 |
+
return self.scaling_activation(self._scaling)
|
203 |
+
|
204 |
+
@property
|
205 |
+
def get_rotation(self):
|
206 |
+
return self.rotation_activation(self._rotation)
|
207 |
+
|
208 |
+
@property
|
209 |
+
def get_xyz(self):
|
210 |
+
return self._xyz
|
211 |
+
|
212 |
+
@property
|
213 |
+
def get_features(self):
|
214 |
+
features_dc = self._features_dc
|
215 |
+
features_rest = self._features_rest
|
216 |
+
if self.enable_dino:
|
217 |
+
return torch.cat((features_dc, features_rest[..., :3]), dim=1), features_rest[..., 3:].reshape(features_rest.shape[0], 1, -1)[..., :self.dino_feat_dim]
|
218 |
+
else:
|
219 |
+
return torch.cat((features_dc, features_rest), dim=1)
|
220 |
+
|
221 |
+
@property
|
222 |
+
def get_opacity(self):
|
223 |
+
return self.opacity_activation(self._opacity)
|
224 |
+
|
225 |
+
@torch.no_grad()
|
226 |
+
def extract_fields(self, resolution=128, num_blocks=16, relax_ratio=1.5):
|
227 |
+
# resolution: resolution of field
|
228 |
+
|
229 |
+
block_size = 2 / num_blocks
|
230 |
+
|
231 |
+
assert resolution % block_size == 0
|
232 |
+
split_size = resolution // num_blocks
|
233 |
+
|
234 |
+
opacities = self.get_opacity
|
235 |
+
|
236 |
+
# pre-filter low opacity gaussians to save computation
|
237 |
+
mask = (opacities > 0.005).squeeze(1)
|
238 |
+
|
239 |
+
opacities = opacities[mask]
|
240 |
+
xyzs = self.get_xyz[mask]
|
241 |
+
stds = self.get_scaling[mask]
|
242 |
+
|
243 |
+
# normalize to ~ [-1, 1]
|
244 |
+
mn, mx = xyzs.amin(0), xyzs.amax(0)
|
245 |
+
self.center = (mn + mx) / 2
|
246 |
+
self.scale = 1.8 / (mx - mn).amax().item()
|
247 |
+
|
248 |
+
xyzs = (xyzs - self.center) * self.scale
|
249 |
+
stds = stds * self.scale
|
250 |
+
|
251 |
+
covs = self.covariance_activation(stds, 1, self._rotation[mask])
|
252 |
+
|
253 |
+
# tile
|
254 |
+
device = opacities.device
|
255 |
+
occ = torch.zeros([resolution] * 3, dtype=torch.float32, device=device)
|
256 |
+
|
257 |
+
X = torch.linspace(-1, 1, resolution).split(split_size)
|
258 |
+
Y = torch.linspace(-1, 1, resolution).split(split_size)
|
259 |
+
Z = torch.linspace(-1, 1, resolution).split(split_size)
|
260 |
+
|
261 |
+
# loop blocks (assume max size of gaussian is small than relax_ratio * block_size !!!)
|
262 |
+
for xi, xs in enumerate(X):
|
263 |
+
for yi, ys in enumerate(Y):
|
264 |
+
for zi, zs in enumerate(Z):
|
265 |
+
xx, yy, zz = torch.meshgrid(xs, ys, zs)
|
266 |
+
# sample points [M, 3]
|
267 |
+
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1).to(device)
|
268 |
+
# in-tile gaussians mask
|
269 |
+
vmin, vmax = pts.amin(0), pts.amax(0)
|
270 |
+
vmin -= block_size * relax_ratio
|
271 |
+
vmax += block_size * relax_ratio
|
272 |
+
mask = (xyzs < vmax).all(-1) & (xyzs > vmin).all(-1)
|
273 |
+
# if hit no gaussian, continue to next block
|
274 |
+
if not mask.any():
|
275 |
+
continue
|
276 |
+
mask_xyzs = xyzs[mask] # [L, 3]
|
277 |
+
mask_covs = covs[mask] # [L, 6]
|
278 |
+
mask_opas = opacities[mask].view(1, -1) # [L, 1] --> [1, L]
|
279 |
+
|
280 |
+
# query per point-gaussian pair.
|
281 |
+
g_pts = pts.unsqueeze(1).repeat(1, mask_covs.shape[0], 1) - mask_xyzs.unsqueeze(0) # [M, L, 3]
|
282 |
+
g_covs = mask_covs.unsqueeze(0).repeat(pts.shape[0], 1, 1) # [M, L, 6]
|
283 |
+
|
284 |
+
# batch on gaussian to avoid OOM
|
285 |
+
batch_g = 1024
|
286 |
+
val = 0
|
287 |
+
for start in range(0, g_covs.shape[1], batch_g):
|
288 |
+
end = min(start + batch_g, g_covs.shape[1])
|
289 |
+
w = gaussian_3d_coeff(g_pts[:, start:end].reshape(-1, 3), g_covs[:, start:end].reshape(-1, 6)).reshape(pts.shape[0], -1) # [M, l]
|
290 |
+
val += (mask_opas[:, start:end] * w).sum(-1)
|
291 |
+
|
292 |
+
# kiui.lo(val, mask_opas, w)
|
293 |
+
|
294 |
+
occ[xi * split_size: xi * split_size + len(xs),
|
295 |
+
yi * split_size: yi * split_size + len(ys),
|
296 |
+
zi * split_size: zi * split_size + len(zs)] = val.reshape(len(xs), len(ys), len(zs))
|
297 |
+
|
298 |
+
kiui.lo(occ, verbose=1)
|
299 |
+
|
300 |
+
return occ
|
301 |
+
|
302 |
+
def extract_mesh(self, path, density_thresh=1, resolution=128, decimate_target=1e5):
|
303 |
+
|
304 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
305 |
+
|
306 |
+
occ = self.extract_fields(resolution).detach().cpu().numpy()
|
307 |
+
|
308 |
+
import mcubes
|
309 |
+
vertices, triangles = mcubes.marching_cubes(occ, density_thresh)
|
310 |
+
vertices = vertices / (resolution - 1.0) * 2 - 1
|
311 |
+
|
312 |
+
# transform back to the original space
|
313 |
+
vertices = vertices / self.scale + self.center.detach().cpu().numpy()
|
314 |
+
|
315 |
+
vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=0.015)
|
316 |
+
if decimate_target > 0 and triangles.shape[0] > decimate_target:
|
317 |
+
vertices, triangles = decimate_mesh(vertices, triangles, decimate_target)
|
318 |
+
|
319 |
+
v = torch.from_numpy(vertices.astype(np.float32)).contiguous().cuda()
|
320 |
+
f = torch.from_numpy(triangles.astype(np.int32)).contiguous().cuda()
|
321 |
+
|
322 |
+
print(
|
323 |
+
f"[INFO] marching cubes result: {v.shape} ({v.min().item()}-{v.max().item()}), {f.shape}"
|
324 |
+
)
|
325 |
+
|
326 |
+
mesh = Mesh(v=v, f=f, device='cuda')
|
327 |
+
|
328 |
+
return mesh
|
329 |
+
|
330 |
+
def get_covariance(self, scaling_modifier = 1):
|
331 |
+
return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
|
332 |
+
|
333 |
+
def oneupSHdegree(self):
|
334 |
+
if self.active_sh_degree < self.max_sh_degree:
|
335 |
+
self.active_sh_degree += 1
|
336 |
+
|
337 |
+
def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float = 1):
|
338 |
+
self.spatial_lr_scale = spatial_lr_scale
|
339 |
+
fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
|
340 |
+
fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
|
341 |
+
features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
|
342 |
+
features[:, :3, 0 ] = fused_color
|
343 |
+
features[:, 3:, 1:] = 0.0
|
344 |
+
|
345 |
+
print("Number of points at initialisation : ", fused_point_cloud.shape[0])
|
346 |
+
|
347 |
+
dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
|
348 |
+
scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
|
349 |
+
rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
|
350 |
+
rots[:, 0] = 1
|
351 |
+
|
352 |
+
opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
|
353 |
+
|
354 |
+
self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
|
355 |
+
self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
|
356 |
+
if self.enable_dino:
|
357 |
+
# Overide the original features
|
358 |
+
_features_rest = features[:,:,1:].transpose(1, 2).contiguous().cuda()
|
359 |
+
dim_rest = _features_rest.shape[1]
|
360 |
+
_semantic_features = torch.randn(self._xyz.shape[0], dim_rest, self.dino_feat_dim//dim_rest + 1).cuda()
|
361 |
+
self._features_rest = nn.Parameter(torch.cat([_features_rest, _semantic_features], dim=-1).requires_grad_(True))
|
362 |
+
else:
|
363 |
+
self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
|
364 |
+
self._scaling = nn.Parameter(scales.requires_grad_(True))
|
365 |
+
self._rotation = nn.Parameter(rots.requires_grad_(True))
|
366 |
+
self._opacity = nn.Parameter(opacities.requires_grad_(True))
|
367 |
+
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
|
368 |
+
|
369 |
+
def training_setup(self, training_args):
|
370 |
+
self.percent_dense = training_args.percent_dense
|
371 |
+
self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
|
372 |
+
self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
|
373 |
+
|
374 |
+
l = [
|
375 |
+
{'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
|
376 |
+
{'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
|
377 |
+
{'params': [self._features_rest], 'lr': training_args.feature_lr / 20, "name": "f_rest"}, # /20
|
378 |
+
{'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
|
379 |
+
{'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"},
|
380 |
+
{'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"},
|
381 |
+
]
|
382 |
+
|
383 |
+
if training_args.opt_cam:
|
384 |
+
l.append({'params': self.cam_params, 'lr': training_args.camera_lr, "name": "cam_params"})
|
385 |
+
|
386 |
+
self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
|
387 |
+
self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
|
388 |
+
lr_final=training_args.position_lr_final*self.spatial_lr_scale,
|
389 |
+
lr_delay_mult=training_args.position_lr_delay_mult,
|
390 |
+
max_steps=training_args.position_lr_max_steps)
|
391 |
+
|
392 |
+
def update_learning_rate(self, iteration):
|
393 |
+
''' Learning rate scheduling per step '''
|
394 |
+
for param_group in self.optimizer.param_groups:
|
395 |
+
if param_group["name"] == "xyz":
|
396 |
+
if iteration > 500:
|
397 |
+
iteration = iteration % 500
|
398 |
+
lr = self.xyz_scheduler_args(iteration)
|
399 |
+
param_group['lr'] = lr
|
400 |
+
|
401 |
+
def construct_list_of_attributes(self):
|
402 |
+
l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
|
403 |
+
# All channels except the 3 DC
|
404 |
+
for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
|
405 |
+
l.append('f_dc_{}'.format(i))
|
406 |
+
for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
|
407 |
+
l.append('f_rest_{}'.format(i))
|
408 |
+
l.append('opacity')
|
409 |
+
for i in range(self._scaling.shape[1]):
|
410 |
+
l.append('scale_{}'.format(i))
|
411 |
+
for i in range(self._rotation.shape[1]):
|
412 |
+
l.append('rot_{}'.format(i))
|
413 |
+
return l
|
414 |
+
|
415 |
+
def save_ply(self, path):
|
416 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
417 |
+
|
418 |
+
xyz = self._xyz.detach().cpu().numpy()
|
419 |
+
normals = np.zeros_like(xyz)
|
420 |
+
f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
|
421 |
+
f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
|
422 |
+
opacities = self._opacity.detach().cpu().numpy()
|
423 |
+
scale = self._scaling.detach().cpu().numpy()
|
424 |
+
rotation = self._rotation.detach().cpu().numpy()
|
425 |
+
|
426 |
+
dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
|
427 |
+
|
428 |
+
elements = np.empty(xyz.shape[0], dtype=dtype_full)
|
429 |
+
attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
|
430 |
+
elements[:] = list(map(tuple, attributes))
|
431 |
+
el = PlyElement.describe(elements, 'vertex')
|
432 |
+
PlyData([el]).write(path)
|
433 |
+
|
434 |
+
def reset_opacity(self):
|
435 |
+
opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
|
436 |
+
optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
|
437 |
+
self._opacity = optimizable_tensors["opacity"]
|
438 |
+
|
439 |
+
def load_ply(self, path):
|
440 |
+
plydata = PlyData.read(path)
|
441 |
+
|
442 |
+
xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
|
443 |
+
np.asarray(plydata.elements[0]["y"]),
|
444 |
+
np.asarray(plydata.elements[0]["z"])), axis=1)
|
445 |
+
opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
|
446 |
+
|
447 |
+
print("Number of points at loading : ", xyz.shape[0])
|
448 |
+
|
449 |
+
features_dc = np.zeros((xyz.shape[0], 3, 1))
|
450 |
+
features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
|
451 |
+
features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
|
452 |
+
features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
|
453 |
+
|
454 |
+
extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
|
455 |
+
assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
|
456 |
+
features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
|
457 |
+
for idx, attr_name in enumerate(extra_f_names):
|
458 |
+
features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
459 |
+
# Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
|
460 |
+
features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
|
461 |
+
|
462 |
+
scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
|
463 |
+
scales = np.zeros((xyz.shape[0], len(scale_names)))
|
464 |
+
for idx, attr_name in enumerate(scale_names):
|
465 |
+
scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
466 |
+
|
467 |
+
rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
|
468 |
+
rots = np.zeros((xyz.shape[0], len(rot_names)))
|
469 |
+
for idx, attr_name in enumerate(rot_names):
|
470 |
+
rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
471 |
+
|
472 |
+
self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
|
473 |
+
self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
|
474 |
+
self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
|
475 |
+
self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
|
476 |
+
self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
|
477 |
+
self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
|
478 |
+
|
479 |
+
self.active_sh_degree = self.max_sh_degree
|
480 |
+
|
481 |
+
def replace_tensor_to_optimizer(self, tensor, name):
|
482 |
+
optimizable_tensors = {}
|
483 |
+
for group in self.optimizer.param_groups:
|
484 |
+
if len(group["params"]) != 1:
|
485 |
+
continue
|
486 |
+
if group["name"] == name:
|
487 |
+
stored_state = self.optimizer.state.get(group['params'][0], None)
|
488 |
+
stored_state["exp_avg"] = torch.zeros_like(tensor)
|
489 |
+
stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
|
490 |
+
|
491 |
+
del self.optimizer.state[group['params'][0]]
|
492 |
+
group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
|
493 |
+
self.optimizer.state[group['params'][0]] = stored_state
|
494 |
+
|
495 |
+
optimizable_tensors[group["name"]] = group["params"][0]
|
496 |
+
return optimizable_tensors
|
497 |
+
|
498 |
+
def _prune_optimizer(self, mask):
|
499 |
+
optimizable_tensors = {}
|
500 |
+
for group in self.optimizer.param_groups:
|
501 |
+
if len(group["params"]) != 1:
|
502 |
+
continue
|
503 |
+
|
504 |
+
stored_state = self.optimizer.state.get(group['params'][0], None)
|
505 |
+
if stored_state is not None:
|
506 |
+
stored_state["exp_avg"] = stored_state["exp_avg"][mask]
|
507 |
+
stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
|
508 |
+
|
509 |
+
del self.optimizer.state[group['params'][0]]
|
510 |
+
group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
|
511 |
+
self.optimizer.state[group['params'][0]] = stored_state
|
512 |
+
|
513 |
+
optimizable_tensors[group["name"]] = group["params"][0]
|
514 |
+
else:
|
515 |
+
group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
|
516 |
+
optimizable_tensors[group["name"]] = group["params"][0]
|
517 |
+
return optimizable_tensors
|
518 |
+
|
519 |
+
def prune_points(self, mask):
|
520 |
+
valid_points_mask = ~mask
|
521 |
+
optimizable_tensors = self._prune_optimizer(valid_points_mask)
|
522 |
+
|
523 |
+
self._xyz = optimizable_tensors["xyz"]
|
524 |
+
self._features_dc = optimizable_tensors["f_dc"]
|
525 |
+
self._features_rest = optimizable_tensors["f_rest"]
|
526 |
+
self._opacity = optimizable_tensors["opacity"]
|
527 |
+
self._scaling = optimizable_tensors["scaling"]
|
528 |
+
self._rotation = optimizable_tensors["rotation"]
|
529 |
+
|
530 |
+
self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
|
531 |
+
|
532 |
+
self.denom = self.denom[valid_points_mask]
|
533 |
+
self.max_radii2D = self.max_radii2D[valid_points_mask]
|
534 |
+
|
535 |
+
def cat_tensors_to_optimizer(self, tensors_dict):
|
536 |
+
optimizable_tensors = {}
|
537 |
+
for group in self.optimizer.param_groups:
|
538 |
+
if len(group["params"]) != 1:
|
539 |
+
continue
|
540 |
+
assert len(group["params"]) == 1
|
541 |
+
extension_tensor = tensors_dict[group["name"]]
|
542 |
+
stored_state = self.optimizer.state.get(group['params'][0], None)
|
543 |
+
if stored_state is not None:
|
544 |
+
|
545 |
+
stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
|
546 |
+
stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
|
547 |
+
|
548 |
+
del self.optimizer.state[group['params'][0]]
|
549 |
+
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
|
550 |
+
self.optimizer.state[group['params'][0]] = stored_state
|
551 |
+
|
552 |
+
optimizable_tensors[group["name"]] = group["params"][0]
|
553 |
+
else:
|
554 |
+
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
|
555 |
+
optimizable_tensors[group["name"]] = group["params"][0]
|
556 |
+
|
557 |
+
return optimizable_tensors
|
558 |
+
|
559 |
+
def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation):
|
560 |
+
d = {"xyz": new_xyz,
|
561 |
+
"f_dc": new_features_dc,
|
562 |
+
"f_rest": new_features_rest,
|
563 |
+
"opacity": new_opacities,
|
564 |
+
"scaling" : new_scaling,
|
565 |
+
"rotation" : new_rotation}
|
566 |
+
|
567 |
+
optimizable_tensors = self.cat_tensors_to_optimizer(d)
|
568 |
+
self._xyz = optimizable_tensors["xyz"]
|
569 |
+
self._features_dc = optimizable_tensors["f_dc"]
|
570 |
+
self._features_rest = optimizable_tensors["f_rest"]
|
571 |
+
self._opacity = optimizable_tensors["opacity"]
|
572 |
+
self._scaling = optimizable_tensors["scaling"]
|
573 |
+
self._rotation = optimizable_tensors["rotation"]
|
574 |
+
|
575 |
+
self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
|
576 |
+
self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
|
577 |
+
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
|
578 |
+
|
579 |
+
def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
|
580 |
+
n_init_points = self.get_xyz.shape[0]
|
581 |
+
# Extract points that satisfy the gradient condition
|
582 |
+
padded_grad = torch.zeros((n_init_points), device="cuda")
|
583 |
+
padded_grad[:grads.shape[0]] = grads.squeeze()
|
584 |
+
selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
|
585 |
+
selected_pts_mask = torch.logical_and(selected_pts_mask,
|
586 |
+
torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent
|
587 |
+
)
|
588 |
+
|
589 |
+
stds = self.get_scaling[selected_pts_mask].repeat(N,1)
|
590 |
+
means =torch.zeros((stds.size(0), 3),device="cuda")
|
591 |
+
samples = torch.normal(mean=means, std=stds)
|
592 |
+
rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
|
593 |
+
new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
|
594 |
+
new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))
|
595 |
+
new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
|
596 |
+
new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
|
597 |
+
new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
|
598 |
+
new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
|
599 |
+
|
600 |
+
self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation)
|
601 |
+
|
602 |
+
prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
|
603 |
+
self.prune_points(prune_filter)
|
604 |
+
|
605 |
+
def densify_and_clone(self, grads, grad_threshold, scene_extent):
|
606 |
+
# Extract points that satisfy the gradient condition
|
607 |
+
selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
|
608 |
+
selected_pts_mask = torch.logical_and(selected_pts_mask,
|
609 |
+
torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent
|
610 |
+
)
|
611 |
+
|
612 |
+
new_xyz = self._xyz[selected_pts_mask]
|
613 |
+
new_features_dc = self._features_dc[selected_pts_mask]
|
614 |
+
new_features_rest = self._features_rest[selected_pts_mask]
|
615 |
+
new_opacities = self._opacity[selected_pts_mask]
|
616 |
+
new_scaling = self._scaling[selected_pts_mask]
|
617 |
+
new_rotation = self._rotation[selected_pts_mask]
|
618 |
+
|
619 |
+
self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation)
|
620 |
+
|
621 |
+
def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
|
622 |
+
grads = self.xyz_gradient_accum / self.denom
|
623 |
+
grads[grads.isnan()] = 0.0
|
624 |
+
|
625 |
+
self.densify_and_clone(grads, max_grad, extent)
|
626 |
+
self.densify_and_split(grads, max_grad, extent)
|
627 |
+
|
628 |
+
prune_mask = (self.get_opacity < min_opacity).squeeze()
|
629 |
+
if max_screen_size:
|
630 |
+
big_points_vs = self.max_radii2D > max_screen_size
|
631 |
+
big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
|
632 |
+
prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
|
633 |
+
self.prune_points(prune_mask)
|
634 |
+
|
635 |
+
torch.cuda.empty_cache()
|
636 |
+
|
637 |
+
def prune(self, min_opacity, extent, max_screen_size):
|
638 |
+
|
639 |
+
prune_mask = (self.get_opacity < min_opacity).squeeze()
|
640 |
+
if max_screen_size:
|
641 |
+
big_points_vs = self.max_radii2D > max_screen_size
|
642 |
+
big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
|
643 |
+
prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
|
644 |
+
self.prune_points(prune_mask)
|
645 |
+
|
646 |
+
torch.cuda.empty_cache()
|
647 |
+
|
648 |
+
|
649 |
+
def add_densification_stats(self, viewspace_point_tensor, update_filter):
|
650 |
+
self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True)
|
651 |
+
self.denom[update_filter] += 1
|
652 |
+
|
653 |
+
|
654 |
+
def getProjectionMatrix(znear, zfar, fx, fy, cx, cy):
|
655 |
+
# TODO: remove hard-coded image size
|
656 |
+
P = torch.zeros(4, 4)
|
657 |
+
|
658 |
+
z_sign = 1.0
|
659 |
+
|
660 |
+
P[0, 0] = 2 * fx / 256
|
661 |
+
P[1, 1] = 2 * fy / 256
|
662 |
+
P[0, 2] = 2 * (cx / 256) - 1
|
663 |
+
P[1, 2] = 2 * (cy / 256) - 1
|
664 |
+
P[2, 2] = z_sign * zfar / (zfar - znear)
|
665 |
+
P[3, 2] = z_sign
|
666 |
+
P[2, 3] = -(zfar * znear) / (zfar - znear)
|
667 |
+
return P
|
668 |
+
|
669 |
+
|
670 |
+
def getProjectionMatrixFoV(znear, zfar, fovX, fovY):
|
671 |
+
tanHalfFovY = math.tan((fovY / 2))
|
672 |
+
tanHalfFovX = math.tan((fovX / 2))
|
673 |
+
|
674 |
+
P = torch.zeros(4, 4)
|
675 |
+
|
676 |
+
z_sign = 1.0
|
677 |
+
|
678 |
+
P[0, 0] = 1 / tanHalfFovX
|
679 |
+
P[1, 1] = 1 / tanHalfFovY
|
680 |
+
P[3, 2] = z_sign
|
681 |
+
P[2, 2] = z_sign * zfar / (zfar - znear)
|
682 |
+
P[2, 3] = -(zfar * znear) / (zfar - znear)
|
683 |
+
return P
|
684 |
+
|
685 |
+
|
686 |
+
class Camera:
|
687 |
+
def __init__(self, c2w, width, height, fx, fy, cx, cy, znear=0.01, zfar=100, opt_pose=False):
|
688 |
+
# c2w (pose) should be in NeRF convention.
|
689 |
+
|
690 |
+
self.image_width = width
|
691 |
+
self.image_height = height
|
692 |
+
self.fx, self.fy = fx, fy
|
693 |
+
self.cx, self.cy = cx, cy
|
694 |
+
self.FoVy = 2 * np.arctan(256 / 2 / self.fy)
|
695 |
+
self.FoVx = 2 * np.arctan(256 / 2 / self.fx)
|
696 |
+
self.znear = znear
|
697 |
+
self.zfar = zfar
|
698 |
+
self.opt_pose = opt_pose
|
699 |
+
|
700 |
+
self.projection_matrix = (
|
701 |
+
getProjectionMatrix(
|
702 |
+
znear=self.znear,
|
703 |
+
zfar=self.zfar,
|
704 |
+
fx=self.fx,
|
705 |
+
fy=self.fy,
|
706 |
+
cx=self.cx,
|
707 |
+
cy=self.cy,
|
708 |
+
)
|
709 |
+
.transpose(0, 1)
|
710 |
+
.cuda()
|
711 |
+
)
|
712 |
+
|
713 |
+
w2c = np.linalg.inv(c2w)
|
714 |
+
|
715 |
+
# OpenGL to OpenCV
|
716 |
+
w2c[1:3] *= -1
|
717 |
+
|
718 |
+
self.world_view_transform = torch.tensor(w2c).transpose(0, 1).cuda()
|
719 |
+
self.full_proj_transform = self.world_view_transform @ self.projection_matrix
|
720 |
+
self.camera_center = torch.tensor(c2w[:3, 3]).cuda()
|
721 |
+
|
722 |
+
|
723 |
+
class FoVCamera:
|
724 |
+
def __init__(self, c2w, width, height, fovy, fovx, znear, zfar, cam_params=None, opt_pose=False):
|
725 |
+
# c2w (pose) should be in NeRF convention.
|
726 |
+
|
727 |
+
self.image_width = width
|
728 |
+
self.image_height = height
|
729 |
+
self.FoVy = fovy
|
730 |
+
self.FoVx = fovx
|
731 |
+
self.znear = znear
|
732 |
+
self.zfar = zfar
|
733 |
+
self.opt_pose = opt_pose
|
734 |
+
|
735 |
+
self.projection_matrix = (
|
736 |
+
getProjectionMatrixFoV(
|
737 |
+
znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy
|
738 |
+
)
|
739 |
+
.transpose(0, 1)
|
740 |
+
.cuda()
|
741 |
+
)
|
742 |
+
|
743 |
+
w2c = np.linalg.inv(c2w)
|
744 |
+
|
745 |
+
# OpenGL to OpenCV
|
746 |
+
w2c[1:3] *= -1
|
747 |
+
|
748 |
+
self.world_view_transform = torch.tensor(w2c).transpose(0, 1).cuda()
|
749 |
+
self.full_proj_transform = self.world_view_transform @ self.projection_matrix
|
750 |
+
self.camera_center = torch.tensor(c2w[:3, 3]).cuda()
|
751 |
+
|
752 |
+
|
753 |
+
class CustomCamera:
|
754 |
+
def __init__(self, cam_params=None, index=None, c2w=None, opt_pose=False):
|
755 |
+
# TODO: remove hard-coded image size
|
756 |
+
# c2w (pose) should be in NeRF convention.
|
757 |
+
# This this the camera class that supports pose optimization.
|
758 |
+
|
759 |
+
self.image_width, self.image_height = 256, 256
|
760 |
+
self.fx, self.fy = cam_params["focal_length"]
|
761 |
+
self.cx, self.cy = cam_params["principal_point"]
|
762 |
+
self.FoVy = 2 * np.arctan(self.image_height / 2 / self.fy)
|
763 |
+
self.FoVx = 2 * np.arctan(self.image_width / 2 / self.fx)
|
764 |
+
self.R = torch.tensor(cam_params["R"])
|
765 |
+
self.T = torch.tensor(cam_params["T"])
|
766 |
+
self.znear = 0.01
|
767 |
+
self.zfar = 100
|
768 |
+
self.opt_pose = opt_pose
|
769 |
+
self.index = index
|
770 |
+
|
771 |
+
self.projection_matrix = (
|
772 |
+
getProjectionMatrix(
|
773 |
+
znear=self.znear,
|
774 |
+
zfar=self.zfar,
|
775 |
+
fx=self.fx,
|
776 |
+
fy=self.fy,
|
777 |
+
cx=self.cx,
|
778 |
+
cy=self.cy,
|
779 |
+
)
|
780 |
+
.transpose(0, 1)
|
781 |
+
.cuda()
|
782 |
+
)
|
783 |
+
|
784 |
+
if not opt_pose:
|
785 |
+
if c2w:
|
786 |
+
w2c = torch.from_numpy(c2w)
|
787 |
+
w2c[1:3] *= -1 # OpenGL to OpenCV
|
788 |
+
else:
|
789 |
+
R = self.R.T # note the transpose here
|
790 |
+
T = self.T
|
791 |
+
upper = torch.cat([R, T[:, None]], dim=1) # Upper 3x4 part of the matrix
|
792 |
+
lower = torch.tensor([[0, 0, 0, 1]], device=R.device, dtype=R.dtype) # Last row
|
793 |
+
w2c = torch.cat([upper, lower], dim=0)
|
794 |
+
|
795 |
+
w2c[:2] *= -1 # PyTorch3D to OpenCV
|
796 |
+
|
797 |
+
self.w2c = w2c
|
798 |
+
self.cam_params = torch.zeros(6)
|
799 |
+
self.world_view_transform = w2c.transpose(0, 1).cuda()
|
800 |
+
self.full_proj_transform = self.world_view_transform @ self.projection_matrix
|
801 |
+
self.camera_center = self.world_view_transform.inverse()[3, :3]
|
802 |
+
else:
|
803 |
+
R = self.R.T # note the transpose here
|
804 |
+
T = self.T
|
805 |
+
upper = torch.cat([R, T[:, None]], dim=1) # Upper 3x4 part of the matrix
|
806 |
+
lower = torch.tensor([[0, 0, 0, 1]], device=R.device, dtype=R.dtype) # Last row
|
807 |
+
w2c = torch.cat([upper, lower], dim=0)
|
808 |
+
|
809 |
+
w2c[:2] *= -1 # PyTorch3D to OpenCV
|
810 |
+
|
811 |
+
self.w2c = w2c
|
812 |
+
self.cam_params = torch.randn(6) * 1e-6
|
813 |
+
self.cam_params.requires_grad_()
|
814 |
+
|
815 |
+
self.world_view_transform = w2c.transpose(0, 1).cuda()
|
816 |
+
self.full_proj_transform = self.world_view_transform @ self.projection_matrix
|
817 |
+
self.camera_center = self.world_view_transform.inverse()[3, :3]
|
818 |
+
|
819 |
+
@property
|
820 |
+
def perspective(self):
|
821 |
+
P = torch.zeros(4, 4)
|
822 |
+
|
823 |
+
z_sign = -1.0
|
824 |
+
|
825 |
+
P[0, 0] = 2 * self.fx / 256
|
826 |
+
P[1, 1] = -2 * self.fy / 256
|
827 |
+
P[0, 2] = -(2 * (self.cx / 256) - 1)
|
828 |
+
P[1, 2] = -(2 * (self.cy / 256) - 1)
|
829 |
+
P[2, 2] = z_sign * self.zfar / (self.zfar - self.znear)
|
830 |
+
P[3, 2] = z_sign
|
831 |
+
P[2, 3] = -(self.zfar * self.znear) / (self.zfar - self.znear)
|
832 |
+
return P.numpy()
|
833 |
+
|
834 |
+
@property
|
835 |
+
def c2w(self):
|
836 |
+
if self.opt_pose:
|
837 |
+
w2c = self.w2c @ SE3.exp(self.cam_params.detach()).as_matrix()
|
838 |
+
w2c[1:3] *= -1 # OpenCV to OpenGL
|
839 |
+
else:
|
840 |
+
R = self.R.T # note the transpose here
|
841 |
+
T = self.T
|
842 |
+
upper = torch.cat([R, T[:, None]], dim=1) # Upper 3x4 part of the matrix
|
843 |
+
lower = torch.tensor([[0, 0, 0, 1]], device=R.device, dtype=R.dtype) # Last row
|
844 |
+
w2c = torch.cat([upper, lower], dim=0)
|
845 |
+
w2c[:2, :] *= -1 # PyTorch3D to OpenCV
|
846 |
+
w2c[1:3, :] *= -1 # OpenCV to OpenGL
|
847 |
+
|
848 |
+
return torch.inverse(w2c).numpy()
|
849 |
+
|
850 |
+
@property
|
851 |
+
def focal_length(self):
|
852 |
+
return np.array([self.fx, self.fy])
|
853 |
+
|
854 |
+
@property
|
855 |
+
def rotation(self):
|
856 |
+
w2c = self.w2c @ SE3.exp(self.cam_params.detach()).as_matrix()
|
857 |
+
w2c[:2] *= -1
|
858 |
+
return w2c[:3, :3].T
|
859 |
+
|
860 |
+
@property
|
861 |
+
def translation(self):
|
862 |
+
w2c = self.w2c @ SE3.exp(self.cam_params.detach()).as_matrix()
|
863 |
+
w2c[:2] *= -1
|
864 |
+
return w2c[:3, 3]
|
865 |
+
|
866 |
+
|
867 |
+
class Renderer:
|
868 |
+
def __init__(self, sh_degree=3, white_background=True, radius=1):
|
869 |
+
|
870 |
+
self.sh_degree = sh_degree
|
871 |
+
self.white_background = white_background
|
872 |
+
self.radius = radius
|
873 |
+
self.enable_dino = None
|
874 |
+
|
875 |
+
self.gaussians = GaussianModel(sh_degree)
|
876 |
+
|
877 |
+
self.bg_color = torch.tensor(
|
878 |
+
[1, 1, 1] if white_background else [0, 0, 0],
|
879 |
+
dtype=torch.float32,
|
880 |
+
device="cuda",
|
881 |
+
)
|
882 |
+
|
883 |
+
def initialize(self, input=None, num_pts=5000, radius=0.5, cameras=None, imgs=None, masks=None, point_maps=None, mode='sphere'):
|
884 |
+
# load checkpoint
|
885 |
+
if input is None and mode in ['sphere', "carve", "inverse_carve"]:
|
886 |
+
# init from random point cloud
|
887 |
+
|
888 |
+
if mode == 'sphere':
|
889 |
+
phis = np.random.random((num_pts,)) * 2 * np.pi
|
890 |
+
costheta = np.random.random((num_pts,)) * 2 - 1
|
891 |
+
thetas = np.arccos(costheta)
|
892 |
+
mu = np.random.random((num_pts,))
|
893 |
+
radius = radius * np.cbrt(mu)
|
894 |
+
x = radius * np.sin(thetas) * np.cos(phis)
|
895 |
+
y = radius * np.sin(thetas) * np.sin(phis)
|
896 |
+
z = radius * np.cos(thetas)
|
897 |
+
xyz = np.stack((x, y, z), axis=1)
|
898 |
+
|
899 |
+
elif mode == "carve":
|
900 |
+
try:
|
901 |
+
xyz = sample_points_from_voxel(cameras, masks, radius, N=num_pts).cpu().numpy()
|
902 |
+
except RuntimeError:
|
903 |
+
radius = 0.3
|
904 |
+
phis = np.random.random((num_pts,)) * 2 * np.pi
|
905 |
+
costheta = np.random.random((num_pts,)) * 2 - 1
|
906 |
+
thetas = np.arccos(costheta)
|
907 |
+
mu = np.random.random((num_pts,))
|
908 |
+
radius = radius * np.cbrt(mu)
|
909 |
+
x = radius * np.sin(thetas) * np.cos(phis)
|
910 |
+
y = radius * np.sin(thetas) * np.sin(phis)
|
911 |
+
z = radius * np.cos(thetas)
|
912 |
+
xyz = np.stack((x, y, z), axis=1)
|
913 |
+
|
914 |
+
elif mode == "inverse_carve":
|
915 |
+
xyz = sample_points_from_voxel(cameras, masks, radius, N=num_pts, inverse=True).cpu().numpy()
|
916 |
+
|
917 |
+
shs = np.random.random((num_pts, 3)) / 255.0
|
918 |
+
pcd = BasicPointCloud(
|
919 |
+
points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))
|
920 |
+
)
|
921 |
+
self.gaussians.create_from_pcd(pcd, 10)
|
922 |
+
|
923 |
+
elif input is None and mode == "dust3r":
|
924 |
+
num_points = sum([np.count_nonzero(masks[i]) for i in range(8)])
|
925 |
+
xyz = np.zeros((num_points, 3))
|
926 |
+
colors = np.zeros((num_points, 3))
|
927 |
+
|
928 |
+
# Iterate through data and add points to xyz and colors arrays
|
929 |
+
index = 0
|
930 |
+
for i in range(len(point_maps)):
|
931 |
+
rgb = imgs[i].reshape(-1, 3)
|
932 |
+
point_map = point_maps[i].reshape(-1, 3).detach().cpu().numpy()
|
933 |
+
for j, include_point in enumerate((masks[i] > 0.5).flatten()):
|
934 |
+
if include_point == 1:
|
935 |
+
xyz[index] = point_map[j]
|
936 |
+
colors[index] = rgb[j]
|
937 |
+
index += 1
|
938 |
+
|
939 |
+
# Check if index matches expected number of points
|
940 |
+
assert index == num_points, "Number of points does not match expected count"
|
941 |
+
|
942 |
+
pcd = BasicPointCloud(
|
943 |
+
points=xyz, colors=colors, normals=np.zeros((len(point_maps)*224*224, 3))
|
944 |
+
)
|
945 |
+
self.gaussians.create_from_pcd(pcd, 10)
|
946 |
+
|
947 |
+
elif isinstance(input, BasicPointCloud):
|
948 |
+
# load from a provided pcd
|
949 |
+
self.gaussians.create_from_pcd(input, 1)
|
950 |
+
else:
|
951 |
+
# load from saved ply
|
952 |
+
self.gaussians.load_ply(input)
|
953 |
+
|
954 |
+
def render(
|
955 |
+
self,
|
956 |
+
viewpoint_camera,
|
957 |
+
scaling_modifier=1.0,
|
958 |
+
bg_color=None,
|
959 |
+
override_color=None,
|
960 |
+
compute_cov3D_python=False,
|
961 |
+
convert_SHs_python=False,
|
962 |
+
):
|
963 |
+
if self.enable_dino:
|
964 |
+
from diff_gaussian_rasterization_feature import (
|
965 |
+
GaussianRasterizationSettings,
|
966 |
+
GaussianRasterizer,
|
967 |
+
)
|
968 |
+
else:
|
969 |
+
from diff_gaussian_rasterization import (
|
970 |
+
GaussianRasterizationSettings,
|
971 |
+
GaussianRasterizer,
|
972 |
+
)
|
973 |
+
|
974 |
+
if viewpoint_camera.opt_pose:
|
975 |
+
w2c = viewpoint_camera.w2c @ SE3.exp(viewpoint_camera.cam_params).as_matrix()
|
976 |
+
w2c = w2c.to("cuda")
|
977 |
+
|
978 |
+
viewpoint_camera.world_view_transform = w2c.transpose(0, 1)
|
979 |
+
viewpoint_camera.full_proj_transform = viewpoint_camera.world_view_transform @ viewpoint_camera.projection_matrix
|
980 |
+
viewpoint_camera.camera_center = viewpoint_camera.world_view_transform.inverse()[3, :3]
|
981 |
+
|
982 |
+
# Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
|
983 |
+
screenspace_points = (
|
984 |
+
torch.zeros_like(
|
985 |
+
self.gaussians.get_xyz,
|
986 |
+
dtype=self.gaussians.get_xyz.dtype,
|
987 |
+
requires_grad=True,
|
988 |
+
device="cuda",
|
989 |
+
)
|
990 |
+
+ 0
|
991 |
+
)
|
992 |
+
try:
|
993 |
+
screenspace_points.retain_grad()
|
994 |
+
except:
|
995 |
+
pass
|
996 |
+
|
997 |
+
# Set up rasterization configuration
|
998 |
+
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
|
999 |
+
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
|
1000 |
+
|
1001 |
+
raster_settings = GaussianRasterizationSettings(
|
1002 |
+
image_height=int(viewpoint_camera.image_height),
|
1003 |
+
image_width=int(viewpoint_camera.image_width),
|
1004 |
+
tanfovx=tanfovx,
|
1005 |
+
tanfovy=tanfovy,
|
1006 |
+
bg=self.bg_color if bg_color is None else bg_color,
|
1007 |
+
scale_modifier=scaling_modifier,
|
1008 |
+
viewmatrix=viewpoint_camera.world_view_transform,
|
1009 |
+
perspectivematrix=viewpoint_camera.projection_matrix,
|
1010 |
+
projmatrix=viewpoint_camera.full_proj_transform,
|
1011 |
+
sh_degree=self.gaussians.active_sh_degree,
|
1012 |
+
campos=viewpoint_camera.camera_center,
|
1013 |
+
prefiltered=False,
|
1014 |
+
debug=False,
|
1015 |
+
)
|
1016 |
+
|
1017 |
+
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
|
1018 |
+
|
1019 |
+
means3D = self.gaussians.get_xyz
|
1020 |
+
means2D = screenspace_points
|
1021 |
+
opacity = self.gaussians.get_opacity
|
1022 |
+
|
1023 |
+
# If precomputed 3d covariance is provided, use it. If not, then it will be computed from
|
1024 |
+
# scaling / rotation by the rasterizer.
|
1025 |
+
scales = None
|
1026 |
+
rotations = None
|
1027 |
+
cov3D_precomp = None
|
1028 |
+
if compute_cov3D_python:
|
1029 |
+
cov3D_precomp = self.gaussians.get_covariance(scaling_modifier)
|
1030 |
+
else:
|
1031 |
+
scales = self.gaussians.get_scaling
|
1032 |
+
rotations = self.gaussians.get_rotation
|
1033 |
+
|
1034 |
+
# If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
|
1035 |
+
# from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
|
1036 |
+
shs = None
|
1037 |
+
colors_precomp = None
|
1038 |
+
if colors_precomp is None:
|
1039 |
+
if convert_SHs_python:
|
1040 |
+
shs_view = self.gaussians.get_features.transpose(1, 2).view(
|
1041 |
+
-1, 3, (self.gaussians.max_sh_degree + 1) ** 2
|
1042 |
+
)
|
1043 |
+
dir_pp = self.gaussians.get_xyz - viewpoint_camera.camera_center.repeat(
|
1044 |
+
self.gaussians.get_features.shape[0], 1
|
1045 |
+
)
|
1046 |
+
dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
|
1047 |
+
sh2rgb = eval_sh(
|
1048 |
+
self.gaussians.active_sh_degree, shs_view, dir_pp_normalized
|
1049 |
+
)
|
1050 |
+
colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
|
1051 |
+
else:
|
1052 |
+
shs = self.gaussians.get_features
|
1053 |
+
else:
|
1054 |
+
colors_precomp = override_color
|
1055 |
+
|
1056 |
+
if self.enable_dino:
|
1057 |
+
shs, semantic_feature = shs
|
1058 |
+
|
1059 |
+
rendered_image, rendered_feature, radii, rendered_depth, rendered_alpha = rasterizer(
|
1060 |
+
means3D=means3D,
|
1061 |
+
means2D=means2D,
|
1062 |
+
shs=shs,
|
1063 |
+
semantic_feature=semantic_feature,
|
1064 |
+
colors_precomp=colors_precomp,
|
1065 |
+
opacities=opacity,
|
1066 |
+
scales=scales,
|
1067 |
+
rotations=rotations,
|
1068 |
+
cov3D_precomp=cov3D_precomp,
|
1069 |
+
viewmat=viewpoint_camera.world_view_transform,
|
1070 |
+
)
|
1071 |
+
|
1072 |
+
else:
|
1073 |
+
# Rasterize visible Gaussians to image, obtain their radii (on screen).
|
1074 |
+
rendered_image, radii, rendered_depth, rendered_alpha = rasterizer(
|
1075 |
+
means3D=means3D,
|
1076 |
+
means2D=means2D,
|
1077 |
+
shs=shs,
|
1078 |
+
colors_precomp=colors_precomp,
|
1079 |
+
opacities=opacity,
|
1080 |
+
scales=scales,
|
1081 |
+
rotations=rotations,
|
1082 |
+
cov3D_precomp=cov3D_precomp,
|
1083 |
+
viewmat=viewpoint_camera.world_view_transform,
|
1084 |
+
)
|
1085 |
+
|
1086 |
+
rendered_image = rendered_image.clamp(0, 1)
|
1087 |
+
|
1088 |
+
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
|
1089 |
+
# They will be excluded from value updates used in the splitting criteria.
|
1090 |
+
ret = {
|
1091 |
+
"image": rendered_image,
|
1092 |
+
"depth": rendered_depth,
|
1093 |
+
"alpha": rendered_alpha,
|
1094 |
+
"viewspace_points": screenspace_points,
|
1095 |
+
"visibility_filter": radii > 0,
|
1096 |
+
"radii": radii,
|
1097 |
+
}
|
1098 |
+
|
1099 |
+
if self.enable_dino:
|
1100 |
+
ret["feature"] = rendered_feature
|
1101 |
+
|
1102 |
+
return ret
|
sparseags/render_utils/util.py
ADDED
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import gc
|
4 |
+
import copy
|
5 |
+
import tqdm
|
6 |
+
import torchvision
|
7 |
+
import shutil
|
8 |
+
import argparse
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
from torchvision.utils import save_image
|
12 |
+
from omegaconf import OmegaConf
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
from kiui.lpips import LPIPS
|
19 |
+
from liegroups.torch import SE3
|
20 |
+
|
21 |
+
import sys
|
22 |
+
sys.path.append('./')
|
23 |
+
|
24 |
+
from sparseags.render_utils.gs_renderer import CustomCamera
|
25 |
+
from sparseags.mesh_utils.mesh_renderer import Renderer
|
26 |
+
from sparseags.cam_utils import OrbitCamera, mat2latlon
|
27 |
+
|
28 |
+
|
29 |
+
def safe_normalize(x):
|
30 |
+
return x / x.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-8)
|
31 |
+
|
32 |
+
|
33 |
+
def look_at(campos, target, opengl=True):
|
34 |
+
if not opengl:
|
35 |
+
forward_vector = safe_normalize(target - campos)
|
36 |
+
up_vector = torch.tensor([0, 1, 0], dtype=campos.dtype, device=campos.device).expand_as(forward_vector)
|
37 |
+
right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
|
38 |
+
up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1))
|
39 |
+
else:
|
40 |
+
forward_vector = safe_normalize(campos - target)
|
41 |
+
up_vector = torch.tensor([0, 1, 0], dtype=campos.dtype, device=campos.device).expand_as(forward_vector)
|
42 |
+
right_vector = safe_normalize(torch.cross(up_vector, forward_vector, dim=-1))
|
43 |
+
up_vector = safe_normalize(torch.cross(forward_vector, right_vector, dim=-1))
|
44 |
+
R = torch.stack([right_vector, up_vector, forward_vector], dim=-1)
|
45 |
+
return R
|
46 |
+
|
47 |
+
|
48 |
+
def orbit_camera(elevation, azimuth, radius=1.0, is_degree=True, target=None, opengl=True):
|
49 |
+
"""Converts elevation & azimuth to a batch of camera pose matrices."""
|
50 |
+
if is_degree:
|
51 |
+
elevation = torch.deg2rad(elevation)
|
52 |
+
azimuth = torch.deg2rad(azimuth)
|
53 |
+
x = radius * torch.cos(elevation) * torch.sin(azimuth)
|
54 |
+
y = -radius * torch.sin(elevation)
|
55 |
+
z = radius * torch.cos(elevation) * torch.cos(azimuth)
|
56 |
+
if target is None:
|
57 |
+
target = torch.zeros(3, dtype=torch.float32, device=elevation.device)
|
58 |
+
campos = torch.stack([x, y, z], dim=-1) + target
|
59 |
+
R = look_at(campos, target.unsqueeze(0).expand_as(campos), opengl)
|
60 |
+
T = torch.eye(4, dtype=torch.float32, device=elevation.device).unsqueeze(0).expand(campos.shape[0], -1, -1).clone()
|
61 |
+
T[:, :3, :3] = R
|
62 |
+
T[:, :3, 3] = campos
|
63 |
+
return T
|
64 |
+
|
65 |
+
|
66 |
+
def render_and_compare(camera_data, mesh_path, save_path, num_views=8):
|
67 |
+
parser = argparse.ArgumentParser()
|
68 |
+
parser.add_argument('--object', type=str, help="path to mesh (obj, ply, glb, ...)")
|
69 |
+
parser.add_argument('--path', type=str, help="path to mesh (obj, ply, glb, ...)")
|
70 |
+
parser.add_argument('--front_dir', type=str, default='+z', help="mesh front-facing dir")
|
71 |
+
parser.add_argument('--mode', default='albedo', type=str, choices=['lambertian', 'albedo', 'normal', 'depth'], help="rendering mode")
|
72 |
+
parser.add_argument('--W', type=int, default=256, help="GUI width")
|
73 |
+
parser.add_argument('--H', type=int, default=256, help="GUI height")
|
74 |
+
parser.add_argument("--wogui", type=bool, default=True, help="disable all dpg GUI")
|
75 |
+
parser.add_argument("--force_cuda_rast", action='store_true', help="force to use RasterizeCudaContext.")
|
76 |
+
parser.add_argument("--config", default='configs/navi.yaml', help="path to the yaml config file")
|
77 |
+
parser.add_argument('--radius', type=float, default=3, help="default GUI camera radius from center")
|
78 |
+
parser.add_argument('--fovy', type=float, default=49.1, help="default GUI camera fovy")
|
79 |
+
args, extras = parser.parse_known_args()
|
80 |
+
|
81 |
+
# override default config from cli
|
82 |
+
opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras))
|
83 |
+
data = camera_data
|
84 |
+
|
85 |
+
opt.mesh = mesh_path
|
86 |
+
opt.trainable_texture = False
|
87 |
+
renderer = Renderer(opt).to(torch.device("cuda"))
|
88 |
+
target = renderer.mesh.v.mean(dim=0)
|
89 |
+
|
90 |
+
cameras = [CustomCamera(cam_params) for cam_params in data.values()]
|
91 |
+
# cams = [(cam.c2w, cam.perspective, cam.focal_length) for cam in cameras]
|
92 |
+
img_paths = [v["filepath"] for k, v in data.items()]
|
93 |
+
flags = [int(v["flag"]) for k, v in data.items()]
|
94 |
+
|
95 |
+
cam_centers = [mat2latlon(cam.camera_center - target) for idx, cam in enumerate(cameras) if flags[idx]]
|
96 |
+
ref_polars = [float(cam[0]) for cam in cam_centers]
|
97 |
+
ref_azimuths = [float(cam[1]) for cam in cam_centers]
|
98 |
+
ref_radii = [float(cam[2]) for cam in cam_centers]
|
99 |
+
|
100 |
+
base_cam = copy.copy(cameras[0])
|
101 |
+
base_cam.fx = np.array([cam.fx for idx, cam in enumerate(cameras) if flags[idx]], dtype=np.float32).mean()
|
102 |
+
base_cam.fy = np.array([cam.fy for idx, cam in enumerate(cameras) if flags[idx]], dtype=np.float32).mean()
|
103 |
+
base_cam.cx = 128
|
104 |
+
base_cam.cy = 128
|
105 |
+
|
106 |
+
lpips_loss = LPIPS(net='vgg').cuda()
|
107 |
+
elevation_range = (max([min(ref_polars) - 20, -89.9]), min([max(ref_polars) + 20, 89.9]))
|
108 |
+
azimuth_range = (-180, 180)
|
109 |
+
radius_range = (min(ref_radii) - 0.2, max(ref_radii) + 0.2)
|
110 |
+
|
111 |
+
elevation_steps = torch.arange(elevation_range[0], elevation_range[1], 15, dtype=torch.float32)
|
112 |
+
azimuth_steps = torch.arange(azimuth_range[0], azimuth_range[1], 15, dtype=torch.float32)
|
113 |
+
radius_steps = torch.arange(radius_range[0], radius_range[1], 0.2, dtype=torch.float32)
|
114 |
+
elevation_grid, azimuth_grid, radius_grid = torch.meshgrid(elevation_steps, azimuth_steps, radius_steps, indexing='ij')
|
115 |
+
pose_grid = torch.stack((elevation_grid.flatten(), azimuth_grid.flatten(), radius_grid.flatten()), dim=1)
|
116 |
+
|
117 |
+
poses = orbit_camera(pose_grid[:, 0], pose_grid[:, 1], pose_grid[:, 2], target=target.cpu())
|
118 |
+
print("Number of hypotheses:", poses.shape[0])
|
119 |
+
s1_steps = 128
|
120 |
+
s2_steps = 256
|
121 |
+
beta = 0.25
|
122 |
+
chunk_size = 512
|
123 |
+
|
124 |
+
for i in tqdm.tqdm(range(num_views)):
|
125 |
+
if flags[i]:
|
126 |
+
continue
|
127 |
+
|
128 |
+
pose_grid = torch.stack((elevation_grid.flatten(), azimuth_grid.flatten(), radius_grid.flatten()), dim=1)
|
129 |
+
|
130 |
+
poses = orbit_camera(pose_grid[:, 0], pose_grid[:, 1], pose_grid[:, 2], target=target.cpu())
|
131 |
+
|
132 |
+
img_path = img_paths[i]
|
133 |
+
base_cam.fx = cameras[i].fx
|
134 |
+
base_cam.fy = cameras[i].fy
|
135 |
+
perspectives = torch.from_numpy(base_cam.perspective).expand(pose_grid.shape[0], -1, -1)
|
136 |
+
|
137 |
+
learnable_cam_params = torch.randn(pose_grid.shape[0], 6) * 1e-6
|
138 |
+
learnable_cam_params.requires_grad_()
|
139 |
+
|
140 |
+
loss_MSE_grid = np.zeros(pose_grid.shape[0])
|
141 |
+
loss_LPIPS_grid = np.zeros(pose_grid.shape[0])
|
142 |
+
loss = 0
|
143 |
+
|
144 |
+
gt_img = Image.open(img_path)
|
145 |
+
if gt_img.mode == 'RGBA':
|
146 |
+
gt_img = np.asarray(gt_img, dtype=np.uint8).copy()
|
147 |
+
gt_mask = (gt_img[..., 3:] > 128).astype(np.float32)
|
148 |
+
gt_img[gt_img[:, :, -1] <= 255*0.9] = [255., 255., 255., 255.] # thresholding background
|
149 |
+
gt_img = gt_img[:, :, :3]
|
150 |
+
|
151 |
+
gt_tensor = torch.from_numpy(gt_img).float().unsqueeze(0).cuda() / 255.
|
152 |
+
gt_mask_tensor = torch.from_numpy(gt_mask).float().unsqueeze(0).cuda()
|
153 |
+
|
154 |
+
num_batches = pose_grid.shape[0] // chunk_size + int(pose_grid.shape[0]%chunk_size > 0)
|
155 |
+
|
156 |
+
# Render images for visualization
|
157 |
+
vis_img = torch.zeros(pose_grid.shape[0], 256, 256, 3)
|
158 |
+
for j in tqdm.tqdm(range(num_batches)):
|
159 |
+
batch_poses = poses[j*chunk_size:(j+1)*chunk_size]
|
160 |
+
batch_perspectives = perspectives[j*chunk_size:(j+1)*chunk_size]
|
161 |
+
with torch.no_grad():
|
162 |
+
out = renderer.render_batch(batch_poses, batch_perspectives, 256, 256, ssaa=1) # (500, 256, 256, 3)
|
163 |
+
# batch_image = (out["image"].detach().cpu().numpy() * 255).astype(np.uint8)
|
164 |
+
batch_image = out["image"].detach().cpu()
|
165 |
+
vis_img[j*chunk_size:(j+1)*chunk_size] = batch_image
|
166 |
+
|
167 |
+
l = [{'params': learnable_cam_params, 'lr': 5e-3, "name": "cam_params"}]
|
168 |
+
optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
|
169 |
+
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
|
170 |
+
|
171 |
+
init_lr = optimizer.param_groups[0]['lr']
|
172 |
+
for j in tqdm.tqdm(range(num_batches)):
|
173 |
+
batch_poses = poses[j*chunk_size:(j+1)*chunk_size]
|
174 |
+
batch_perspectives = perspectives[j*chunk_size:(j+1)*chunk_size]
|
175 |
+
optimizer.param_groups[0]['lr'] = init_lr
|
176 |
+
for k in tqdm.tqdm(range(s1_steps)):
|
177 |
+
batch_residuals = SE3.exp(learnable_cam_params[j*chunk_size:(j+1)*chunk_size]).as_matrix() # [5760, 4, 4]
|
178 |
+
batch_poses_opt = torch.bmm(batch_poses, batch_residuals)
|
179 |
+
out = renderer.render_batch(batch_poses_opt, batch_perspectives, 256, 256, ssaa=1) # (500, 256, 256, 3)
|
180 |
+
pred_tensor = out["image"]
|
181 |
+
valid_mask = (out["alpha"] > 0) & (out["viewcos"] > 0.5) # (500, 256, 256, 1)
|
182 |
+
|
183 |
+
if k == s1_steps - 1:
|
184 |
+
loss = F.mse_loss(pred_tensor, gt_tensor.expand(pred_tensor.shape[0], -1, -1, -1), reduction='none').mean(dim=(1, 2, 3))
|
185 |
+
loss_MSE_grid[j*chunk_size:(j+1)*chunk_size] = loss.detach().cpu().numpy()
|
186 |
+
loss = loss.mean()
|
187 |
+
|
188 |
+
else:
|
189 |
+
loss = F.mse_loss(pred_tensor, gt_tensor.expand(pred_tensor.shape[0], -1, -1, -1), reduction='mean')
|
190 |
+
|
191 |
+
loss.backward()
|
192 |
+
optimizer.step()
|
193 |
+
optimizer.zero_grad()
|
194 |
+
scheduler.step()
|
195 |
+
|
196 |
+
# Render optimized images for visualization
|
197 |
+
# vis_img_optimized = torch.zeros(pose_grid.shape[0], 256, 256, 3)
|
198 |
+
# for j in tqdm.tqdm(range(num_batches)):
|
199 |
+
# batch_poses = poses[j*chunk_size:(j+1)*chunk_size]
|
200 |
+
# batch_perspectives = perspectives[j*chunk_size:(j+1)*chunk_size]
|
201 |
+
# batch_residuals = SE3.exp(learnable_cam_params[j*chunk_size:(j+1)*chunk_size]).as_matrix() # [5760, 4, 4]
|
202 |
+
# batch_poses_opt = torch.bmm(batch_poses, batch_residuals)
|
203 |
+
# with torch.no_grad():
|
204 |
+
# out = renderer.render_batch(batch_poses_opt, batch_perspectives, 256, 256, ssaa=1) # (500, 256, 256, 3)
|
205 |
+
# # batch_image = (out["image"].detach().cpu().numpy() * 255).astype(np.uint8)
|
206 |
+
# batch_image = out["image"].detach().cpu()
|
207 |
+
# vis_img_optimized[j*chunk_size:(j+1)*chunk_size] = batch_image
|
208 |
+
|
209 |
+
# indices = np.argsort(loss_MSE_grid)
|
210 |
+
# padding = (pose_grid.shape[0] // 10 + int(pose_grid.shape[0]%10 > 0)) * 10 - pose_grid.shape[0]
|
211 |
+
# grid = vis_img[indices].permute(0, 3, 1, 2).contiguous()
|
212 |
+
# padded_gird = torch.cat([grid, torch.ones(padding, 3, 256, 256)], dim=0)
|
213 |
+
# padded_gird = padded_gird.view((padding + pose_grid.shape[0]) // 10, 10, 3, 256, 256).permute(2, 0, 3, 1, 4)
|
214 |
+
# padded_gird = padded_gird.reshape(3, -1, 2560)
|
215 |
+
# output_path = os.path.join(save_path, f'vis1_candidates_{i}.png')
|
216 |
+
# save_image(padded_gird, output_path)
|
217 |
+
|
218 |
+
# grid = vis_img_optimized[indices].permute(0, 3, 1, 2).contiguous()
|
219 |
+
# padded_gird = torch.cat([grid, torch.ones(padding, 3, 256, 256)], dim=0)
|
220 |
+
# padded_gird = padded_gird.view((padding + pose_grid.shape[0]) // 10, 10, 3, 256, 256).permute(2, 0, 3, 1, 4)
|
221 |
+
# padded_gird = padded_gird.reshape(3, -1, 2560)
|
222 |
+
# output_path = os.path.join(save_path, f'vis1_optimized_candidates_{i}.png')
|
223 |
+
# save_image(padded_gird, output_path)
|
224 |
+
|
225 |
+
beta = 0.1
|
226 |
+
indices = np.argsort(loss_MSE_grid)[:max(int(loss_MSE_grid.shape[0] * beta), 64)]
|
227 |
+
batch_poses = poses[indices]
|
228 |
+
batch_residuals = SE3.exp(learnable_cam_params[indices].detach()).as_matrix() # [5760, 4, 4]
|
229 |
+
poses = torch.bmm(batch_poses, batch_residuals) # [216, 4, 4]
|
230 |
+
poses = poses.repeat(4, 1, 1)
|
231 |
+
|
232 |
+
learnable_cam_params = torch.randn(poses.shape[0], 6) * 1e-1
|
233 |
+
learnable_cam_params.requires_grad_()
|
234 |
+
|
235 |
+
optimizer.param_groups = []
|
236 |
+
optimizer.add_param_group({'params': learnable_cam_params})
|
237 |
+
|
238 |
+
perspectives = torch.from_numpy(cameras[i].perspective).expand(poses.shape[0], -1, -1)
|
239 |
+
loss_MSE_grid = np.zeros(poses.shape[0])
|
240 |
+
|
241 |
+
num_batches = poses.shape[0] // chunk_size + int(poses.shape[0]%chunk_size > 0)
|
242 |
+
for j in tqdm.tqdm(range(num_batches)):
|
243 |
+
batch_poses = poses[j*chunk_size:(j+1)*chunk_size]
|
244 |
+
batch_perspectives = perspectives[j*chunk_size:(j+1)*chunk_size]
|
245 |
+
optimizer.param_groups[0]['lr'] = 1e-3
|
246 |
+
for k in tqdm.tqdm(range(s2_steps)):
|
247 |
+
batch_residuals = SE3.exp(learnable_cam_params[j*chunk_size:(j+1)*chunk_size]).as_matrix() # [5760, 4, 4]
|
248 |
+
batch_poses_opt = torch.bmm(batch_poses, batch_residuals)
|
249 |
+
out = renderer.render_batch(batch_poses_opt, batch_perspectives, 256, 256, ssaa=1) # (500, 256, 256, 3)
|
250 |
+
pred_tensor = out["image"]
|
251 |
+
valid_mask = (out["alpha"] > 0) & (out["viewcos"] > 0.5) # (500, 256, 256, 1)
|
252 |
+
# batch_image = (out["image"].detach().cpu().numpy() * 255).astype(np.uint8)
|
253 |
+
# del batch_pose, batch_perspective
|
254 |
+
|
255 |
+
if k == s2_steps - 1:
|
256 |
+
loss = F.mse_loss(pred_tensor, gt_tensor.expand(pred_tensor.shape[0], -1, -1, -1), reduction='none').mean(dim=(1, 2, 3))
|
257 |
+
# loss += F.mse_loss(valid_mask, gt_mask_tensor.expand(pred_tensor.shape[0], -1, -1, -1), reduction='none').mean(dim=(1, 2, 3))
|
258 |
+
loss_MSE_grid[j*chunk_size:(j+1)*chunk_size] = loss.detach().cpu().numpy()
|
259 |
+
loss = loss.mean()
|
260 |
+
|
261 |
+
else:
|
262 |
+
loss = F.mse_loss(pred_tensor, gt_tensor.expand(pred_tensor.shape[0], -1, -1, -1), reduction='mean')
|
263 |
+
|
264 |
+
loss.backward()
|
265 |
+
optimizer.step()
|
266 |
+
optimizer.zero_grad()
|
267 |
+
scheduler.step()
|
268 |
+
|
269 |
+
beta = 0.1
|
270 |
+
indices = np.argsort(loss_MSE_grid)[:max(int(loss_MSE_grid.shape[0] * beta), 64)]
|
271 |
+
batch_poses = poses[indices]
|
272 |
+
batch_residuals = SE3.exp(learnable_cam_params[indices].detach()).as_matrix() # [5760, 4, 4]
|
273 |
+
poses = torch.bmm(batch_poses, batch_residuals) # [216, 4, 4]
|
274 |
+
poses = poses.repeat(4, 1, 1)
|
275 |
+
|
276 |
+
learnable_cam_params = torch.randn(poses.shape[0], 6) * 1e-2
|
277 |
+
learnable_cam_params.requires_grad_()
|
278 |
+
|
279 |
+
optimizer.param_groups = []
|
280 |
+
optimizer.add_param_group({'params': learnable_cam_params})
|
281 |
+
|
282 |
+
perspectives = torch.from_numpy(cameras[i].perspective).expand(poses.shape[0], -1, -1)
|
283 |
+
loss_MSE_grid = np.zeros(poses.shape[0])
|
284 |
+
|
285 |
+
num_batches = poses.shape[0] // chunk_size + int(poses.shape[0]%chunk_size > 0)
|
286 |
+
for j in tqdm.tqdm(range(num_batches)):
|
287 |
+
batch_poses = poses[j*chunk_size:(j+1)*chunk_size]
|
288 |
+
batch_perspectives = perspectives[j*chunk_size:(j+1)*chunk_size]
|
289 |
+
optimizer.param_groups[0]['lr'] = 1e-3
|
290 |
+
for k in tqdm.tqdm(range(s2_steps)):
|
291 |
+
batch_residuals = SE3.exp(learnable_cam_params[j*chunk_size:(j+1)*chunk_size]).as_matrix() # [5760, 4, 4]
|
292 |
+
batch_poses_opt = torch.bmm(batch_poses, batch_residuals)
|
293 |
+
out = renderer.render_batch(batch_poses_opt, batch_perspectives, 256, 256, ssaa=1) # (500, 256, 256, 3)
|
294 |
+
pred_tensor = out["image"]
|
295 |
+
valid_mask = (out["alpha"] > 0) & (out["viewcos"] > 0.5) # (500, 256, 256, 1)
|
296 |
+
|
297 |
+
if k == s2_steps - 1:
|
298 |
+
loss = F.mse_loss(pred_tensor, gt_tensor.expand(pred_tensor.shape[0], -1, -1, -1), reduction='none').mean(dim=(1, 2, 3))
|
299 |
+
# loss += F.mse_loss(valid_mask, gt_mask_tensor.expand(pred_tensor.shape[0], -1, -1, -1), reduction='none').mean(dim=(1, 2, 3))
|
300 |
+
loss_MSE_grid[j*chunk_size:(j+1)*chunk_size] = loss.detach().cpu().numpy()
|
301 |
+
loss = loss.mean()
|
302 |
+
|
303 |
+
else:
|
304 |
+
loss = F.mse_loss(pred_tensor, gt_tensor.expand(pred_tensor.shape[0], -1, -1, -1), reduction='mean')
|
305 |
+
|
306 |
+
loss.backward()
|
307 |
+
optimizer.step()
|
308 |
+
optimizer.zero_grad()
|
309 |
+
scheduler.step()
|
310 |
+
|
311 |
+
pose_grid = poses
|
312 |
+
loss_LPIPS_grid = np.zeros(poses.shape[0])
|
313 |
+
|
314 |
+
chunk_size = 64
|
315 |
+
gt_tensor = gt_tensor.permute(0, 3, 1, 2).contiguous()
|
316 |
+
vis_img_opt = np.zeros((pose_grid.shape[0], 256, 256, 3), dtype=np.uint8)
|
317 |
+
num_batches = pose_grid.shape[0] // chunk_size + int(pose_grid.shape[0]%chunk_size > 0)
|
318 |
+
for j in tqdm.tqdm(range(num_batches)):
|
319 |
+
batch_poses = poses[j*chunk_size:(j+1)*chunk_size]
|
320 |
+
batch_residuals = SE3.exp(learnable_cam_params[j*chunk_size:(j+1)*chunk_size]).as_matrix() # [5760, 4, 4]
|
321 |
+
batch_poses_opt = torch.bmm(batch_poses, batch_residuals)
|
322 |
+
batch_perspectives = perspectives[j*chunk_size:(j+1)*chunk_size]
|
323 |
+
with torch.no_grad():
|
324 |
+
out = renderer.render_batch(batch_poses_opt, batch_perspectives, 256, 256, ssaa=1) # (500, 256, 256, 3)
|
325 |
+
batch_image = (out["image"].detach().cpu().numpy() * 255).astype(np.uint8)
|
326 |
+
vis_img_opt[j*chunk_size:(j+1)*chunk_size] = batch_image
|
327 |
+
|
328 |
+
pred_tensor = out["image"].permute(0, 3, 1, 2).contiguous()
|
329 |
+
with torch.no_grad():
|
330 |
+
loss_LPIPS_grid[j*chunk_size:(j+1)*chunk_size] = lpips_loss(pred_tensor, gt_tensor.expand(pred_tensor.shape[0], -1, -1, -1)).squeeze().cpu().numpy()
|
331 |
+
|
332 |
+
# indices_of_smallest = np.argsort(loss_MSE_grid)[:15]
|
333 |
+
indices1 = np.argsort(loss_MSE_grid)
|
334 |
+
indices2 = np.argsort(loss_LPIPS_grid)
|
335 |
+
|
336 |
+
ranks1 = np.zeros_like(loss_MSE_grid)
|
337 |
+
ranks2 = np.zeros_like(loss_LPIPS_grid)
|
338 |
+
|
339 |
+
ranks1[indices1] = np.arange(1, loss_MSE_grid.size + 1)
|
340 |
+
ranks2[indices2] = np.arange(1, loss_LPIPS_grid.size + 1)
|
341 |
+
|
342 |
+
total_ranks = ranks1 + ranks2
|
343 |
+
indices_of_smallest = np.argsort(total_ranks)[:15]
|
344 |
+
|
345 |
+
index = indices_of_smallest[0]
|
346 |
+
residual = SE3.exp(learnable_cam_params[index].detach()).as_matrix() # [5760, 4, 4]
|
347 |
+
c2w = poses[index] @ residual
|
348 |
+
w2c = torch.inverse(c2w)
|
349 |
+
|
350 |
+
w2c[1:3, :] *= -1 # OpenCV to OpenGL
|
351 |
+
w2c[:2, :] *= -1 # PyTorch3D to OpenCV
|
352 |
+
|
353 |
+
data[list(data.keys())[i]]["R"] = w2c[:3, :3].T.tolist()
|
354 |
+
data[list(data.keys())[i]]["T"] = w2c[:3, 3].tolist()
|
355 |
+
|
356 |
+
num_frames = 16
|
357 |
+
cmap = plt.get_cmap("hot")
|
358 |
+
num_rows = 2
|
359 |
+
num_cols = 8
|
360 |
+
# plt.subplots_adjust(top=0.2)
|
361 |
+
figsize = (num_cols * 2, num_rows * 2.4)
|
362 |
+
fig, axs = plt.subplots(num_rows, num_cols, figsize=figsize)
|
363 |
+
fig.suptitle(f"Input Image v.s. Top 15 Similar Renderings", fontsize=16, y=0.93)
|
364 |
+
plt.subplots_adjust(top=0.9)
|
365 |
+
axs = axs.flatten()
|
366 |
+
for idx in range(num_rows * num_cols):
|
367 |
+
if idx < num_frames:
|
368 |
+
if idx == 0:
|
369 |
+
axs[idx].imshow(gt_img.reshape(256, 256, 3))
|
370 |
+
axs[idx].set_xlabel(f'Input Image', fontsize=10)
|
371 |
+
else:
|
372 |
+
axs[idx].imshow(vis_img_opt[indices_of_smallest[idx-1]].reshape(256, 256, 3))
|
373 |
+
loss_text = f"MSE: {loss_MSE_grid[indices_of_smallest[idx-1]]:.4f}_{int(ranks1[indices_of_smallest[idx-1]]):d}\nLPIPS: {loss_LPIPS_grid[indices_of_smallest[idx-1]]:.4f}_{int(ranks2[indices_of_smallest[idx-1]]):d}"
|
374 |
+
axs[idx].text(0.05, 0.95, loss_text, color='black', fontsize=8,
|
375 |
+
ha='left', va='top', transform=axs[idx].transAxes)
|
376 |
+
for s in ["bottom", "top", "left", "right"]:
|
377 |
+
if idx == 0:
|
378 |
+
axs[idx].spines[s].set_color("green")
|
379 |
+
else:
|
380 |
+
axs[idx].spines[s].set_color(cmap(0.8 * idx / (num_frames)))
|
381 |
+
axs[idx].spines[s].set_linewidth(5)
|
382 |
+
axs[idx].set_xticks([])
|
383 |
+
axs[idx].set_yticks([])
|
384 |
+
|
385 |
+
# if i >= args.all_views:
|
386 |
+
# axs[i].set_xlabel(f'MSE: {mse_losses[i%args.all_views]:.4f}\nLPIPS: {lpips_losses[i%args.all_views]:.4f}', fontsize=10)
|
387 |
+
else:
|
388 |
+
axs[i].axis("off")
|
389 |
+
plt.tight_layout()
|
390 |
+
|
391 |
+
output_path = os.path.join(save_path, f'vis_{i}_render_and_compare.png')
|
392 |
+
plt.savefig(output_path) # Save the figure to a file
|
393 |
+
plt.close(fig)
|
394 |
+
print(f"Visualization file written to {output_path}")
|
395 |
+
|
396 |
+
del lpips_loss, renderer, learnable_cam_params
|
397 |
+
gc.collect()
|
398 |
+
torch.cuda.empty_cache()
|
399 |
+
|
400 |
+
return data
|
401 |
+
|
402 |
+
|
403 |
+
def align_to_mesh(camera_data, mesh_path, save_path, num_views=8):
|
404 |
+
parser = argparse.ArgumentParser()
|
405 |
+
parser.add_argument('--object', type=str, help="path to mesh (obj, ply, glb, ...)")
|
406 |
+
parser.add_argument('--path', type=str, help="path to mesh (obj, ply, glb, ...)")
|
407 |
+
parser.add_argument('--front_dir', type=str, default='+z', help="mesh front-facing dir")
|
408 |
+
parser.add_argument('--mode', default='albedo', type=str, choices=['lambertian', 'albedo', 'normal', 'depth'], help="rendering mode")
|
409 |
+
parser.add_argument('--W', type=int, default=256, help="GUI width")
|
410 |
+
parser.add_argument('--H', type=int, default=256, help="GUI height")
|
411 |
+
parser.add_argument("--wogui", type=bool, default=True, help="disable all dpg GUI")
|
412 |
+
parser.add_argument("--force_cuda_rast", action='store_true', help="force to use RasterizeCudaContext.")
|
413 |
+
parser.add_argument("--config", default='configs/navi.yaml', help="path to the yaml config file")
|
414 |
+
parser.add_argument('--radius', type=float, default=3, help="default GUI camera radius from center")
|
415 |
+
parser.add_argument('--fovy', type=float, default=49.1, help="default GUI camera fovy")
|
416 |
+
args, extras = parser.parse_known_args()
|
417 |
+
|
418 |
+
# override default config from cli
|
419 |
+
opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras))
|
420 |
+
data = camera_data
|
421 |
+
|
422 |
+
opt.mesh = mesh_path
|
423 |
+
opt.trainable_texture = False
|
424 |
+
renderer = Renderer(opt).to(torch.device("cuda"))
|
425 |
+
|
426 |
+
cameras = [CustomCamera(cam_params) for cam_params in data.values()]
|
427 |
+
# cams = [(cam.c2w, cam.perspective, cam.focal_length) for cam in cameras]
|
428 |
+
img_paths = [v["filepath"] for k, v in data.items()]
|
429 |
+
flags = [int(v["flag"]) for k, v in data.items()]
|
430 |
+
|
431 |
+
s1_steps = 128
|
432 |
+
num_hypotheses = 64
|
433 |
+
chunk_size = 512
|
434 |
+
print("Number of hypotheses:", num_hypotheses)
|
435 |
+
|
436 |
+
for i in tqdm.tqdm(range(num_views)):
|
437 |
+
if flags[i]:
|
438 |
+
continue
|
439 |
+
|
440 |
+
loss_MSE_grid = np.zeros(num_hypotheses)
|
441 |
+
vis_img_opt = torch.zeros(num_hypotheses, 256, 256, 3)
|
442 |
+
poses = torch.from_numpy(cameras[i].c2w).expand(num_hypotheses, -1, -1)
|
443 |
+
perspectives = torch.from_numpy(cameras[i].perspective).expand(num_hypotheses, -1, -1)
|
444 |
+
|
445 |
+
learnable_cam_params = torch.randn(num_hypotheses, 6) * 1e-3
|
446 |
+
learnable_cam_params.requires_grad_()
|
447 |
+
|
448 |
+
img_path = img_paths[i]
|
449 |
+
gt_img = Image.open(img_path)
|
450 |
+
if gt_img.mode == 'RGBA':
|
451 |
+
gt_img = np.asarray(gt_img, dtype=np.uint8).copy()
|
452 |
+
gt_mask = (gt_img[..., 3:] > 128).astype(np.float32)
|
453 |
+
gt_img[gt_img[:, :, -1] <= 255*0.9] = [255., 255., 255., 255.] # thresholding background
|
454 |
+
gt_img = gt_img[:, :, :3]
|
455 |
+
|
456 |
+
gt_tensor = torch.from_numpy(gt_img).float().unsqueeze(0).cuda() / 255.
|
457 |
+
gt_mask_tensor = torch.from_numpy(gt_mask).float().unsqueeze(0).cuda()
|
458 |
+
|
459 |
+
num_batches = num_hypotheses // chunk_size + int(num_hypotheses%chunk_size > 0)
|
460 |
+
|
461 |
+
l = [{'params': learnable_cam_params, 'lr': 5e-3, "name": "cam_params"}]
|
462 |
+
optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
|
463 |
+
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
|
464 |
+
|
465 |
+
init_lr = optimizer.param_groups[0]['lr']
|
466 |
+
for j in tqdm.tqdm(range(num_batches)):
|
467 |
+
batch_poses = poses[j*chunk_size:(j+1)*chunk_size]
|
468 |
+
batch_perspectives = perspectives[j*chunk_size:(j+1)*chunk_size]
|
469 |
+
optimizer.param_groups[0]['lr'] = init_lr
|
470 |
+
for k in tqdm.tqdm(range(s1_steps)):
|
471 |
+
batch_residuals = SE3.exp(learnable_cam_params[j*chunk_size:(j+1)*chunk_size]).as_matrix() # [5760, 4, 4]
|
472 |
+
batch_poses_opt = torch.bmm(batch_poses, batch_residuals)
|
473 |
+
out = renderer.render_batch(batch_poses_opt, batch_perspectives, 256, 256, ssaa=1) # (500, 256, 256, 3)
|
474 |
+
pred_tensor = out["image"]
|
475 |
+
|
476 |
+
if k == s1_steps - 1:
|
477 |
+
loss = F.mse_loss(pred_tensor, gt_tensor.expand(pred_tensor.shape[0], -1, -1, -1), reduction='none').mean(dim=(1, 2, 3))
|
478 |
+
# loss += F.mse_loss(valid_mask, gt_mask_tensor.expand(pred_tensor.shape[0], -1, -1, -1), reduction='none').mean(dim=(1, 2, 3))
|
479 |
+
loss_MSE_grid[j*chunk_size:(j+1)*chunk_size] = loss.detach().cpu().numpy()
|
480 |
+
batch_image = pred_tensor.detach().cpu()
|
481 |
+
vis_img_opt[j*chunk_size:(j+1)*chunk_size] = batch_image
|
482 |
+
loss = loss.mean()
|
483 |
+
|
484 |
+
else:
|
485 |
+
loss = F.mse_loss(pred_tensor, gt_tensor.expand(pred_tensor.shape[0], -1, -1, -1), reduction='mean')
|
486 |
+
|
487 |
+
loss.backward()
|
488 |
+
optimizer.step()
|
489 |
+
optimizer.zero_grad()
|
490 |
+
scheduler.step()
|
491 |
+
|
492 |
+
indices = np.argsort(loss_MSE_grid)
|
493 |
+
residual = SE3.exp(learnable_cam_params[indices[0]].detach()).as_matrix() # [5760, 4, 4]
|
494 |
+
c2w = torch.from_numpy(cameras[i].c2w) @ residual
|
495 |
+
w2c = torch.inverse(c2w)
|
496 |
+
|
497 |
+
w2c[1:3, :] *= -1 # OpenCV to OpenGL
|
498 |
+
w2c[:2, :] *= -1 # PyTorch3D to OpenCV
|
499 |
+
|
500 |
+
data[list(data.keys())[i]]["R"] = w2c[:3, :3].T.tolist()
|
501 |
+
data[list(data.keys())[i]]["T"] = w2c[:3, 3].tolist()
|
502 |
+
|
503 |
+
grid = vis_img_opt[indices].permute(0, 3, 1, 2).contiguous()
|
504 |
+
grid = grid.view(8, 8, 3, 256, 256).permute(2, 0, 3, 1, 4)
|
505 |
+
grid = grid.reshape(3, -1, int(256*8))
|
506 |
+
output_path = os.path.join(save_path, f'vis_aligned_candidates_{i}.png')
|
507 |
+
save_image(grid, output_path)
|
508 |
+
|
509 |
+
return data
|
510 |
+
|
sparseags/sh_utils.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 The PlenOctree Authors.
|
2 |
+
# Redistribution and use in source and binary forms, with or without
|
3 |
+
# modification, are permitted provided that the following conditions are met:
|
4 |
+
#
|
5 |
+
# 1. Redistributions of source code must retain the above copyright notice,
|
6 |
+
# this list of conditions and the following disclaimer.
|
7 |
+
#
|
8 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
9 |
+
# this list of conditions and the following disclaimer in the documentation
|
10 |
+
# and/or other materials provided with the distribution.
|
11 |
+
#
|
12 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
13 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
14 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
15 |
+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
16 |
+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
17 |
+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
18 |
+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
19 |
+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
20 |
+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
21 |
+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
22 |
+
# POSSIBILITY OF SUCH DAMAGE.
|
23 |
+
|
24 |
+
import torch
|
25 |
+
|
26 |
+
C0 = 0.28209479177387814
|
27 |
+
C1 = 0.4886025119029199
|
28 |
+
C2 = [
|
29 |
+
1.0925484305920792,
|
30 |
+
-1.0925484305920792,
|
31 |
+
0.31539156525252005,
|
32 |
+
-1.0925484305920792,
|
33 |
+
0.5462742152960396
|
34 |
+
]
|
35 |
+
C3 = [
|
36 |
+
-0.5900435899266435,
|
37 |
+
2.890611442640554,
|
38 |
+
-0.4570457994644658,
|
39 |
+
0.3731763325901154,
|
40 |
+
-0.4570457994644658,
|
41 |
+
1.445305721320277,
|
42 |
+
-0.5900435899266435
|
43 |
+
]
|
44 |
+
C4 = [
|
45 |
+
2.5033429417967046,
|
46 |
+
-1.7701307697799304,
|
47 |
+
0.9461746957575601,
|
48 |
+
-0.6690465435572892,
|
49 |
+
0.10578554691520431,
|
50 |
+
-0.6690465435572892,
|
51 |
+
0.47308734787878004,
|
52 |
+
-1.7701307697799304,
|
53 |
+
0.6258357354491761,
|
54 |
+
]
|
55 |
+
|
56 |
+
|
57 |
+
def eval_sh(deg, sh, dirs):
|
58 |
+
"""
|
59 |
+
Evaluate spherical harmonics at unit directions
|
60 |
+
using hardcoded SH polynomials.
|
61 |
+
Works with torch/np/jnp.
|
62 |
+
... Can be 0 or more batch dimensions.
|
63 |
+
Args:
|
64 |
+
deg: int SH deg. Currently, 0-3 supported
|
65 |
+
sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
|
66 |
+
dirs: jnp.ndarray unit directions [..., 3]
|
67 |
+
Returns:
|
68 |
+
[..., C]
|
69 |
+
"""
|
70 |
+
assert deg <= 4 and deg >= 0
|
71 |
+
coeff = (deg + 1) ** 2
|
72 |
+
assert sh.shape[-1] >= coeff
|
73 |
+
|
74 |
+
result = C0 * sh[..., 0]
|
75 |
+
if deg > 0:
|
76 |
+
x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
|
77 |
+
result = (result -
|
78 |
+
C1 * y * sh[..., 1] +
|
79 |
+
C1 * z * sh[..., 2] -
|
80 |
+
C1 * x * sh[..., 3])
|
81 |
+
|
82 |
+
if deg > 1:
|
83 |
+
xx, yy, zz = x * x, y * y, z * z
|
84 |
+
xy, yz, xz = x * y, y * z, x * z
|
85 |
+
result = (result +
|
86 |
+
C2[0] * xy * sh[..., 4] +
|
87 |
+
C2[1] * yz * sh[..., 5] +
|
88 |
+
C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
|
89 |
+
C2[3] * xz * sh[..., 7] +
|
90 |
+
C2[4] * (xx - yy) * sh[..., 8])
|
91 |
+
|
92 |
+
if deg > 2:
|
93 |
+
result = (result +
|
94 |
+
C3[0] * y * (3 * xx - yy) * sh[..., 9] +
|
95 |
+
C3[1] * xy * z * sh[..., 10] +
|
96 |
+
C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
|
97 |
+
C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
|
98 |
+
C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
|
99 |
+
C3[5] * z * (xx - yy) * sh[..., 14] +
|
100 |
+
C3[6] * x * (xx - 3 * yy) * sh[..., 15])
|
101 |
+
|
102 |
+
if deg > 3:
|
103 |
+
result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
|
104 |
+
C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
|
105 |
+
C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
|
106 |
+
C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
|
107 |
+
C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
|
108 |
+
C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
|
109 |
+
C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
|
110 |
+
C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
|
111 |
+
C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
|
112 |
+
return result
|
113 |
+
|
114 |
+
def RGB2SH(rgb):
|
115 |
+
return (rgb - 0.5) / C0
|
116 |
+
|
117 |
+
def SH2RGB(sh):
|
118 |
+
return sh * C0 + 0.5
|
sparseags/visual_utils.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import cv2
|
4 |
+
import csv
|
5 |
+
import json
|
6 |
+
import math
|
7 |
+
import tqdm
|
8 |
+
import shutil
|
9 |
+
import argparse
|
10 |
+
import numpy as np
|
11 |
+
from PIL import Image
|
12 |
+
from omegaconf import OmegaConf
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
import nvdiffrast.torch as dr
|
18 |
+
|
19 |
+
from kiui.mesh import Mesh
|
20 |
+
from kiui.cam import OrbitCamera
|
21 |
+
from kiui.op import safe_normalize
|
22 |
+
from kiui.lpips import LPIPS
|
23 |
+
|
24 |
+
import sys
|
25 |
+
from sparseags.mesh_utils.mesh_renderer import Renderer
|
26 |
+
from sparseags.cam_utils import orbit_camera, OrbitCamera
|
27 |
+
from sparseags.render_utils.gs_renderer import CustomCamera
|
28 |
+
|
29 |
+
|
30 |
+
class GUI:
|
31 |
+
def __init__(self, opt):
|
32 |
+
self.opt = opt
|
33 |
+
self.W = opt.W
|
34 |
+
self.H = opt.H
|
35 |
+
self.wogui = opt.wogui # disable gui and run in cmd
|
36 |
+
self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)
|
37 |
+
self.bg_color = torch.ones(3, dtype=torch.float32).cuda() # default white bg
|
38 |
+
# self.bg_color = torch.zeros(3, dtype=torch.float32).cuda() # black bg
|
39 |
+
|
40 |
+
self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
|
41 |
+
self.need_update = True # camera moved, should reset accumulation
|
42 |
+
self.light_dir = np.array([0, 0])
|
43 |
+
self.ambient_ratio = 0.5
|
44 |
+
|
45 |
+
# auto-rotate
|
46 |
+
self.auto_rotate_cam = False
|
47 |
+
self.auto_rotate_light = False
|
48 |
+
|
49 |
+
self.mode = opt.mode
|
50 |
+
self.render_modes = ['albedo', 'depth', 'normal', 'lambertian']
|
51 |
+
|
52 |
+
# load mesh
|
53 |
+
self.mesh = Mesh.load(opt.mesh, front_dir=opt.front_dir)
|
54 |
+
|
55 |
+
if not opt.force_cuda_rast and (self.wogui or os.name == 'nt'):
|
56 |
+
self.glctx = dr.RasterizeGLContext()
|
57 |
+
else:
|
58 |
+
self.glctx = dr.RasterizeCudaContext()
|
59 |
+
|
60 |
+
def step(self):
|
61 |
+
|
62 |
+
if not self.need_update:
|
63 |
+
return
|
64 |
+
|
65 |
+
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
66 |
+
starter.record()
|
67 |
+
|
68 |
+
# do MVP for vertices
|
69 |
+
pose = torch.from_numpy(self.cam.pose.astype(np.float32)).cuda()
|
70 |
+
proj = torch.from_numpy(self.cam.perspective.astype(np.float32)).cuda()
|
71 |
+
|
72 |
+
v_cam = torch.matmul(F.pad(self.mesh.v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
|
73 |
+
v_clip = v_cam @ proj.T
|
74 |
+
|
75 |
+
rast, rast_db = dr.rasterize(self.glctx, v_clip, self.mesh.f, (self.H, self.W))
|
76 |
+
|
77 |
+
alpha = (rast[..., 3:] > 0).float()
|
78 |
+
alpha = dr.antialias(alpha, rast, v_clip, self.mesh.f).squeeze(0).clamp(0, 1) # [H, W, 3]
|
79 |
+
|
80 |
+
if self.mode == 'depth':
|
81 |
+
depth, _ = dr.interpolate(-v_cam[..., [2]], rast, self.mesh.f) # [1, H, W, 1]
|
82 |
+
depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-20)
|
83 |
+
buffer = depth.squeeze(0).detach().cpu().numpy().repeat(3, -1) # [H, W, 3]
|
84 |
+
else:
|
85 |
+
# use vertex color if exists
|
86 |
+
if self.mesh.vc is not None:
|
87 |
+
albedo, _ = dr.interpolate(self.mesh.vc.unsqueeze(0).contiguous(), rast, self.mesh.f)
|
88 |
+
# use texture image
|
89 |
+
else:
|
90 |
+
texc, _ = dr.interpolate(self.mesh.vt.unsqueeze(0).contiguous(), rast, self.mesh.ft)
|
91 |
+
albedo = dr.texture(self.mesh.albedo.unsqueeze(0), texc, filter_mode='linear') # [1, H, W, 3]
|
92 |
+
|
93 |
+
albedo = torch.where(rast[..., 3:] > 0, albedo, torch.tensor(0).to(albedo.device)) # remove background
|
94 |
+
albedo = dr.antialias(albedo, rast, v_clip, self.mesh.f).clamp(0, 1) # [1, H, W, 3]
|
95 |
+
if self.mode == 'albedo':
|
96 |
+
albedo = albedo * alpha + self.bg_color * (1 - alpha)
|
97 |
+
buffer = albedo[0].detach().cpu().numpy()
|
98 |
+
else:
|
99 |
+
normal, _ = dr.interpolate(self.mesh.vn.unsqueeze(0).contiguous(), rast, self.mesh.fn)
|
100 |
+
normal = safe_normalize(normal)
|
101 |
+
if self.mode == 'normal':
|
102 |
+
normal_image = (normal[0] + 1) / 2
|
103 |
+
normal_image = torch.where(rast[..., 3:] > 0, normal_image, torch.tensor(1).to(normal_image.device)) # remove background
|
104 |
+
buffer = normal_image.detach().cpu().numpy()
|
105 |
+
elif self.mode == 'lambertian':
|
106 |
+
light_d = np.deg2rad(self.light_dir)
|
107 |
+
light_d = np.array([
|
108 |
+
np.cos(light_d[0]) * np.sin(light_d[1]),
|
109 |
+
-np.sin(light_d[0]),
|
110 |
+
np.cos(light_d[0]) * np.cos(light_d[1]),
|
111 |
+
], dtype=np.float32)
|
112 |
+
light_d = torch.from_numpy(light_d).to(albedo.device)
|
113 |
+
lambertian = self.ambient_ratio + (1 - self.ambient_ratio) * (normal @ light_d).float().clamp(min=0)
|
114 |
+
albedo = (albedo * lambertian.unsqueeze(-1)) * alpha + self.bg_color * (1 - alpha)
|
115 |
+
buffer = albedo[0].detach().cpu().numpy()
|
116 |
+
|
117 |
+
ender.record()
|
118 |
+
torch.cuda.synchronize()
|
119 |
+
t = starter.elapsed_time(ender)
|
120 |
+
|
121 |
+
self.render_buffer = buffer
|
122 |
+
self.need_update = False
|
123 |
+
|
124 |
+
if self.auto_rotate_cam:
|
125 |
+
self.cam.orbit(5, 0)
|
126 |
+
self.need_update = True
|
127 |
+
|
128 |
+
if self.auto_rotate_light:
|
129 |
+
self.light_dir[1] += 3
|
130 |
+
self.need_update = True
|
131 |
+
|
132 |
+
|
133 |
+
def vis_output(camera_data, mesh_path=None, save_path=None, num_views=8):
|
134 |
+
parser = argparse.ArgumentParser()
|
135 |
+
parser.add_argument('--front_dir', type=str, default='+z', help="mesh front-facing dir")
|
136 |
+
parser.add_argument('--mode', default='albedo', type=str, choices=['lambertian', 'albedo', 'normal', 'depth'], help="rendering mode")
|
137 |
+
parser.add_argument('--W', type=int, default=256, help="GUI width")
|
138 |
+
parser.add_argument('--H', type=int, default=256, help="GUI height")
|
139 |
+
parser.add_argument('--radius', type=float, default=3, help="default GUI camera radius from center")
|
140 |
+
parser.add_argument('--fovy', type=float, default=49.1, help="default GUI camera fovy")
|
141 |
+
parser.add_argument("--wogui", type=bool, default=True, help="disable all dpg GUI")
|
142 |
+
parser.add_argument("--force_cuda_rast", action='store_true', help="force to use RasterizeCudaContext.")
|
143 |
+
parser.add_argument('--elevation', type=int, default=0, help="rendering elevation")
|
144 |
+
parser.add_argument('--save_video', type=str, default=None, help="path to save rendered video")
|
145 |
+
parser.add_argument('--idx', type=int, default=0, help="GUI height")
|
146 |
+
parser.add_argument('--config', default='configs/navi.yaml', type=str, help='Path to config directory, which contains image.yaml')
|
147 |
+
args, extras = parser.parse_known_args()
|
148 |
+
|
149 |
+
# override default config from cli
|
150 |
+
opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras))
|
151 |
+
data = camera_data
|
152 |
+
|
153 |
+
cameras = [CustomCamera(cam_params) for cam_params in data.values()]
|
154 |
+
cams = [(cam.c2w, cam.perspective, cam.focal_length) for cam in cameras]
|
155 |
+
img_paths = [v["filepath"] for k, v in data.items()]
|
156 |
+
|
157 |
+
opt.mesh = mesh_path
|
158 |
+
opt.trainable_texture = False
|
159 |
+
renderer = Renderer(opt).to(torch.device("cuda"))
|
160 |
+
|
161 |
+
lpips_loss = LPIPS(net='vgg').cuda()
|
162 |
+
mse_losses = []
|
163 |
+
lpips_losses = []
|
164 |
+
flags = [int(v["flag"]) for k, v in data.items()]
|
165 |
+
images = np.zeros((2, num_views, 256, 256, 3), dtype=np.uint8)
|
166 |
+
|
167 |
+
for i in tqdm.tqdm(range(len(cams))):
|
168 |
+
|
169 |
+
img_path = img_paths[i]
|
170 |
+
|
171 |
+
img = Image.open(img_path)
|
172 |
+
if img.mode == 'RGBA':
|
173 |
+
img = np.asarray(img, dtype=np.uint8).copy()
|
174 |
+
img[img[:, :, -1] <= 255*0.9] = [255., 255., 255., 255.] # thresholding background
|
175 |
+
img = img[:, :, :3]
|
176 |
+
|
177 |
+
gt_tensor = torch.from_numpy(img).permute(2, 0, 1).float().unsqueeze(0).cuda() / 255.0
|
178 |
+
|
179 |
+
images[0, i] = img
|
180 |
+
|
181 |
+
with torch.no_grad():
|
182 |
+
out = renderer.render(*cams[i][:2], 256, 256, ssaa=1)
|
183 |
+
|
184 |
+
# rgb loss
|
185 |
+
image = (out["image"].detach().cpu().numpy() * 255).astype(np.uint8)
|
186 |
+
pred_tensor = out["image"].permute(2, 0, 1).float().unsqueeze(0).cuda()
|
187 |
+
# obj_scale = ((out["alpha"] > 0) & (out["viewcos"] > 0.5)).detach().sum().float()
|
188 |
+
obj_scale = (out["alpha"] > 0).detach().sum().float()
|
189 |
+
obj_scale /= 256 ** 2
|
190 |
+
|
191 |
+
images[1, i] = image
|
192 |
+
with torch.no_grad():
|
193 |
+
mse_losses.append(F.mse_loss(pred_tensor, gt_tensor).squeeze().cpu().numpy() / obj_scale.item())
|
194 |
+
lpips_losses.append(lpips_loss(pred_tensor, gt_tensor).squeeze().cpu().numpy() / obj_scale.item())
|
195 |
+
|
196 |
+
mean_mse = np.mean(np.array(mse_losses)[:num_views])
|
197 |
+
mean_lpips = np.mean(np.array(lpips_losses)[:num_views])
|
198 |
+
|
199 |
+
num_frames = 2 * num_views
|
200 |
+
cmap = plt.get_cmap("hsv")
|
201 |
+
num_rows = 2
|
202 |
+
num_cols = num_views
|
203 |
+
plt.subplots_adjust(top=0.2)
|
204 |
+
figsize = (num_cols * 2, num_rows * 2.2)
|
205 |
+
fig, axs = plt.subplots(num_rows, num_cols, figsize=figsize)
|
206 |
+
fig.suptitle(f"Avg MSE: {mean_mse:.4f}, Avg LPIPS: {mean_lpips:.4f}", fontsize=16, y=0.97)
|
207 |
+
axs = axs.flatten()
|
208 |
+
for i in range(num_rows * num_cols):
|
209 |
+
if i < num_frames:
|
210 |
+
axs[i].imshow(images.reshape(-1, 256, 256, 3)[i])
|
211 |
+
for s in ["bottom", "top", "left", "right"]:
|
212 |
+
if i % num_views <= num_views - 1:
|
213 |
+
if not flags[i%num_views]:
|
214 |
+
axs[i].spines[s].set_color("red")
|
215 |
+
else:
|
216 |
+
axs[i].spines[s].set_color("green")
|
217 |
+
else:
|
218 |
+
axs[i].spines[s].set_color(cmap(i / (num_frames)))
|
219 |
+
axs[i].spines[s].set_linewidth(5)
|
220 |
+
axs[i].set_xticks([])
|
221 |
+
axs[i].set_yticks([])
|
222 |
+
|
223 |
+
if i >= num_views:
|
224 |
+
axs[i].set_xlabel(f'MSE: {mse_losses[i%num_views]:.4f}\nLPIPS: {lpips_losses[i%num_views]:.4f}', fontsize=10)
|
225 |
+
else:
|
226 |
+
axs[i].axis("off")
|
227 |
+
plt.tight_layout()
|
228 |
+
plt.savefig(save_path)
|
229 |
+
plt.close(fig)
|
230 |
+
print(f"Visualization file written to {save_path}")
|
231 |
+
|
232 |
+
out_dir = save_path.replace('vis.png', 'reprojections')
|
233 |
+
os.makedirs(out_dir, exist_ok=True)
|
234 |
+
|
235 |
+
for i in range(num_views):
|
236 |
+
gt = Image.fromarray(images[0, i])
|
237 |
+
pred = Image.fromarray(images[1, i])
|
238 |
+
gt.save(os.path.join(out_dir, f"gt_{i}.png"))
|
239 |
+
pred.save(os.path.join(out_dir, f"pred_{i}.png"))
|
240 |
+
|
241 |
+
return np.array(lpips_losses), np.array(mse_losses)
|
242 |
+
|
243 |
+
|