|
from datasets.ytb_vos import YoutubeVOSDataset |
|
from datasets.ytb_vis import YoutubeVISDataset |
|
from datasets.saliency_modular import SaliencyDataset |
|
from datasets.vipseg import VIPSegDataset |
|
from datasets.mvimagenet import MVImageNetDataset |
|
from datasets.sam import SAMDataset |
|
from datasets.dreambooth import DreamBoothDataset |
|
from datasets.uvo import UVODataset |
|
from datasets.uvo_val import UVOValDataset |
|
from datasets.mose import MoseDataset |
|
from datasets.vitonhd import VitonHDDataset |
|
from datasets.fashiontryon import FashionTryonDataset |
|
from datasets.lvis import LvisDataset |
|
from torch.utils.data import ConcatDataset |
|
from torch.utils.data import DataLoader |
|
import numpy as np |
|
import cv2 |
|
from omegaconf import OmegaConf |
|
|
|
|
|
DConf = OmegaConf.load('./configs/datasets.yaml') |
|
dataset1 = YoutubeVOSDataset(**DConf.Train.YoutubeVOS) |
|
dataset2 = SaliencyDataset(**DConf.Train.Saliency) |
|
dataset3 = VIPSegDataset(**DConf.Train.VIPSeg) |
|
dataset4 = YoutubeVISDataset(**DConf.Train.YoutubeVIS) |
|
dataset5 = MVImageNetDataset(**DConf.Train.MVImageNet) |
|
dataset6 = SAMDataset(**DConf.Train.SAM) |
|
dataset7 = UVODataset(**DConf.Train.UVO.train) |
|
dataset8 = VitonHDDataset(**DConf.Train.VitonHD) |
|
dataset9 = UVOValDataset(**DConf.Train.UVO.val) |
|
dataset10 = MoseDataset(**DConf.Train.Mose) |
|
dataset11 = FashionTryonDataset(**DConf.Train.FashionTryon) |
|
dataset12 = LvisDataset(**DConf.Train.Lvis) |
|
|
|
dataset = dataset5 |
|
|
|
|
|
def vis_sample(item): |
|
ref = item['ref']* 255 |
|
tar = item['jpg'] * 127.5 + 127.5 |
|
hint = item['hint'] * 127.5 + 127.5 |
|
step = item['time_steps'] |
|
print(ref.shape, tar.shape, hint.shape, step.shape) |
|
|
|
ref = ref[0].numpy() |
|
tar = tar[0].numpy() |
|
hint_image = hint[0, :,:,:-1].numpy() |
|
hint_mask = hint[0, :,:,-1].numpy() |
|
hint_mask = np.stack([hint_mask,hint_mask,hint_mask],-1) |
|
ref = cv2.resize(ref.astype(np.uint8), (512,512)) |
|
vis = cv2.hconcat([ref.astype(np.float32), hint_image.astype(np.float32), hint_mask.astype(np.float32), tar.astype(np.float32) ]) |
|
cv2.imwrite('sample_vis.jpg',vis[:,:,::-1]) |
|
|
|
|
|
dataloader = DataLoader(dataset, num_workers=8, batch_size=4, shuffle=True) |
|
print('len dataloader: ', len(dataloader)) |
|
for data in dataloader: |
|
vis_sample(data) |
|
|
|
|
|
|