Spaces:
Sleeping
Sleeping
""" | |
This implementation is based on | |
https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py | |
pulished under an Apache License 2.0. | |
""" | |
import math | |
import random | |
import torch | |
def _get_pixels( | |
per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda" | |
): | |
# NOTE I've seen CUDA illegal memory access errors being caused by the normal_() | |
# paths, flip the order so normal is run on CPU if this becomes a problem | |
# Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 | |
if per_pixel: | |
return torch.empty(patch_size, dtype=dtype, device=device).normal_() | |
elif rand_color: | |
return torch.empty( | |
(patch_size[0], 1, 1), dtype=dtype, device=device | |
).normal_() | |
else: | |
return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) | |
class RandomErasing: | |
"""Randomly selects a rectangle region in an image and erases its pixels. | |
'Random Erasing Data Augmentation' by Zhong et al. | |
See https://arxiv.org/pdf/1708.04896.pdf | |
This variant of RandomErasing is intended to be applied to either a batch | |
or single image tensor after it has been normalized by dataset mean and std. | |
Args: | |
probability: Probability that the Random Erasing operation will be performed. | |
min_area: Minimum percentage of erased area wrt input image area. | |
max_area: Maximum percentage of erased area wrt input image area. | |
min_aspect: Minimum aspect ratio of erased area. | |
mode: pixel color mode, one of 'const', 'rand', or 'pixel' | |
'const' - erase block is constant color of 0 for all channels | |
'rand' - erase block is same per-channel random (normal) color | |
'pixel' - erase block is per-pixel random (normal) color | |
max_count: maximum number of erasing blocks per image, area per box is scaled by count. | |
per-image count is randomly chosen between 1 and this value. | |
""" | |
def __init__( | |
self, | |
probability=0.5, | |
min_area=0.02, | |
max_area=1 / 3, | |
min_aspect=0.3, | |
max_aspect=None, | |
mode="const", | |
min_count=1, | |
max_count=None, | |
num_splits=0, | |
device="cuda", | |
cube=True, | |
): | |
self.probability = probability | |
self.min_area = min_area | |
self.max_area = max_area | |
max_aspect = max_aspect or 1 / min_aspect | |
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) | |
self.min_count = min_count | |
self.max_count = max_count or min_count | |
self.num_splits = num_splits | |
mode = mode.lower() | |
self.rand_color = False | |
self.per_pixel = False | |
self.cube = cube | |
if mode == "rand": | |
self.rand_color = True # per block random normal | |
elif mode == "pixel": | |
self.per_pixel = True # per pixel random normal | |
else: | |
assert not mode or mode == "const" | |
self.device = device | |
def _erase(self, img, chan, img_h, img_w, dtype): | |
if random.random() > self.probability: | |
return | |
area = img_h * img_w | |
count = ( | |
self.min_count | |
if self.min_count == self.max_count | |
else random.randint(self.min_count, self.max_count) | |
) | |
for _ in range(count): | |
for _ in range(10): | |
target_area = ( | |
random.uniform(self.min_area, self.max_area) * area / count | |
) | |
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) | |
h = int(round(math.sqrt(target_area * aspect_ratio))) | |
w = int(round(math.sqrt(target_area / aspect_ratio))) | |
if w < img_w and h < img_h: | |
top = random.randint(0, img_h - h) | |
left = random.randint(0, img_w - w) | |
img[:, top : top + h, left : left + w] = _get_pixels( | |
self.per_pixel, | |
self.rand_color, | |
(chan, h, w), | |
dtype=dtype, | |
device=self.device, | |
) | |
break | |
def _erase_cube( | |
self, | |
img, | |
batch_start, | |
batch_size, | |
chan, | |
img_h, | |
img_w, | |
dtype, | |
): | |
if random.random() > self.probability: | |
return | |
area = img_h * img_w | |
count = ( | |
self.min_count | |
if self.min_count == self.max_count | |
else random.randint(self.min_count, self.max_count) | |
) | |
for _ in range(count): | |
for _ in range(100): | |
target_area = ( | |
random.uniform(self.min_area, self.max_area) * area / count | |
) | |
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) | |
h = int(round(math.sqrt(target_area * aspect_ratio))) | |
w = int(round(math.sqrt(target_area / aspect_ratio))) | |
if w < img_w and h < img_h: | |
top = random.randint(0, img_h - h) | |
left = random.randint(0, img_w - w) | |
for i in range(batch_start, batch_size): | |
img_instance = img[i] | |
img_instance[ | |
:, top : top + h, left : left + w | |
] = _get_pixels( | |
self.per_pixel, | |
self.rand_color, | |
(chan, h, w), | |
dtype=dtype, | |
device=self.device, | |
) | |
break | |
def __call__(self, input): | |
if len(input.size()) == 3: | |
self._erase(input, *input.size(), input.dtype) | |
else: | |
batch_size, chan, img_h, img_w = input.size() | |
# skip first slice of batch if num_splits is set (for clean portion of samples) | |
batch_start = ( | |
batch_size // self.num_splits if self.num_splits > 1 else 0 | |
) | |
if self.cube: | |
self._erase_cube( | |
input, | |
batch_start, | |
batch_size, | |
chan, | |
img_h, | |
img_w, | |
input.dtype, | |
) | |
else: | |
for i in range(batch_start, batch_size): | |
self._erase(input[i], chan, img_h, img_w, input.dtype) | |
return input | |