File size: 2,151 Bytes
7754b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import json
from torch.utils import data
from torchvision.datasets import ImageFolder
import torch
import os
from PIL import Image
import numpy as np
import argparse
from tqdm import tqdm
from munkres import Munkres
import multiprocessing
from multiprocessing import Process, Manager
import collections
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import random
import torchvision
import cv2
from label_str_to_imagenet_classes import label_str_to_imagenet_classes

torch.manual_seed(0)

ImageItem = collections.namedtuple('ImageItem', ('image_name', 'tag'))
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                 std=[0.5, 0.5, 0.5])

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
])

class RobustnessDataset(ImageFolder):
    def __init__(self, imagenet_path, imagenet_classes_path='imagenet_classes.json', isV2=False, isSI=False):
        self._isV2 = isV2
        self._isSI = isSI
        self._imagenet_path = imagenet_path
        with open(imagenet_classes_path, 'r') as f:
            self._imagenet_classes = json.load(f)
        self._tag_list = [tag for tag in os.listdir(self._imagenet_path)]
        self._all_images = []
        for tag in self._tag_list:
            base_dir = os.path.join(self._imagenet_path, tag)
            for i, file in enumerate(os.listdir(base_dir)):
                self._all_images.append(ImageItem(file, tag))


    def __getitem__(self, item):
        image_item = self._all_images[item]
        image_path = os.path.join(self._imagenet_path, image_item.tag, image_item.image_name)
        image = Image.open(image_path)
        image = image.convert('RGB')
        image = transform(image)

        if self._isV2:
            class_name = int(image_item.tag)
        elif self._isSI:
            class_name = int(label_str_to_imagenet_classes[image_item.tag])
        else:
            class_name = int(self._imagenet_classes[image_item.tag])

        return image, class_name

    def __len__(self):
        return len(self._all_images)