Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# -------------------------------------------------------- | |
# References: | |
# DeiT: https://github.com/facebookresearch/deit | |
# -------------------------------------------------------- | |
import os | |
import PIL | |
from torchvision import datasets, transforms | |
from timm.data import create_transform | |
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
def build_dataset(is_train, args): | |
transform = build_transform(is_train, args) | |
root = os.path.join(args.data_path, "train" if is_train else "val") | |
dataset = datasets.ImageFolder(root, transform=transform) | |
print(dataset) | |
return dataset | |
def build_transform(is_train, args): | |
mean = IMAGENET_DEFAULT_MEAN | |
std = IMAGENET_DEFAULT_STD | |
# train transform | |
if is_train: | |
# this should always dispatch to transforms_imagenet_train | |
transform = create_transform( | |
input_size=args.input_size, | |
is_training=True, | |
color_jitter=args.color_jitter, | |
auto_augment=args.aa, | |
interpolation="bicubic", | |
re_prob=args.reprob, | |
re_mode=args.remode, | |
re_count=args.recount, | |
mean=mean, | |
std=std, | |
) | |
return transform | |
# eval transform | |
t = [] | |
if args.input_size <= 224: | |
crop_pct = 224 / 256 | |
else: | |
crop_pct = 1.0 | |
size = int(args.input_size / crop_pct) | |
t.append( | |
transforms.Resize( | |
size, interpolation=PIL.Image.BICUBIC | |
), # to maintain same ratio w.r.t. 224 images | |
) | |
t.append(transforms.CenterCrop(args.input_size)) | |
t.append(transforms.ToTensor()) | |
t.append(transforms.Normalize(mean, std)) | |
return transforms.Compose(t) | |