cyun9286's picture
Add application file
f53b39e
raw
history blame
18.5 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Transforms and data augmentation for both image + bbox.
"""
import logging
import random
from typing import Iterable
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as F
import torchvision.transforms.v2.functional as Fv2
from PIL import Image as PILImage
from torchvision.transforms import InterpolationMode
from training.utils.data_utils import VideoDatapoint
def hflip(datapoint, index):
datapoint.frames[index].data = F.hflip(datapoint.frames[index].data)
for obj in datapoint.frames[index].objects:
if obj.segment is not None:
obj.segment = F.hflip(obj.segment)
return datapoint
def get_size_with_aspect_ratio(image_size, size, max_size=None):
w, h = image_size
if max_size is not None:
min_original_size = float(min((w, h)))
max_original_size = float(max((w, h)))
if max_original_size / min_original_size * size > max_size:
size = max_size * min_original_size / max_original_size
if (w <= h and w == size) or (h <= w and h == size):
return (h, w)
if w < h:
ow = int(round(size))
oh = int(round(size * h / w))
else:
oh = int(round(size))
ow = int(round(size * w / h))
return (oh, ow)
def resize(datapoint, index, size, max_size=None, square=False, v2=False):
# size can be min_size (scalar) or (w, h) tuple
def get_size(image_size, size, max_size=None):
if isinstance(size, (list, tuple)):
return size[::-1]
else:
return get_size_with_aspect_ratio(image_size, size, max_size)
if square:
size = size, size
else:
cur_size = (
datapoint.frames[index].data.size()[-2:][::-1]
if v2
else datapoint.frames[index].data.size
)
size = get_size(cur_size, size, max_size)
old_size = (
datapoint.frames[index].data.size()[-2:][::-1]
if v2
else datapoint.frames[index].data.size
)
if v2:
datapoint.frames[index].data = Fv2.resize(
datapoint.frames[index].data, size, antialias=True
)
else:
datapoint.frames[index].data = F.resize(datapoint.frames[index].data, size)
new_size = (
datapoint.frames[index].data.size()[-2:][::-1]
if v2
else datapoint.frames[index].data.size
)
for obj in datapoint.frames[index].objects:
if obj.segment is not None:
obj.segment = F.resize(obj.segment[None, None], size).squeeze()
h, w = size
datapoint.frames[index].size = (h, w)
return datapoint
def pad(datapoint, index, padding, v2=False):
old_h, old_w = datapoint.frames[index].size
h, w = old_h, old_w
if len(padding) == 2:
# assumes that we only pad on the bottom right corners
datapoint.frames[index].data = F.pad(
datapoint.frames[index].data, (0, 0, padding[0], padding[1])
)
h += padding[1]
w += padding[0]
else:
# left, top, right, bottom
datapoint.frames[index].data = F.pad(
datapoint.frames[index].data,
(padding[0], padding[1], padding[2], padding[3]),
)
h += padding[1] + padding[3]
w += padding[0] + padding[2]
datapoint.frames[index].size = (h, w)
for obj in datapoint.frames[index].objects:
if obj.segment is not None:
if v2:
if len(padding) == 2:
obj.segment = Fv2.pad(obj.segment, (0, 0, padding[0], padding[1]))
else:
obj.segment = Fv2.pad(obj.segment, tuple(padding))
else:
if len(padding) == 2:
obj.segment = F.pad(obj.segment, (0, 0, padding[0], padding[1]))
else:
obj.segment = F.pad(obj.segment, tuple(padding))
return datapoint
class RandomHorizontalFlip:
def __init__(self, consistent_transform, p=0.5):
self.p = p
self.consistent_transform = consistent_transform
def __call__(self, datapoint, **kwargs):
if self.consistent_transform:
if random.random() < self.p:
for i in range(len(datapoint.frames)):
datapoint = hflip(datapoint, i)
return datapoint
for i in range(len(datapoint.frames)):
if random.random() < self.p:
datapoint = hflip(datapoint, i)
return datapoint
class RandomResizeAPI:
def __init__(
self, sizes, consistent_transform, max_size=None, square=False, v2=False
):
if isinstance(sizes, int):
sizes = (sizes,)
assert isinstance(sizes, Iterable)
self.sizes = list(sizes)
self.max_size = max_size
self.square = square
self.consistent_transform = consistent_transform
self.v2 = v2
def __call__(self, datapoint, **kwargs):
if self.consistent_transform:
size = random.choice(self.sizes)
for i in range(len(datapoint.frames)):
datapoint = resize(
datapoint, i, size, self.max_size, square=self.square, v2=self.v2
)
return datapoint
for i in range(len(datapoint.frames)):
size = random.choice(self.sizes)
datapoint = resize(
datapoint, i, size, self.max_size, square=self.square, v2=self.v2
)
return datapoint
class ToTensorAPI:
def __init__(self, v2=False):
self.v2 = v2
def __call__(self, datapoint: VideoDatapoint, **kwargs):
for img in datapoint.frames:
if self.v2:
img.data = Fv2.to_image_tensor(img.data)
else:
img.data = F.to_tensor(img.data)
return datapoint
class NormalizeAPI:
def __init__(self, mean, std, v2=False):
self.mean = mean
self.std = std
self.v2 = v2
def __call__(self, datapoint: VideoDatapoint, **kwargs):
for img in datapoint.frames:
if self.v2:
img.data = Fv2.convert_image_dtype(img.data, torch.float32)
img.data = Fv2.normalize(img.data, mean=self.mean, std=self.std)
else:
img.data = F.normalize(img.data, mean=self.mean, std=self.std)
return datapoint
class ComposeAPI:
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, datapoint, **kwargs):
for t in self.transforms:
datapoint = t(datapoint, **kwargs)
return datapoint
def __repr__(self):
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += "\n"
format_string += " {0}".format(t)
format_string += "\n)"
return format_string
class RandomGrayscale:
def __init__(self, consistent_transform, p=0.5):
self.p = p
self.consistent_transform = consistent_transform
self.Grayscale = T.Grayscale(num_output_channels=3)
def __call__(self, datapoint: VideoDatapoint, **kwargs):
if self.consistent_transform:
if random.random() < self.p:
for img in datapoint.frames:
img.data = self.Grayscale(img.data)
return datapoint
for img in datapoint.frames:
if random.random() < self.p:
img.data = self.Grayscale(img.data)
return datapoint
class ColorJitter:
def __init__(self, consistent_transform, brightness, contrast, saturation, hue):
self.consistent_transform = consistent_transform
self.brightness = (
brightness
if isinstance(brightness, list)
else [max(0, 1 - brightness), 1 + brightness]
)
self.contrast = (
contrast
if isinstance(contrast, list)
else [max(0, 1 - contrast), 1 + contrast]
)
self.saturation = (
saturation
if isinstance(saturation, list)
else [max(0, 1 - saturation), 1 + saturation]
)
self.hue = hue if isinstance(hue, list) or hue is None else ([-hue, hue])
def __call__(self, datapoint: VideoDatapoint, **kwargs):
if self.consistent_transform:
# Create a color jitter transformation params
(
fn_idx,
brightness_factor,
contrast_factor,
saturation_factor,
hue_factor,
) = T.ColorJitter.get_params(
self.brightness, self.contrast, self.saturation, self.hue
)
for img in datapoint.frames:
if not self.consistent_transform:
(
fn_idx,
brightness_factor,
contrast_factor,
saturation_factor,
hue_factor,
) = T.ColorJitter.get_params(
self.brightness, self.contrast, self.saturation, self.hue
)
for fn_id in fn_idx:
if fn_id == 0 and brightness_factor is not None:
img.data = F.adjust_brightness(img.data, brightness_factor)
elif fn_id == 1 and contrast_factor is not None:
img.data = F.adjust_contrast(img.data, contrast_factor)
elif fn_id == 2 and saturation_factor is not None:
img.data = F.adjust_saturation(img.data, saturation_factor)
elif fn_id == 3 and hue_factor is not None:
img.data = F.adjust_hue(img.data, hue_factor)
return datapoint
class RandomAffine:
def __init__(
self,
degrees,
consistent_transform,
scale=None,
translate=None,
shear=None,
image_mean=(123, 116, 103),
log_warning=True,
num_tentatives=1,
image_interpolation="bicubic",
):
"""
The mask is required for this transform.
if consistent_transform if True, then the same random affine is applied to all frames and masks.
"""
self.degrees = degrees if isinstance(degrees, list) else ([-degrees, degrees])
self.scale = scale
self.shear = (
shear if isinstance(shear, list) else ([-shear, shear] if shear else None)
)
self.translate = translate
self.fill_img = image_mean
self.consistent_transform = consistent_transform
self.log_warning = log_warning
self.num_tentatives = num_tentatives
if image_interpolation == "bicubic":
self.image_interpolation = InterpolationMode.BICUBIC
elif image_interpolation == "bilinear":
self.image_interpolation = InterpolationMode.BILINEAR
else:
raise NotImplementedError
def __call__(self, datapoint: VideoDatapoint, **kwargs):
for _tentative in range(self.num_tentatives):
res = self.transform_datapoint(datapoint)
if res is not None:
return res
if self.log_warning:
logging.warning(
f"Skip RandomAffine for zero-area mask in first frame after {self.num_tentatives} tentatives"
)
return datapoint
def transform_datapoint(self, datapoint: VideoDatapoint):
_, height, width = F.get_dimensions(datapoint.frames[0].data)
img_size = [width, height]
if self.consistent_transform:
# Create a random affine transformation
affine_params = T.RandomAffine.get_params(
degrees=self.degrees,
translate=self.translate,
scale_ranges=self.scale,
shears=self.shear,
img_size=img_size,
)
for img_idx, img in enumerate(datapoint.frames):
this_masks = [
obj.segment.unsqueeze(0) if obj.segment is not None else None
for obj in img.objects
]
if not self.consistent_transform:
# if not consistent we create a new affine params for every frame&mask pair Create a random affine transformation
affine_params = T.RandomAffine.get_params(
degrees=self.degrees,
translate=self.translate,
scale_ranges=self.scale,
shears=self.shear,
img_size=img_size,
)
transformed_bboxes, transformed_masks = [], []
for i in range(len(img.objects)):
if this_masks[i] is None:
transformed_masks.append(None)
# Dummy bbox for a dummy target
transformed_bboxes.append(torch.tensor([[0, 0, 1, 1]]))
else:
transformed_mask = F.affine(
this_masks[i],
*affine_params,
interpolation=InterpolationMode.NEAREST,
fill=0.0,
)
if img_idx == 0 and transformed_mask.max() == 0:
# We are dealing with a video and the object is not visible in the first frame
# Return the datapoint without transformation
return None
transformed_masks.append(transformed_mask.squeeze())
for i in range(len(img.objects)):
img.objects[i].segment = transformed_masks[i]
img.data = F.affine(
img.data,
*affine_params,
interpolation=self.image_interpolation,
fill=self.fill_img,
)
return datapoint
def random_mosaic_frame(
datapoint,
index,
grid_h,
grid_w,
target_grid_y,
target_grid_x,
should_hflip,
):
# Step 1: downsize the images and paste them into a mosaic
image_data = datapoint.frames[index].data
is_pil = isinstance(image_data, PILImage.Image)
if is_pil:
H_im = image_data.height
W_im = image_data.width
image_data_output = PILImage.new("RGB", (W_im, H_im))
else:
H_im = image_data.size(-2)
W_im = image_data.size(-1)
image_data_output = torch.zeros_like(image_data)
downsize_cache = {}
for grid_y in range(grid_h):
for grid_x in range(grid_w):
y_offset_b = grid_y * H_im // grid_h
x_offset_b = grid_x * W_im // grid_w
y_offset_e = (grid_y + 1) * H_im // grid_h
x_offset_e = (grid_x + 1) * W_im // grid_w
H_im_downsize = y_offset_e - y_offset_b
W_im_downsize = x_offset_e - x_offset_b
if (H_im_downsize, W_im_downsize) in downsize_cache:
image_data_downsize = downsize_cache[(H_im_downsize, W_im_downsize)]
else:
image_data_downsize = F.resize(
image_data,
size=(H_im_downsize, W_im_downsize),
interpolation=InterpolationMode.BILINEAR,
antialias=True, # antialiasing for downsizing
)
downsize_cache[(H_im_downsize, W_im_downsize)] = image_data_downsize
if should_hflip[grid_y, grid_x].item():
image_data_downsize = F.hflip(image_data_downsize)
if is_pil:
image_data_output.paste(image_data_downsize, (x_offset_b, y_offset_b))
else:
image_data_output[:, y_offset_b:y_offset_e, x_offset_b:x_offset_e] = (
image_data_downsize
)
datapoint.frames[index].data = image_data_output
# Step 2: downsize the masks and paste them into the target grid of the mosaic
for obj in datapoint.frames[index].objects:
if obj.segment is None:
continue
assert obj.segment.shape == (H_im, W_im) and obj.segment.dtype == torch.uint8
segment_output = torch.zeros_like(obj.segment)
target_y_offset_b = target_grid_y * H_im // grid_h
target_x_offset_b = target_grid_x * W_im // grid_w
target_y_offset_e = (target_grid_y + 1) * H_im // grid_h
target_x_offset_e = (target_grid_x + 1) * W_im // grid_w
target_H_im_downsize = target_y_offset_e - target_y_offset_b
target_W_im_downsize = target_x_offset_e - target_x_offset_b
segment_downsize = F.resize(
obj.segment[None, None],
size=(target_H_im_downsize, target_W_im_downsize),
interpolation=InterpolationMode.BILINEAR,
antialias=True, # antialiasing for downsizing
)[0, 0]
if should_hflip[target_grid_y, target_grid_x].item():
segment_downsize = F.hflip(segment_downsize[None, None])[0, 0]
segment_output[
target_y_offset_b:target_y_offset_e, target_x_offset_b:target_x_offset_e
] = segment_downsize
obj.segment = segment_output
return datapoint
class RandomMosaicVideoAPI:
def __init__(self, prob=0.15, grid_h=2, grid_w=2, use_random_hflip=False):
self.prob = prob
self.grid_h = grid_h
self.grid_w = grid_w
self.use_random_hflip = use_random_hflip
def __call__(self, datapoint, **kwargs):
if random.random() > self.prob:
return datapoint
# select a random location to place the target mask in the mosaic
target_grid_y = random.randint(0, self.grid_h - 1)
target_grid_x = random.randint(0, self.grid_w - 1)
# whether to flip each grid in the mosaic horizontally
if self.use_random_hflip:
should_hflip = torch.rand(self.grid_h, self.grid_w) < 0.5
else:
should_hflip = torch.zeros(self.grid_h, self.grid_w, dtype=torch.bool)
for i in range(len(datapoint.frames)):
datapoint = random_mosaic_frame(
datapoint,
i,
grid_h=self.grid_h,
grid_w=self.grid_w,
target_grid_y=target_grid_y,
target_grid_x=target_grid_x,
should_hflip=should_hflip,
)
return datapoint