File size: 879 Bytes
7754b29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
from torchvision.datasets import ImageFolder
import torch
import os
import collections
torch.manual_seed(0)
ImageItem = collections.namedtuple('ImageItem', ('image_name', 'tag'))
class RobustnessDataset(ImageFolder):
def __init__(self, dataset_path):
self._dataset_path = dataset_path
self._tag_list = [tag for tag in os.listdir(self._dataset_path)]
self._all_images = []
for tag in self._tag_list:
base_dir = os.path.join(self._dataset_path, tag)
for i, file in enumerate(os.listdir(base_dir)):
self._all_images.append(ImageItem(file, tag))
def __getitem__(self, item):
image_item = self._all_images[item]
image_path = os.path.join(self._dataset_path, image_item.tag, image_item.image_name)
return image_path
def __len__(self):
return len(self._all_images) |