mazpie's picture
Initial commit
2d9a728
raw
history blame
10.4 kB
import os
from torchvision import transforms
from .transforms import *
from .masking_generator import TubeMaskingGenerator, RandomMaskingGenerator
from .mae import VideoMAE
from .mae_multi import VideoMAE_multi
from .kinetics import VideoClsDataset
from .kinetics_sparse import VideoClsDataset_sparse
from .anet import ANetDataset
from .ssv2 import SSVideoClsDataset, SSRawFrameClsDataset
from .hmdb import HMDBVideoClsDataset, HMDBRawFrameClsDataset
class DataAugmentationForVideoMAE(object):
def __init__(self, args):
self.input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN
self.input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD
normalize = GroupNormalize(self.input_mean, self.input_std)
self.train_augmentation = GroupMultiScaleCrop(args.input_size, [1, .875, .75, .66])
if args.color_jitter > 0:
self.transform = transforms.Compose([
self.train_augmentation,
GroupColorJitter(args.color_jitter),
GroupRandomHorizontalFlip(flip=args.flip),
Stack(roll=False),
ToTorchFormatTensor(div=True),
normalize,
])
else:
self.transform = transforms.Compose([
self.train_augmentation,
GroupRandomHorizontalFlip(flip=args.flip),
Stack(roll=False),
ToTorchFormatTensor(div=True),
normalize,
])
if args.mask_type == 'tube':
self.masked_position_generator = TubeMaskingGenerator(
args.window_size, args.mask_ratio
)
elif args.mask_type == 'random':
self.masked_position_generator = RandomMaskingGenerator(
args.window_size, args.mask_ratio
)
elif args.mask_type in 'attention':
self.masked_position_generator = None
def __call__(self, images):
process_data, _ = self.transform(images)
if self.masked_position_generator is None:
return process_data, -1
else:
return process_data, self.masked_position_generator()
def __repr__(self):
repr = "(DataAugmentationForVideoMAE,\n"
repr += " transform = %s,\n" % str(self.transform)
repr += " Masked position generator = %s,\n" % str(self.masked_position_generator)
repr += ")"
return repr
def build_pretraining_dataset(args):
transform = DataAugmentationForVideoMAE(args)
dataset = VideoMAE(
root=None,
setting=args.data_path,
prefix=args.prefix,
split=args.split,
video_ext='mp4',
is_color=True,
modality='rgb',
num_segments=args.num_segments,
new_length=args.num_frames,
new_step=args.sampling_rate,
transform=transform,
temporal_jitter=False,
video_loader=True,
use_decord=args.use_decord,
lazy_init=False,
num_sample=args.num_sample)
print("Data Aug = %s" % str(transform))
return dataset
def build_multi_pretraining_dataset(args):
origianl_flip = args.flip
transform = DataAugmentationForVideoMAE(args)
args.flip = False
transform_ssv2 = DataAugmentationForVideoMAE(args)
args.flip = origianl_flip
dataset = VideoMAE_multi(
root=None,
setting=args.data_path,
prefix=args.prefix,
split=args.split,
is_color=True,
modality='rgb',
num_segments=args.num_segments,
new_length=args.num_frames,
new_step=args.sampling_rate,
transform=transform,
transform_ssv2=transform_ssv2,
temporal_jitter=False,
video_loader=True,
use_decord=args.use_decord,
lazy_init=False,
num_sample=args.num_sample)
print("Data Aug = %s" % str(transform))
print("Data Aug for SSV2 = %s" % str(transform_ssv2))
return dataset
def build_dataset(is_train, test_mode, args):
print(f'Use Dataset: {args.data_set}')
if args.data_set in [
'Kinetics',
'Kinetics_sparse',
'mitv1_sparse'
]:
mode = None
anno_path = None
if is_train is True:
mode = 'train'
anno_path = os.path.join(args.data_path, 'train.csv')
elif test_mode is True:
mode = 'test'
anno_path = os.path.join(args.data_path, 'test.csv')
else:
mode = 'validation'
anno_path = os.path.join(args.data_path, 'val.csv')
if 'sparse' in args.data_set:
func = VideoClsDataset_sparse
else:
func = VideoClsDataset
dataset = func(
anno_path=anno_path,
prefix=args.prefix,
split=args.split,
mode=mode,
clip_len=args.num_frames,
frame_sample_rate=args.sampling_rate,
num_segment=1,
test_num_segment=args.test_num_segment,
test_num_crop=args.test_num_crop,
num_crop=1 if not test_mode else 3,
keep_aspect_ratio=True,
crop_size=args.input_size,
short_side_size=args.short_side_size,
new_height=256,
new_width=320,
args=args)
nb_classes = args.nb_classes
elif args.data_set == 'SSV2':
mode = None
anno_path = None
if is_train is True:
mode = 'train'
anno_path = os.path.join(args.data_path, 'train.csv')
elif test_mode is True:
mode = 'test'
anno_path = os.path.join(args.data_path, 'test.csv')
else:
mode = 'validation'
anno_path = os.path.join(args.data_path, 'val.csv')
if args.use_decord:
func = SSVideoClsDataset
else:
func = SSRawFrameClsDataset
dataset = func(
anno_path=anno_path,
prefix=args.prefix,
split=args.split,
mode=mode,
clip_len=1,
num_segment=args.num_frames,
test_num_segment=args.test_num_segment,
test_num_crop=args.test_num_crop,
num_crop=1 if not test_mode else 3,
keep_aspect_ratio=True,
crop_size=args.input_size,
short_side_size=args.short_side_size,
new_height=256,
new_width=320,
filename_tmpl=args.filename_tmpl,
args=args)
nb_classes = 174
elif args.data_set == 'UCF101':
mode = None
anno_path = None
if is_train is True:
mode = 'train'
anno_path = os.path.join(args.data_path, 'train.csv')
elif test_mode is True:
mode = 'test'
anno_path = os.path.join(args.data_path, 'test.csv')
else:
mode = 'validation'
anno_path = os.path.join(args.data_path, 'val.csv')
dataset = VideoClsDataset(
anno_path=anno_path,
prefix=args.prefix,
split=args.split,
mode=mode,
clip_len=args.num_frames,
frame_sample_rate=args.sampling_rate,
num_segment=1,
test_num_segment=args.test_num_segment,
test_num_crop=args.test_num_crop,
num_crop=1 if not test_mode else 3,
keep_aspect_ratio=True,
crop_size=args.input_size,
short_side_size=args.short_side_size,
new_height=256,
new_width=320,
args=args)
nb_classes = 101
elif args.data_set == 'HMDB51':
mode = None
anno_path = None
if is_train is True:
mode = 'train'
anno_path = os.path.join(args.data_path, 'train.csv')
elif test_mode is True:
mode = 'test'
anno_path = os.path.join(args.data_path, 'test.csv')
else:
mode = 'validation'
anno_path = os.path.join(args.data_path, 'val.csv')
if args.use_decord:
func = HMDBVideoClsDataset
else:
func = HMDBRawFrameClsDataset
dataset = func(
anno_path=anno_path,
prefix=args.prefix,
split=args.split,
mode=mode,
clip_len=1,
num_segment=args.num_frames,
test_num_segment=args.test_num_segment,
test_num_crop=args.test_num_crop,
num_crop=1 if not test_mode else 3,
keep_aspect_ratio=True,
crop_size=args.input_size,
short_side_size=args.short_side_size,
new_height=256,
new_width=320,
filename_tmpl=args.filename_tmpl,
args=args)
nb_classes = 51
elif args.data_set in [
'ANet',
'HACS',
'ANet_interval',
'HACS_interval'
]:
mode = None
anno_path = None
if is_train is True:
mode = 'train'
anno_path = os.path.join(args.data_path, 'train.csv')
elif test_mode is True:
mode = 'test'
anno_path = os.path.join(args.data_path, 'test.csv')
else:
mode = 'validation'
anno_path = os.path.join(args.data_path, 'val.csv')
if 'interval' in args.data_set:
func = ANetDataset
else:
func = VideoClsDataset_sparse
dataset = func(
anno_path=anno_path,
prefix=args.prefix,
split=args.split,
mode=mode,
clip_len=args.num_frames,
frame_sample_rate=args.sampling_rate,
num_segment=1,
test_num_segment=args.test_num_segment,
test_num_crop=args.test_num_crop,
num_crop=1 if not test_mode else 3,
keep_aspect_ratio=True,
crop_size=args.input_size,
short_side_size=args.short_side_size,
new_height=256,
new_width=320,
args=args)
nb_classes = args.nb_classes
else:
print(f'Wrong: {args.data_set}')
raise NotImplementedError()
assert nb_classes == args.nb_classes
print("Number of the class = %d" % args.nb_classes)
return dataset, nb_classes