|
|
|
|
|
from .utils.transforms import * |
|
from .base.batched_sampler import BatchedRandomSampler |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from .my_spring import SpringDatasets |
|
from .my_sceneflow import SceneFlowDatasets |
|
from .my_vkitti2 import VkittiDatasets |
|
from .my_PointOdyssey import PointodysseyDatasets |
|
from .my_Tartanair import TartanairDatasets |
|
from .my_sintel import SintelDatasets |
|
def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True): |
|
import torch |
|
from croco.utils.misc import get_world_size, get_rank |
|
|
|
|
|
if isinstance(dataset, str): |
|
dataset = eval(dataset) |
|
|
|
world_size = get_world_size() |
|
rank = get_rank() |
|
|
|
try: |
|
sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size, |
|
rank=rank, drop_last=drop_last) |
|
except (AttributeError, NotImplementedError): |
|
|
|
if torch.distributed.is_initialized(): |
|
sampler = torch.utils.data.DistributedSampler( |
|
dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last |
|
) |
|
elif shuffle: |
|
sampler = torch.utils.data.RandomSampler(dataset) |
|
else: |
|
sampler = torch.utils.data.SequentialSampler(dataset) |
|
|
|
data_loader = torch.utils.data.DataLoader( |
|
dataset, |
|
sampler=sampler, |
|
batch_size=batch_size, |
|
num_workers=num_workers, |
|
pin_memory=pin_mem, |
|
drop_last=drop_last, |
|
) |
|
|
|
return data_loader |
|
|