File size: 879 Bytes
7754b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from torchvision.datasets import ImageFolder
import torch
import os
import collections


torch.manual_seed(0)

ImageItem = collections.namedtuple('ImageItem', ('image_name', 'tag'))

class RobustnessDataset(ImageFolder):
    def __init__(self, dataset_path):
        self._dataset_path = dataset_path
        self._tag_list = [tag for tag in os.listdir(self._dataset_path)]
        self._all_images = []
        for tag in self._tag_list:
            base_dir = os.path.join(self._dataset_path, tag)
            for i, file in enumerate(os.listdir(base_dir)):
                self._all_images.append(ImageItem(file, tag))

    def __getitem__(self, item):
        image_item = self._all_images[item]
        image_path = os.path.join(self._dataset_path, image_item.tag, image_item.image_name)
        return image_path

    def __len__(self):
        return len(self._all_images)