|
import json |
|
from torchvision.datasets import ImageFolder |
|
import torch |
|
import os |
|
from PIL import Image |
|
import collections |
|
import torchvision.transforms as transforms |
|
from label_str_to_imagenet_classes import label_str_to_imagenet_classes |
|
|
|
torch.manual_seed(0) |
|
|
|
ImageItem = collections.namedtuple('ImageItem', ('image_name', 'tag')) |
|
|
|
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], |
|
std=[0.5, 0.5, 0.5]) |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
normalize, |
|
]) |
|
|
|
class RobustnessDataset(ImageFolder): |
|
def __init__(self, imagenet_path, folder, imagenet_classes_path='imagenet_classes.json', isV2=False, isSI=False): |
|
self._isV2 = isV2 |
|
self._isSI = isSI |
|
self._folder = folder |
|
self._imagenet_path = imagenet_path |
|
with open(imagenet_classes_path, 'r') as f: |
|
self._imagenet_classes = json.load(f) |
|
self._all_images = [] |
|
|
|
base_dir = os.path.join(self._imagenet_path, folder) |
|
for i, file in enumerate(os.listdir(base_dir)): |
|
self._all_images.append(ImageItem(file, folder)) |
|
|
|
|
|
def __getitem__(self, item): |
|
image_item = self._all_images[item] |
|
image_path = os.path.join(self._imagenet_path, image_item.tag, image_item.image_name) |
|
image = Image.open(image_path) |
|
image = image.convert('RGB') |
|
image = transform(image) |
|
|
|
if self._isV2: |
|
class_name = int(image_item.tag) |
|
elif self._isSI: |
|
class_name = int(label_str_to_imagenet_classes[image_item.tag]) |
|
else: |
|
class_name = int(self._imagenet_classes[image_item.tag]) |
|
|
|
return image, class_name |
|
|
|
def __len__(self): |
|
return len(self._all_images) |
|
|
|
def get_classname(self): |
|
if self._isV2: |
|
class_name = int(self._folder) |
|
elif self._isSI: |
|
class_name = int(label_str_to_imagenet_classes[self._folder]) |
|
else: |
|
class_name = int(self._imagenet_classes[self._folder]) |
|
return class_name |