|
import json |
|
from torch.utils import data |
|
from torchvision.datasets import ImageFolder |
|
import torch |
|
import os |
|
from PIL import Image |
|
import numpy as np |
|
import argparse |
|
from tqdm import tqdm |
|
from munkres import Munkres |
|
import multiprocessing |
|
from multiprocessing import Process, Manager |
|
import collections |
|
import torchvision.transforms as transforms |
|
import torchvision.transforms.functional as TF |
|
import random |
|
import torchvision |
|
import cv2 |
|
from label_str_to_imagenet_classes import label_str_to_imagenet_classes |
|
|
|
torch.manual_seed(0) |
|
|
|
ImageItem = collections.namedtuple('ImageItem', ('image_name', 'tag')) |
|
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], |
|
std=[0.5, 0.5, 0.5]) |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
normalize, |
|
]) |
|
|
|
class RobustnessDataset(ImageFolder): |
|
def __init__(self, imagenet_path, imagenet_classes_path='imagenet_classes.json', isV2=False, isSI=False): |
|
self._isV2 = isV2 |
|
self._isSI = isSI |
|
self._imagenet_path = imagenet_path |
|
with open(imagenet_classes_path, 'r') as f: |
|
self._imagenet_classes = json.load(f) |
|
self._tag_list = [tag for tag in os.listdir(self._imagenet_path)] |
|
self._all_images = [] |
|
for tag in self._tag_list: |
|
base_dir = os.path.join(self._imagenet_path, tag) |
|
for i, file in enumerate(os.listdir(base_dir)): |
|
self._all_images.append(ImageItem(file, tag)) |
|
|
|
|
|
def __getitem__(self, item): |
|
image_item = self._all_images[item] |
|
image_path = os.path.join(self._imagenet_path, image_item.tag, image_item.image_name) |
|
image = Image.open(image_path) |
|
image = image.convert('RGB') |
|
image = transform(image) |
|
|
|
if self._isV2: |
|
class_name = int(image_item.tag) |
|
elif self._isSI: |
|
class_name = int(label_str_to_imagenet_classes[image_item.tag]) |
|
else: |
|
class_name = int(self._imagenet_classes[image_item.tag]) |
|
|
|
return image, class_name |
|
|
|
def __len__(self): |
|
return len(self._all_images) |