Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.distributed as dist | |
from utils.distributed import get_rank, is_dist_avail_and_initialized, is_main_process | |
import random | |
import logging | |
logger = logging.getLogger(__name__) | |
class MetaLoader(object): | |
""" wraps multiple data loader """ | |
def __init__(self, name2loader): | |
"""Iterates over multiple dataloaders, it ensures all processes | |
work on data from the same dataloader. This loader will end when | |
the shorter dataloader raises StopIteration exception. | |
loaders: Dict, {name: dataloader} | |
""" | |
self.name2loader = name2loader | |
self.name2iter = {name: iter(l) for name, l in name2loader.items()} | |
name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())} | |
index2name = {v: k for k, v in name2index.items()} | |
iter_order = [] | |
for n, l in name2loader.items(): | |
iter_order.extend([name2index[n]]*len(l)) | |
random.shuffle(iter_order) | |
iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8) | |
# sync | |
if is_dist_avail_and_initialized(): | |
# make sure all processes have the same order so that | |
# each step they will have data from the same loader | |
dist.broadcast(iter_order, src=0) | |
self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()] | |
logger.info(str(self)) | |
def __str__(self): | |
output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"] | |
for idx, (name, loader) in enumerate(self.name2loader.items()): | |
output.append( | |
f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={len(loader)} " | |
) | |
return "\n".join(output) | |
def __len__(self): | |
return len(self.iter_order) | |
def __iter__(self): | |
""" this iterator will run indefinitely """ | |
for name in self.iter_order: | |
_iter = self.name2iter[name] | |
batch = next(_iter) | |
yield name, batch | |
class MetaLoader_rs(object): | |
""" wraps multiple data loader """ | |
def __init__(self, name2loader, skip_num=0): | |
"""Iterates over multiple dataloaders, it ensures all processes | |
work on data from the same dataloader. This loader will end when | |
the shorter dataloader raises StopIteration exception. | |
loaders: Dict, {name: dataloader} | |
""" | |
self.name2loader = name2loader | |
name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())} | |
index2name = {v: k for k, v in name2index.items()} | |
iter_order = [] | |
for n, l in name2loader.items(): | |
iter_order.extend([name2index[n]]*len(l)) | |
random.shuffle(iter_order) | |
iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8) | |
# sync | |
if is_dist_avail_and_initialized(): | |
# make sure all processes have the same order so that | |
# each step they will have data from the same loader | |
dist.broadcast(iter_order, src=0) | |
if skip_num > 0: | |
iter_order_skip = iter_order[:skip_num] | |
for k, v in index2name.items(): | |
media_step = (iter_order_skip == k).sum().item() | |
name2loader[v].sampler.set_start_iter(media_step) | |
logger.info(f"{v} dataloder skip steps: {media_step}") | |
iter_order = iter_order[skip_num:] | |
self.name2loader = name2loader | |
else: | |
logger.info("Do not skip steps for any dataloader!") | |
for k, v in index2name.items(): | |
name2loader[v].sampler.set_start_iter(0) | |
self.name2iter = {name: iter(l) for name, l in name2loader.items()} | |
self.iter_idx = iter_order | |
self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()] | |
logger.info(str(self)) | |
def __str__(self): | |
output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"] | |
for idx, (name, loader) in enumerate(self.name2loader.items()): | |
length = (self.iter_idx == idx).sum() | |
output.append( | |
f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={length} " | |
) | |
return "\n".join(output) | |
def __len__(self): | |
return len(self.iter_order) | |
def __iter__(self): | |
""" this iterator will run indefinitely """ | |
for name in self.iter_order: | |
_iter = self.name2iter[name] | |
batch = next(_iter) | |
yield name, batch |