Spaces:
Runtime error
Runtime error
# Copyright (c) 2015-present, Facebook, Inc. | |
# All rights reserved. | |
import os | |
import json | |
from torchvision import datasets, transforms | |
from torchvision.datasets.folder import ImageFolder, default_loader | |
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
from timm.data import create_transform | |
class INatDataset(ImageFolder): | |
def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, | |
category='name', loader=default_loader): | |
self.transform = transform | |
self.loader = loader | |
self.target_transform = target_transform | |
self.year = year | |
# assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] | |
path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') | |
with open(path_json) as json_file: | |
data = json.load(json_file) | |
with open(os.path.join(root, 'categories.json')) as json_file: | |
data_catg = json.load(json_file) | |
path_json_for_targeter = os.path.join(root, f"train{year}.json") | |
with open(path_json_for_targeter) as json_file: | |
data_for_targeter = json.load(json_file) | |
targeter = {} | |
indexer = 0 | |
for elem in data_for_targeter['annotations']: | |
king = [] | |
king.append(data_catg[int(elem['category_id'])][category]) | |
if king[0] not in targeter.keys(): | |
targeter[king[0]] = indexer | |
indexer += 1 | |
self.nb_classes = len(targeter) | |
self.samples = [] | |
for elem in data['images']: | |
cut = elem['file_name'].split('/') | |
target_current = int(cut[2]) | |
path_current = os.path.join(root, cut[0], cut[2], cut[3]) | |
categors = data_catg[target_current] | |
target_current_true = targeter[categors[category]] | |
self.samples.append((path_current, target_current_true)) | |
# __getitem__ and __len__ inherited from ImageFolder | |
def build_dataset(is_train, args): | |
transform = build_transform(is_train, args) | |
if args.data_set == 'CIFAR': | |
dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) | |
nb_classes = 100 | |
elif args.data_set == 'IMNET': | |
root = os.path.join(args.data_path, 'train' if is_train else 'val') | |
dataset = datasets.ImageFolder(root, transform=transform) | |
nb_classes = 1000 | |
elif args.data_set == 'INAT': | |
dataset = INatDataset(args.data_path, train=is_train, year=2018, | |
category=args.inat_category, transform=transform) | |
nb_classes = dataset.nb_classes | |
elif args.data_set == 'INAT19': | |
dataset = INatDataset(args.data_path, train=is_train, year=2019, | |
category=args.inat_category, transform=transform) | |
nb_classes = dataset.nb_classes | |
return dataset, nb_classes | |
def build_transform(is_train, args): | |
resize_im = args.input_size > 32 | |
if is_train: | |
# this should always dispatch to transforms_imagenet_train | |
transform = create_transform( | |
input_size=args.input_size, | |
is_training=True, | |
color_jitter=args.color_jitter, | |
auto_augment=args.aa, | |
interpolation=args.train_interpolation, | |
re_prob=args.reprob, | |
re_mode=args.remode, | |
re_count=args.recount, | |
) | |
if not resize_im: | |
# replace RandomResizedCropAndInterpolation with | |
# RandomCrop | |
transform.transforms[0] = transforms.RandomCrop( | |
args.input_size, padding=4) | |
return transform | |
t = [] | |
if resize_im: | |
size = int(args.input_size / args.eval_crop_ratio) | |
t.append( | |
transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images | |
) | |
t.append(transforms.CenterCrop(args.input_size)) | |
t.append(transforms.ToTensor()) | |
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) | |
return transforms.Compose(t) | |