Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
# from . import transforms as T | |
import torchvision.transforms as T | |
from PIL import Image | |
from timm.data import create_transform | |
from .torchvision_transforms.transforms import Resize as New_Resize | |
def build_clip_transforms(cfg, is_train=True): | |
if cfg.AUG.USE_TIMM and is_train: | |
print('=> use timm transform for training') | |
timm_cfg = cfg.AUG.TIMM_AUG | |
transforms = create_transform( | |
input_size=cfg.TRAIN.IMAGE_SIZE[0], | |
is_training=True, | |
use_prefetcher=False, | |
no_aug=False, | |
re_prob=timm_cfg.RE_PROB, | |
re_mode=timm_cfg.RE_MODE, | |
re_count=timm_cfg.RE_COUNT, | |
scale=cfg.AUG.SCALE, | |
ratio=cfg.AUG.RATIO, | |
hflip=timm_cfg.HFLIP, | |
vflip=timm_cfg.VFLIP, | |
color_jitter=timm_cfg.COLOR_JITTER, | |
auto_augment=timm_cfg.AUTO_AUGMENT, | |
interpolation=timm_cfg.INTERPOLATION, | |
mean=cfg.MODEL.PIXEL_MEAN, | |
std=cfg.MODEL.PIXEL_STD, | |
) | |
return transforms | |
# normalize_transform = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) | |
# assert isinstance(cfg.DATASET.OUTPUT_SIZE, (list, tuple)), 'DATASET.OUTPUT_SIZE should be list or tuple' | |
# NOTE: normalization is applied in rcnn.py, to keep consistent as Detectron2 | |
# normalize = T.Normalize(mean=cfg.MODEL.PIXEL_MEAN, std=cfg.MODEL.PIXEL_STD) # T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD) | |
transforms = None | |
if is_train: | |
aug = cfg.AUG | |
scale = aug.SCALE | |
ratio = aug.RATIO | |
if len(cfg.AUG.TRAIN.IMAGE_SIZE) == 2: # Data Augmentation from MSR-CLIP | |
ts = [ | |
T.RandomResizedCrop( | |
cfg.AUG.TRAIN.IMAGE_SIZE[0], scale=scale, ratio=ratio, | |
interpolation=cfg.AUG.INTERPOLATION | |
), | |
T.RandomHorizontalFlip(), | |
] | |
elif len(cfg.AUG.TRAIN.IMAGE_SIZE) == 1 and cfg.AUG.TRAIN.MAX_SIZE is not None: # designed for pretraining fastrcnn | |
ts = [ | |
New_Resize( | |
cfg.AUG.TRAIN.IMAGE_SIZE[0], max_size=cfg.AUG.TRAIN.MAX_SIZE, | |
interpolation=cfg.AUG.INTERPOLATION | |
), | |
T.RandomHorizontalFlip(), | |
] | |
cj = aug.COLOR_JITTER | |
if cj[-1] > 0.0: | |
ts.append(T.RandomApply([T.ColorJitter(*cj[:-1])], p=cj[-1])) | |
gs = aug.GRAY_SCALE | |
if gs > 0.0: | |
ts.append(T.RandomGrayscale(gs)) | |
gb = aug.GAUSSIAN_BLUR | |
if gb > 0.0: | |
ts.append(T.RandomApply([GaussianBlur([.1, 2.])], p=gb)) | |
ts.append(T.ToTensor()) | |
# NOTE: normalization is applied in rcnn.py, to keep consistent as Detectron2 | |
#ts.append(normalize) | |
transforms = T.Compose(ts) | |
else: | |
# for zeroshot inference of grounding evaluation | |
transforms = T.Compose([ | |
T.Resize( | |
cfg.AUG.TEST.IMAGE_SIZE[0], | |
interpolation=cfg.AUG.TEST.INTERPOLATION | |
), | |
T.ToTensor(), | |
]) | |
return transforms | |
return transforms | |