Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from torchvision import transforms | |
from .transforms import * | |
from .masking_generator import TubeMaskingGenerator, RandomMaskingGenerator | |
from .mae import VideoMAE | |
from .mae_multi import VideoMAE_multi | |
from .kinetics import VideoClsDataset | |
from .kinetics_sparse import VideoClsDataset_sparse | |
from .anet import ANetDataset | |
from .ssv2 import SSVideoClsDataset, SSRawFrameClsDataset | |
from .hmdb import HMDBVideoClsDataset, HMDBRawFrameClsDataset | |
class DataAugmentationForVideoMAE(object): | |
def __init__(self, args): | |
self.input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN | |
self.input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD | |
normalize = GroupNormalize(self.input_mean, self.input_std) | |
self.train_augmentation = GroupMultiScaleCrop(args.input_size, [1, .875, .75, .66]) | |
if args.color_jitter > 0: | |
self.transform = transforms.Compose([ | |
self.train_augmentation, | |
GroupColorJitter(args.color_jitter), | |
GroupRandomHorizontalFlip(flip=args.flip), | |
Stack(roll=False), | |
ToTorchFormatTensor(div=True), | |
normalize, | |
]) | |
else: | |
self.transform = transforms.Compose([ | |
self.train_augmentation, | |
GroupRandomHorizontalFlip(flip=args.flip), | |
Stack(roll=False), | |
ToTorchFormatTensor(div=True), | |
normalize, | |
]) | |
if args.mask_type == 'tube': | |
self.masked_position_generator = TubeMaskingGenerator( | |
args.window_size, args.mask_ratio | |
) | |
elif args.mask_type == 'random': | |
self.masked_position_generator = RandomMaskingGenerator( | |
args.window_size, args.mask_ratio | |
) | |
elif args.mask_type in 'attention': | |
self.masked_position_generator = None | |
def __call__(self, images): | |
process_data, _ = self.transform(images) | |
if self.masked_position_generator is None: | |
return process_data, -1 | |
else: | |
return process_data, self.masked_position_generator() | |
def __repr__(self): | |
repr = "(DataAugmentationForVideoMAE,\n" | |
repr += " transform = %s,\n" % str(self.transform) | |
repr += " Masked position generator = %s,\n" % str(self.masked_position_generator) | |
repr += ")" | |
return repr | |
def build_pretraining_dataset(args): | |
transform = DataAugmentationForVideoMAE(args) | |
dataset = VideoMAE( | |
root=None, | |
setting=args.data_path, | |
prefix=args.prefix, | |
split=args.split, | |
video_ext='mp4', | |
is_color=True, | |
modality='rgb', | |
num_segments=args.num_segments, | |
new_length=args.num_frames, | |
new_step=args.sampling_rate, | |
transform=transform, | |
temporal_jitter=False, | |
video_loader=True, | |
use_decord=args.use_decord, | |
lazy_init=False, | |
num_sample=args.num_sample) | |
print("Data Aug = %s" % str(transform)) | |
return dataset | |
def build_multi_pretraining_dataset(args): | |
origianl_flip = args.flip | |
transform = DataAugmentationForVideoMAE(args) | |
args.flip = False | |
transform_ssv2 = DataAugmentationForVideoMAE(args) | |
args.flip = origianl_flip | |
dataset = VideoMAE_multi( | |
root=None, | |
setting=args.data_path, | |
prefix=args.prefix, | |
split=args.split, | |
is_color=True, | |
modality='rgb', | |
num_segments=args.num_segments, | |
new_length=args.num_frames, | |
new_step=args.sampling_rate, | |
transform=transform, | |
transform_ssv2=transform_ssv2, | |
temporal_jitter=False, | |
video_loader=True, | |
use_decord=args.use_decord, | |
lazy_init=False, | |
num_sample=args.num_sample) | |
print("Data Aug = %s" % str(transform)) | |
print("Data Aug for SSV2 = %s" % str(transform_ssv2)) | |
return dataset | |
def build_dataset(is_train, test_mode, args): | |
print(f'Use Dataset: {args.data_set}') | |
if args.data_set in [ | |
'Kinetics', | |
'Kinetics_sparse', | |
'mitv1_sparse' | |
]: | |
mode = None | |
anno_path = None | |
if is_train is True: | |
mode = 'train' | |
anno_path = os.path.join(args.data_path, 'train.csv') | |
elif test_mode is True: | |
mode = 'test' | |
anno_path = os.path.join(args.data_path, 'test.csv') | |
else: | |
mode = 'validation' | |
anno_path = os.path.join(args.data_path, 'val.csv') | |
if 'sparse' in args.data_set: | |
func = VideoClsDataset_sparse | |
else: | |
func = VideoClsDataset | |
dataset = func( | |
anno_path=anno_path, | |
prefix=args.prefix, | |
split=args.split, | |
mode=mode, | |
clip_len=args.num_frames, | |
frame_sample_rate=args.sampling_rate, | |
num_segment=1, | |
test_num_segment=args.test_num_segment, | |
test_num_crop=args.test_num_crop, | |
num_crop=1 if not test_mode else 3, | |
keep_aspect_ratio=True, | |
crop_size=args.input_size, | |
short_side_size=args.short_side_size, | |
new_height=256, | |
new_width=320, | |
args=args) | |
nb_classes = args.nb_classes | |
elif args.data_set == 'SSV2': | |
mode = None | |
anno_path = None | |
if is_train is True: | |
mode = 'train' | |
anno_path = os.path.join(args.data_path, 'train.csv') | |
elif test_mode is True: | |
mode = 'test' | |
anno_path = os.path.join(args.data_path, 'test.csv') | |
else: | |
mode = 'validation' | |
anno_path = os.path.join(args.data_path, 'val.csv') | |
if args.use_decord: | |
func = SSVideoClsDataset | |
else: | |
func = SSRawFrameClsDataset | |
dataset = func( | |
anno_path=anno_path, | |
prefix=args.prefix, | |
split=args.split, | |
mode=mode, | |
clip_len=1, | |
num_segment=args.num_frames, | |
test_num_segment=args.test_num_segment, | |
test_num_crop=args.test_num_crop, | |
num_crop=1 if not test_mode else 3, | |
keep_aspect_ratio=True, | |
crop_size=args.input_size, | |
short_side_size=args.short_side_size, | |
new_height=256, | |
new_width=320, | |
filename_tmpl=args.filename_tmpl, | |
args=args) | |
nb_classes = 174 | |
elif args.data_set == 'UCF101': | |
mode = None | |
anno_path = None | |
if is_train is True: | |
mode = 'train' | |
anno_path = os.path.join(args.data_path, 'train.csv') | |
elif test_mode is True: | |
mode = 'test' | |
anno_path = os.path.join(args.data_path, 'test.csv') | |
else: | |
mode = 'validation' | |
anno_path = os.path.join(args.data_path, 'val.csv') | |
dataset = VideoClsDataset( | |
anno_path=anno_path, | |
prefix=args.prefix, | |
split=args.split, | |
mode=mode, | |
clip_len=args.num_frames, | |
frame_sample_rate=args.sampling_rate, | |
num_segment=1, | |
test_num_segment=args.test_num_segment, | |
test_num_crop=args.test_num_crop, | |
num_crop=1 if not test_mode else 3, | |
keep_aspect_ratio=True, | |
crop_size=args.input_size, | |
short_side_size=args.short_side_size, | |
new_height=256, | |
new_width=320, | |
args=args) | |
nb_classes = 101 | |
elif args.data_set == 'HMDB51': | |
mode = None | |
anno_path = None | |
if is_train is True: | |
mode = 'train' | |
anno_path = os.path.join(args.data_path, 'train.csv') | |
elif test_mode is True: | |
mode = 'test' | |
anno_path = os.path.join(args.data_path, 'test.csv') | |
else: | |
mode = 'validation' | |
anno_path = os.path.join(args.data_path, 'val.csv') | |
if args.use_decord: | |
func = HMDBVideoClsDataset | |
else: | |
func = HMDBRawFrameClsDataset | |
dataset = func( | |
anno_path=anno_path, | |
prefix=args.prefix, | |
split=args.split, | |
mode=mode, | |
clip_len=1, | |
num_segment=args.num_frames, | |
test_num_segment=args.test_num_segment, | |
test_num_crop=args.test_num_crop, | |
num_crop=1 if not test_mode else 3, | |
keep_aspect_ratio=True, | |
crop_size=args.input_size, | |
short_side_size=args.short_side_size, | |
new_height=256, | |
new_width=320, | |
filename_tmpl=args.filename_tmpl, | |
args=args) | |
nb_classes = 51 | |
elif args.data_set in [ | |
'ANet', | |
'HACS', | |
'ANet_interval', | |
'HACS_interval' | |
]: | |
mode = None | |
anno_path = None | |
if is_train is True: | |
mode = 'train' | |
anno_path = os.path.join(args.data_path, 'train.csv') | |
elif test_mode is True: | |
mode = 'test' | |
anno_path = os.path.join(args.data_path, 'test.csv') | |
else: | |
mode = 'validation' | |
anno_path = os.path.join(args.data_path, 'val.csv') | |
if 'interval' in args.data_set: | |
func = ANetDataset | |
else: | |
func = VideoClsDataset_sparse | |
dataset = func( | |
anno_path=anno_path, | |
prefix=args.prefix, | |
split=args.split, | |
mode=mode, | |
clip_len=args.num_frames, | |
frame_sample_rate=args.sampling_rate, | |
num_segment=1, | |
test_num_segment=args.test_num_segment, | |
test_num_crop=args.test_num_crop, | |
num_crop=1 if not test_mode else 3, | |
keep_aspect_ratio=True, | |
crop_size=args.input_size, | |
short_side_size=args.short_side_size, | |
new_height=256, | |
new_width=320, | |
args=args) | |
nb_classes = args.nb_classes | |
else: | |
print(f'Wrong: {args.data_set}') | |
raise NotImplementedError() | |
assert nb_classes == args.nb_classes | |
print("Number of the class = %d" % args.nb_classes) | |
return dataset, nb_classes | |