File size: 2,105 Bytes
7754b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import json
from torchvision.datasets import ImageFolder
import torch
import os
from PIL import Image
import collections
import torchvision.transforms as transforms
from label_str_to_imagenet_classes import label_str_to_imagenet_classes

torch.manual_seed(0)

ImageItem = collections.namedtuple('ImageItem', ('image_name', 'tag'))

normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                 std=[0.5, 0.5, 0.5])

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
])

class RobustnessDataset(ImageFolder):
    def __init__(self, imagenet_path, folder, imagenet_classes_path='imagenet_classes.json', isV2=False, isSI=False):
        self._isV2 = isV2
        self._isSI = isSI
        self._folder = folder
        self._imagenet_path = imagenet_path
        with open(imagenet_classes_path, 'r') as f:
            self._imagenet_classes = json.load(f)
        self._all_images = []

        base_dir = os.path.join(self._imagenet_path, folder)
        for i, file in enumerate(os.listdir(base_dir)):
            self._all_images.append(ImageItem(file, folder))


    def __getitem__(self, item):
        image_item = self._all_images[item]
        image_path = os.path.join(self._imagenet_path, image_item.tag, image_item.image_name)
        image = Image.open(image_path)
        image = image.convert('RGB')
        image = transform(image)

        if self._isV2:
            class_name = int(image_item.tag)
        elif self._isSI:
            class_name = int(label_str_to_imagenet_classes[image_item.tag])
        else:
            class_name = int(self._imagenet_classes[image_item.tag])

        return image, class_name

    def __len__(self):
        return len(self._all_images)

    def get_classname(self):
        if self._isV2:
            class_name = int(self._folder)
        elif self._isSI:
            class_name = int(label_str_to_imagenet_classes[self._folder])
        else:
            class_name = int(self._imagenet_classes[self._folder])
        return class_name