flamehaze1115 commited on
Commit
651aae3
·
1 Parent(s): 649375f

Delete mvdiffusion/data/objaverse_dataset.py

Browse files
Files changed (1) hide show
  1. mvdiffusion/data/objaverse_dataset.py +0 -608
mvdiffusion/data/objaverse_dataset.py DELETED
@@ -1,608 +0,0 @@
1
- from typing import Dict
2
- import numpy as np
3
- from omegaconf import DictConfig, ListConfig
4
- import torch
5
- from torch.utils.data import Dataset
6
- from pathlib import Path
7
- import json
8
- from PIL import Image
9
- from torchvision import transforms
10
- from einops import rearrange
11
- from typing import Literal, Tuple, Optional, Any
12
- import cv2
13
- import random
14
-
15
- import json
16
- import os, sys
17
- import math
18
-
19
- import PIL.Image
20
- from .normal_utils import trans_normal, normal2img, img2normal
21
- import pdb
22
-
23
- def shift_list(lst, n):
24
- length = len(lst)
25
- n = n % length # Ensure n is within the range of the list length
26
- return lst[-n:] + lst[:-n]
27
-
28
-
29
- class ObjaverseDataset(Dataset):
30
- def __init__(self,
31
- root_dir: str,
32
- num_views: int,
33
- bg_color: Any,
34
- img_wh: Tuple[int, int],
35
- object_list: str,
36
- groups_num: int=1,
37
- validation: bool = False,
38
- random_views: bool = False,
39
- num_validation_samples: int = 64,
40
- num_samples: Optional[int] = None,
41
- invalid_list: Optional[str] = None,
42
- trans_norm_system: bool = True, # if True, transform all normals map into the cam system of front view
43
- augment_data: bool = False,
44
- read_normal: bool = True,
45
- read_color: bool = False,
46
- read_depth: bool = False,
47
- mix_color_normal: bool = False,
48
- random_view_and_domain: bool = False
49
- ) -> None:
50
- """Create a dataset from a folder of images.
51
- If you pass in a root directory it will be searched for images
52
- ending in ext (ext can be a list)
53
- """
54
- self.root_dir = Path(root_dir)
55
- self.num_views = num_views
56
- self.bg_color = bg_color
57
- self.validation = validation
58
- self.num_samples = num_samples
59
- self.trans_norm_system = trans_norm_system
60
- self.augment_data = augment_data
61
- self.invalid_list = invalid_list
62
- self.groups_num = groups_num
63
- print("augment data: ", self.augment_data)
64
- self.img_wh = img_wh
65
- self.read_normal = read_normal
66
- self.read_color = read_color
67
- self.read_depth = read_depth
68
- self.mix_color_normal = mix_color_normal # mix load color and normal maps
69
- self.random_view_and_domain = random_view_and_domain # load normal or rgb of a single view
70
- self.random_views = random_views
71
- if not self.random_views:
72
- if self.num_views == 4:
73
- self.view_types = ['front', 'right', 'back', 'left']
74
- elif self.num_views == 5:
75
- self.view_types = ['front', 'front_right', 'right', 'back', 'left']
76
- elif self.num_views == 6 or self.num_views==1:
77
- self.view_types = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
78
- else:
79
- self.view_types = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
80
-
81
- self.fix_cam_pose_dir = "./mvdiffusion/data/fixed_poses/nine_views"
82
-
83
- self.fix_cam_poses = self.load_fixed_poses() # world2cam matrix
84
-
85
- if object_list is not None:
86
- with open(object_list) as f:
87
- self.objects = json.load(f)
88
- self.objects = [os.path.basename(o).replace(".glb", "") for o in self.objects]
89
- else:
90
- self.objects = os.listdir(self.root_dir)
91
- self.objects = sorted(self.objects)
92
-
93
- if self.invalid_list is not None:
94
- with open(self.invalid_list) as f:
95
- self.invalid_objects = json.load(f)
96
- self.invalid_objects = [os.path.basename(o).replace(".glb", "") for o in self.invalid_objects]
97
- else:
98
- self.invalid_objects = []
99
-
100
-
101
- self.all_objects = set(self.objects) - (set(self.invalid_objects) & set(self.objects))
102
- self.all_objects = list(self.all_objects)
103
-
104
- if not validation:
105
- self.all_objects = self.all_objects[:-num_validation_samples]
106
- else:
107
- self.all_objects = self.all_objects[-num_validation_samples:]
108
- if num_samples is not None:
109
- self.all_objects = self.all_objects[:num_samples]
110
-
111
- print("loading ", len(self.all_objects), " objects in the dataset")
112
-
113
- if self.mix_color_normal:
114
- self.backup_data = self.__getitem_mix__(0, "9438abf986c7453a9f4df7c34aa2e65b")
115
- elif self.random_view_and_domain:
116
- self.backup_data = self.__getitem_random_viewanddomain__(0, "9438abf986c7453a9f4df7c34aa2e65b")
117
- else:
118
- self.backup_data = self.__getitem_norm__(0, "9438abf986c7453a9f4df7c34aa2e65b") # "66b2134b7e3645b29d7c349645291f78")
119
-
120
- def __len__(self):
121
- return len(self.objects)*self.total_view
122
-
123
- def load_fixed_poses(self):
124
- poses = {}
125
- for face in self.view_types:
126
- RT = np.loadtxt(os.path.join(self.fix_cam_pose_dir,'%03d_%s_RT.txt'%(0, face)))
127
- poses[face] = RT
128
-
129
- return poses
130
-
131
- def cartesian_to_spherical(self, xyz):
132
- ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
133
- xy = xyz[:,0]**2 + xyz[:,1]**2
134
- z = np.sqrt(xy + xyz[:,2]**2)
135
- theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down
136
- #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
137
- azimuth = np.arctan2(xyz[:,1], xyz[:,0])
138
- return np.array([theta, azimuth, z])
139
-
140
- def get_T(self, target_RT, cond_RT):
141
- R, T = target_RT[:3, :3], target_RT[:, -1]
142
- T_target = -R.T @ T # change to cam2world
143
-
144
- R, T = cond_RT[:3, :3], cond_RT[:, -1]
145
- T_cond = -R.T @ T
146
-
147
- theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :])
148
- theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :])
149
-
150
- d_theta = theta_target - theta_cond
151
- d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
152
- d_z = z_target - z_cond
153
-
154
- # d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()])
155
- return d_theta, d_azimuth
156
-
157
- def get_bg_color(self):
158
- if self.bg_color == 'white':
159
- bg_color = np.array([1., 1., 1.], dtype=np.float32)
160
- elif self.bg_color == 'black':
161
- bg_color = np.array([0., 0., 0.], dtype=np.float32)
162
- elif self.bg_color == 'gray':
163
- bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
164
- elif self.bg_color == 'random':
165
- bg_color = np.random.rand(3)
166
- elif self.bg_color == 'three_choices':
167
- white = np.array([1., 1., 1.], dtype=np.float32)
168
- black = np.array([0., 0., 0.], dtype=np.float32)
169
- gray = np.array([0.5, 0.5, 0.5], dtype=np.float32)
170
- bg_color = random.choice([white, black, gray])
171
- elif isinstance(self.bg_color, float):
172
- bg_color = np.array([self.bg_color] * 3, dtype=np.float32)
173
- else:
174
- raise NotImplementedError
175
- return bg_color
176
-
177
-
178
-
179
- def load_mask(self, img_path, return_type='np'):
180
- # not using cv2 as may load in uint16 format
181
- # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
182
- # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
183
- # pil always returns uint8
184
- img = np.array(Image.open(img_path).resize(self.img_wh))
185
- img = np.float32(img > 0)
186
-
187
- assert len(np.shape(img)) == 2
188
-
189
- if return_type == "np":
190
- pass
191
- elif return_type == "pt":
192
- img = torch.from_numpy(img)
193
- else:
194
- raise NotImplementedError
195
-
196
- return img
197
-
198
- def load_image(self, img_path, bg_color, alpha, return_type='np'):
199
- # not using cv2 as may load in uint16 format
200
- # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
201
- # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
202
- # pil always returns uint8
203
- img = np.array(Image.open(img_path).resize(self.img_wh))
204
- img = img.astype(np.float32) / 255. # [0, 1]
205
- assert img.shape[-1] == 3 # RGB
206
-
207
- if alpha.shape[-1] != 1:
208
- alpha = alpha[:, :, None]
209
-
210
- img = img[...,:3] * alpha + bg_color * (1 - alpha)
211
-
212
- if return_type == "np":
213
- pass
214
- elif return_type == "pt":
215
- img = torch.from_numpy(img)
216
- else:
217
- raise NotImplementedError
218
-
219
- return img
220
-
221
- def load_depth(self, img_path, bg_color, alpha, return_type='np'):
222
- # not using cv2 as may load in uint16 format
223
- # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
224
- # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
225
- # pil always returns uint8
226
- img = np.array(Image.open(img_path).resize(self.img_wh))
227
- img = img.astype(np.float32) / 65535. # [0, 1]
228
-
229
- img[img > 0.4] = 0
230
- img = img / 0.4
231
-
232
- assert img.ndim == 2 # depth
233
- img = np.stack([img]*3, axis=-1)
234
-
235
- if alpha.shape[-1] != 1:
236
- alpha = alpha[:, :, None]
237
-
238
- # print(np.max(img[:, :, 0]))
239
-
240
- img = img[...,:3] * alpha + bg_color * (1 - alpha)
241
-
242
- if return_type == "np":
243
- pass
244
- elif return_type == "pt":
245
- img = torch.from_numpy(img)
246
- else:
247
- raise NotImplementedError
248
-
249
- return img
250
-
251
- def load_normal(self, img_path, bg_color, alpha, RT_w2c=None, RT_w2c_cond=None, return_type='np'):
252
- # not using cv2 as may load in uint16 format
253
- # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
254
- # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
255
- # pil always returns uint8
256
- normal = np.array(Image.open(img_path).resize(self.img_wh))
257
-
258
- assert normal.shape[-1] == 3 # RGB
259
-
260
- normal = trans_normal(img2normal(normal), RT_w2c, RT_w2c_cond)
261
-
262
- img = (normal*0.5 + 0.5).astype(np.float32) # [0, 1]
263
-
264
- if alpha.shape[-1] != 1:
265
- alpha = alpha[:, :, None]
266
-
267
- img = img[...,:3] * alpha + bg_color * (1 - alpha)
268
-
269
- if return_type == "np":
270
- pass
271
- elif return_type == "pt":
272
- img = torch.from_numpy(img)
273
- else:
274
- raise NotImplementedError
275
-
276
- return img
277
-
278
- def __len__(self):
279
- return len(self.all_objects)
280
-
281
- def __getitem_mix__(self, index, debug_object=None):
282
- if debug_object is not None:
283
- object_name = debug_object #
284
- set_idx = random.sample(range(0, self.groups_num), 1)[0] # without replacement
285
- else:
286
- object_name = self.all_objects[index%len(self.all_objects)]
287
- set_idx = 0
288
-
289
- if self.augment_data:
290
- cond_view = random.sample(self.view_types, k=1)[0]
291
- else:
292
- cond_view = 'front'
293
-
294
- if random.random() < 0.5:
295
- read_color, read_normal, read_depth = True, False, False
296
- else:
297
- read_color, read_normal, read_depth = False, True, True
298
-
299
- read_normal = read_normal & self.read_normal
300
- read_depth = read_depth & self.read_depth
301
-
302
- assert (read_color and (read_normal or read_depth)) is False
303
-
304
- view_types = self.view_types
305
-
306
- cond_w2c = self.fix_cam_poses[cond_view]
307
-
308
- tgt_w2cs = [self.fix_cam_poses[view] for view in view_types]
309
-
310
- elevations = []
311
- azimuths = []
312
-
313
- # get the bg color
314
- bg_color = self.get_bg_color()
315
-
316
- cond_alpha = self.load_mask(os.path.join(self.root_dir, object_name[:3], object_name, "mask_%03d_%s.png" % (set_idx, cond_view)), return_type='np')
317
- img_tensors_in = [
318
- self.load_image(os.path.join(self.root_dir, object_name[:3], object_name, "rgb_%03d_%s.png" % (set_idx, cond_view)), bg_color, cond_alpha, return_type='pt').permute(2, 0, 1)
319
- ] * self.num_views
320
- img_tensors_out = []
321
-
322
- for view, tgt_w2c in zip(view_types, tgt_w2cs):
323
- img_path = os.path.join(self.root_dir, object_name[:3], object_name, "rgb_%03d_%s.png" % (set_idx, view))
324
- mask_path = os.path.join(self.root_dir, object_name[:3], object_name, "mask_%03d_%s.png" % (set_idx, view))
325
- normal_path = os.path.join(self.root_dir, object_name[:3], object_name, "normals_%03d_%s.png" % (set_idx, view))
326
- depth_path = os.path.join(self.root_dir, object_name[:3], object_name, "depth_%03d_%s.png" % (set_idx, view))
327
- alpha = self.load_mask(mask_path, return_type='np')
328
-
329
- if read_color:
330
- img_tensor = self.load_image(img_path, bg_color, alpha, return_type="pt")
331
- img_tensor = img_tensor.permute(2, 0, 1)
332
- img_tensors_out.append(img_tensor)
333
-
334
- if read_normal:
335
- normal_tensor = self.load_normal(normal_path, bg_color, alpha, RT_w2c=tgt_w2c, RT_w2c_cond=cond_w2c, return_type="pt").permute(2, 0, 1)
336
- img_tensors_out.append(normal_tensor)
337
- if read_depth:
338
- depth_tensor = self.load_depth(depth_path, bg_color, alpha, return_type="pt").permute(2, 0, 1)
339
- img_tensors_out.append(depth_tensor)
340
-
341
- # evelations, azimuths
342
- elevation, azimuth = self.get_T(tgt_w2c, cond_w2c)
343
- elevations.append(elevation)
344
- azimuths.append(azimuth)
345
-
346
- img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
347
- img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
348
-
349
-
350
- elevations = torch.as_tensor(elevations).float().squeeze(1)
351
- azimuths = torch.as_tensor(azimuths).float().squeeze(1)
352
- elevations_cond = torch.as_tensor([0] * self.num_views).float() # fixed only use 4 views to train
353
- camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1) # (Nv, 3)
354
-
355
- normal_class = torch.tensor([1, 0]).float()
356
- normal_task_embeddings = torch.stack([normal_class]*self.num_views, dim=0) # (Nv, 2)
357
- color_class = torch.tensor([0, 1]).float()
358
- color_task_embeddings = torch.stack([color_class]*self.num_views, dim=0) # (Nv, 2)
359
- if read_normal or read_depth:
360
- task_embeddings = normal_task_embeddings
361
- if read_color:
362
- task_embeddings = color_task_embeddings
363
-
364
- return {
365
- 'elevations_cond': elevations_cond,
366
- 'elevations_cond_deg': torch.rad2deg(elevations_cond),
367
- 'elevations': elevations,
368
- 'azimuths': azimuths,
369
- 'elevations_deg': torch.rad2deg(elevations),
370
- 'azimuths_deg': torch.rad2deg(azimuths),
371
- 'imgs_in': img_tensors_in,
372
- 'imgs_out': img_tensors_out,
373
- 'camera_embeddings': camera_embeddings,
374
- 'task_embeddings': task_embeddings
375
- }
376
-
377
-
378
- def __getitem_random_viewanddomain__(self, index, debug_object=None):
379
- if debug_object is not None:
380
- object_name = debug_object #
381
- set_idx = random.sample(range(0, self.groups_num), 1)[0] # without replacement
382
- else:
383
- object_name = self.all_objects[index%len(self.all_objects)]
384
- set_idx = 0
385
-
386
- if self.augment_data:
387
- cond_view = random.sample(self.view_types, k=1)[0]
388
- else:
389
- cond_view = 'front'
390
-
391
- if random.random() < 0.5:
392
- read_color, read_normal, read_depth = True, False, False
393
- else:
394
- read_color, read_normal, read_depth = False, True, True
395
-
396
- read_normal = read_normal & self.read_normal
397
- read_depth = read_depth & self.read_depth
398
-
399
- assert (read_color and (read_normal or read_depth)) is False
400
-
401
- view_types = self.view_types
402
-
403
- cond_w2c = self.fix_cam_poses[cond_view]
404
-
405
- tgt_w2cs = [self.fix_cam_poses[view] for view in view_types]
406
-
407
- elevations = []
408
- azimuths = []
409
-
410
- # get the bg color
411
- bg_color = self.get_bg_color()
412
-
413
- cond_alpha = self.load_mask(os.path.join(self.root_dir, object_name[:3], object_name, "mask_%03d_%s.png" % (set_idx, cond_view)), return_type='np')
414
- img_tensors_in = [
415
- self.load_image(os.path.join(self.root_dir, object_name[:3], object_name, "rgb_%03d_%s.png" % (set_idx, cond_view)), bg_color, cond_alpha, return_type='pt').permute(2, 0, 1)
416
- ] * self.num_views
417
- img_tensors_out = []
418
-
419
- random_viewidx = random.randint(0, len(view_types)-1)
420
-
421
- for view, tgt_w2c in zip([view_types[random_viewidx]], [tgt_w2cs[random_viewidx]]):
422
- img_path = os.path.join(self.root_dir, object_name[:3], object_name, "rgb_%03d_%s.png" % (set_idx, view))
423
- mask_path = os.path.join(self.root_dir, object_name[:3], object_name, "mask_%03d_%s.png" % (set_idx, view))
424
- normal_path = os.path.join(self.root_dir, object_name[:3], object_name, "normals_%03d_%s.png" % (set_idx, view))
425
- depth_path = os.path.join(self.root_dir, object_name[:3], object_name, "depth_%03d_%s.png" % (set_idx, view))
426
- alpha = self.load_mask(mask_path, return_type='np')
427
-
428
- if read_color:
429
- img_tensor = self.load_image(img_path, bg_color, alpha, return_type="pt")
430
- img_tensor = img_tensor.permute(2, 0, 1)
431
- img_tensors_out.append(img_tensor)
432
-
433
- if read_normal:
434
- normal_tensor = self.load_normal(normal_path, bg_color, alpha, RT_w2c=tgt_w2c, RT_w2c_cond=cond_w2c, return_type="pt").permute(2, 0, 1)
435
- img_tensors_out.append(normal_tensor)
436
- if read_depth:
437
- depth_tensor = self.load_depth(depth_path, bg_color, alpha, return_type="pt").permute(2, 0, 1)
438
- img_tensors_out.append(depth_tensor)
439
-
440
- # evelations, azimuths
441
- elevation, azimuth = self.get_T(tgt_w2c, cond_w2c)
442
- elevations.append(elevation)
443
- azimuths.append(azimuth)
444
-
445
- img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
446
- img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
447
-
448
-
449
- elevations = torch.as_tensor(elevations).float().squeeze(1)
450
- azimuths = torch.as_tensor(azimuths).float().squeeze(1)
451
- elevations_cond = torch.as_tensor([0] * self.num_views).float() # fixed only use 4 views to train
452
- camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1) # (Nv, 3)
453
-
454
- normal_class = torch.tensor([1, 0]).float()
455
- normal_task_embeddings = torch.stack([normal_class]*self.num_views, dim=0) # (Nv, 2)
456
- color_class = torch.tensor([0, 1]).float()
457
- color_task_embeddings = torch.stack([color_class]*self.num_views, dim=0) # (Nv, 2)
458
- if read_normal or read_depth:
459
- task_embeddings = normal_task_embeddings
460
- if read_color:
461
- task_embeddings = color_task_embeddings
462
-
463
- return {
464
- 'elevations_cond': elevations_cond,
465
- 'elevations_cond_deg': torch.rad2deg(elevations_cond),
466
- 'elevations': elevations,
467
- 'azimuths': azimuths,
468
- 'elevations_deg': torch.rad2deg(elevations),
469
- 'azimuths_deg': torch.rad2deg(azimuths),
470
- 'imgs_in': img_tensors_in,
471
- 'imgs_out': img_tensors_out,
472
- 'camera_embeddings': camera_embeddings,
473
- 'task_embeddings': task_embeddings
474
- }
475
-
476
-
477
- def __getitem_norm__(self, index, debug_object=None):
478
- if debug_object is not None:
479
- object_name = debug_object #
480
- set_idx = random.sample(range(0, self.groups_num), 1)[0] # without replacement
481
- else:
482
- object_name = self.all_objects[index%len(self.all_objects)]
483
- set_idx = 0
484
-
485
- if self.augment_data:
486
- cond_view = random.sample(self.view_types, k=1)[0]
487
- else:
488
- cond_view = 'front'
489
-
490
- # if self.random_views:
491
- # view_types = ['front']+random.sample(self.view_types[1:], 3)
492
- # else:
493
- # view_types = self.view_types
494
-
495
- view_types = self.view_types
496
-
497
- cond_w2c = self.fix_cam_poses[cond_view]
498
-
499
- tgt_w2cs = [self.fix_cam_poses[view] for view in view_types]
500
-
501
- elevations = []
502
- azimuths = []
503
-
504
- # get the bg color
505
- bg_color = self.get_bg_color()
506
-
507
- cond_alpha = self.load_mask(os.path.join(self.root_dir, object_name[:3], object_name, "mask_%03d_%s.png" % (set_idx, cond_view)), return_type='np')
508
- img_tensors_in = [
509
- self.load_image(os.path.join(self.root_dir, object_name[:3], object_name, "rgb_%03d_%s.png" % (set_idx, cond_view)), bg_color, cond_alpha, return_type='pt').permute(2, 0, 1)
510
- ] * self.num_views
511
- img_tensors_out = []
512
- normal_tensors_out = []
513
- for view, tgt_w2c in zip(view_types, tgt_w2cs):
514
- img_path = os.path.join(self.root_dir, object_name[:3], object_name, "rgb_%03d_%s.png" % (set_idx, view))
515
- mask_path = os.path.join(self.root_dir, object_name[:3], object_name, "mask_%03d_%s.png" % (set_idx, view))
516
- alpha = self.load_mask(mask_path, return_type='np')
517
-
518
- if self.read_color:
519
- img_tensor = self.load_image(img_path, bg_color, alpha, return_type="pt")
520
- img_tensor = img_tensor.permute(2, 0, 1)
521
- img_tensors_out.append(img_tensor)
522
-
523
- if self.read_normal:
524
- normal_path = os.path.join(self.root_dir, object_name[:3], object_name, "normals_%03d_%s.png" % (set_idx, view))
525
- normal_tensor = self.load_normal(normal_path, bg_color, alpha, RT_w2c=tgt_w2c, RT_w2c_cond=cond_w2c, return_type="pt").permute(2, 0, 1)
526
- normal_tensors_out.append(normal_tensor)
527
-
528
- # evelations, azimuths
529
- elevation, azimuth = self.get_T(tgt_w2c, cond_w2c)
530
- elevations.append(elevation)
531
- azimuths.append(azimuth)
532
-
533
- img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
534
- if self.read_color:
535
- img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
536
- if self.read_normal:
537
- normal_tensors_out = torch.stack(normal_tensors_out, dim=0).float() # (Nv, 3, H, W)
538
-
539
- elevations = torch.as_tensor(elevations).float().squeeze(1)
540
- azimuths = torch.as_tensor(azimuths).float().squeeze(1)
541
- elevations_cond = torch.as_tensor([0] * self.num_views).float() # fixed only use 4 views to train
542
-
543
- camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1) # (Nv, 3)
544
-
545
- normal_class = torch.tensor([1, 0]).float()
546
- normal_task_embeddings = torch.stack([normal_class]*self.num_views, dim=0) # (Nv, 2)
547
- color_class = torch.tensor([0, 1]).float()
548
- color_task_embeddings = torch.stack([color_class]*self.num_views, dim=0) # (Nv, 2)
549
-
550
- return {
551
- 'elevations_cond': elevations_cond,
552
- 'elevations_cond_deg': torch.rad2deg(elevations_cond),
553
- 'elevations': elevations,
554
- 'azimuths': azimuths,
555
- 'elevations_deg': torch.rad2deg(elevations),
556
- 'azimuths_deg': torch.rad2deg(azimuths),
557
- 'imgs_in': img_tensors_in,
558
- 'imgs_out': img_tensors_out,
559
- 'normals_out': normal_tensors_out,
560
- 'camera_embeddings': camera_embeddings,
561
- 'normal_task_embeddings': normal_task_embeddings,
562
- 'color_task_embeddings': color_task_embeddings
563
- }
564
-
565
- def __getitem__(self, index):
566
-
567
- try:
568
- if self.mix_color_normal:
569
- data = self.__getitem_mix__(index)
570
- elif self.random_view_and_domain:
571
- data = self.__getitem_random_viewanddomain__(index)
572
- else:
573
- data = self.__getitem_norm__(index)
574
- return data
575
- except:
576
- print("load error ", self.all_objects[index%len(self.all_objects)] )
577
- return self.backup_data
578
-
579
-
580
- class ConcatDataset(torch.utils.data.Dataset):
581
- def __init__(self, datasets, weights):
582
- self.datasets = datasets
583
- self.weights = weights
584
- self.num_datasets = len(datasets)
585
-
586
- def __getitem__(self, i):
587
-
588
- chosen = random.choices(self.datasets, self.weights, k=1)[0]
589
- return chosen[i]
590
-
591
- def __len__(self):
592
- return max(len(d) for d in self.datasets)
593
-
594
- if __name__ == "__main__":
595
- train_dataset = ObjaverseDataset(
596
- root_dir="/ghome/l5/xxlong/.objaverse/hf-objaverse-v1/renderings",
597
- size=(128, 128),
598
- ext="hdf5",
599
- default_trans=torch.zeros(3),
600
- return_paths=False,
601
- total_view=8,
602
- validation=False,
603
- object_list=None,
604
- views_mode='fourviews'
605
- )
606
- data0 = train_dataset[0]
607
- data1 = train_dataset[50]
608
- # print(data)