Spaces:
Running
Running
""" Quick n Simple Image Folder, Tarfile based DataSet | |
Hacked together by / Copyright 2019, Ross Wightman | |
""" | |
import torch.utils.data as data | |
import os | |
import torch | |
import logging | |
from PIL import Image | |
from .parsers import create_parser | |
_logger = logging.getLogger(__name__) | |
_ERROR_RETRY = 50 | |
class ImageDataset(data.Dataset): | |
def __init__( | |
self, | |
root, | |
parser=None, | |
class_map=None, | |
load_bytes=False, | |
transform=None, | |
target_transform=None, | |
): | |
if parser is None or isinstance(parser, str): | |
parser = create_parser(parser or '', root=root, class_map=class_map) | |
self.parser = parser | |
self.load_bytes = load_bytes | |
self.transform = transform | |
self.target_transform = target_transform | |
self._consecutive_errors = 0 | |
def __getitem__(self, index): | |
img, target = self.parser[index] | |
try: | |
img = img.read() if self.load_bytes else Image.open(img).convert('RGB') | |
except Exception as e: | |
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}') | |
self._consecutive_errors += 1 | |
if self._consecutive_errors < _ERROR_RETRY: | |
return self.__getitem__((index + 1) % len(self.parser)) | |
else: | |
raise e | |
self._consecutive_errors = 0 | |
if self.transform is not None: | |
img = self.transform(img) | |
if target is None: | |
target = -1 | |
elif self.target_transform is not None: | |
target = self.target_transform(target) | |
return img, target | |
def __len__(self): | |
return len(self.parser) | |
def filename(self, index, basename=False, absolute=False): | |
return self.parser.filename(index, basename, absolute) | |
def filenames(self, basename=False, absolute=False): | |
return self.parser.filenames(basename, absolute) | |
class IterableImageDataset(data.IterableDataset): | |
def __init__( | |
self, | |
root, | |
parser=None, | |
split='train', | |
is_training=False, | |
batch_size=None, | |
repeats=0, | |
download=False, | |
transform=None, | |
target_transform=None, | |
): | |
assert parser is not None | |
if isinstance(parser, str): | |
self.parser = create_parser( | |
parser, root=root, split=split, is_training=is_training, | |
batch_size=batch_size, repeats=repeats, download=download) | |
else: | |
self.parser = parser | |
self.transform = transform | |
self.target_transform = target_transform | |
self._consecutive_errors = 0 | |
def __iter__(self): | |
for img, target in self.parser: | |
if self.transform is not None: | |
img = self.transform(img) | |
if self.target_transform is not None: | |
target = self.target_transform(target) | |
yield img, target | |
def __len__(self): | |
if hasattr(self.parser, '__len__'): | |
return len(self.parser) | |
else: | |
return 0 | |
def filename(self, index, basename=False, absolute=False): | |
assert False, 'Filename lookup by index not supported, use filenames().' | |
def filenames(self, basename=False, absolute=False): | |
return self.parser.filenames(basename, absolute) | |
class AugMixDataset(torch.utils.data.Dataset): | |
"""Dataset wrapper to perform AugMix or other clean/augmentation mixes""" | |
def __init__(self, dataset, num_splits=2): | |
self.augmentation = None | |
self.normalize = None | |
self.dataset = dataset | |
if self.dataset.transform is not None: | |
self._set_transforms(self.dataset.transform) | |
self.num_splits = num_splits | |
def _set_transforms(self, x): | |
assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms' | |
self.dataset.transform = x[0] | |
self.augmentation = x[1] | |
self.normalize = x[2] | |
def transform(self): | |
return self.dataset.transform | |
def transform(self, x): | |
self._set_transforms(x) | |
def _normalize(self, x): | |
return x if self.normalize is None else self.normalize(x) | |
def __getitem__(self, i): | |
x, y = self.dataset[i] # all splits share the same dataset base transform | |
x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split) | |
# run the full augmentation on the remaining splits | |
for _ in range(self.num_splits - 1): | |
x_list.append(self._normalize(self.augmentation(x))) | |
return tuple(x_list), y | |
def __len__(self): | |
return len(self.dataset) | |