|
import torch |
|
import torch.nn.functional as F |
|
import random |
|
import kornia |
|
from torchvision.transforms.functional import adjust_brightness, adjust_contrast |
|
|
|
from climategan.tutils import normalize, retrieve_sky_mask |
|
|
|
try: |
|
from kornia.filters import filter2d |
|
except ImportError: |
|
from kornia.filters import filter2D as filter2d |
|
|
|
|
|
def increase_sky_mask(mask, p_w=0, p_h=0): |
|
""" |
|
Increases sky mask in width and height by a given pourcentage |
|
(Purpose: when applying Gaussian blur, there are no artifacts of blue sky behind) |
|
Args: |
|
sky_mask (torch.Tensor): Sky mask of shape (H,W) |
|
p_w (float): Percentage of mask width by which to increase |
|
the width of the sky region |
|
p_h (float): Percentage of mask height by which to increase |
|
the height of the sky region |
|
Returns: |
|
torch.Tensor: Sky mask increased given p_w and p_h |
|
""" |
|
|
|
if p_h <= 0 and p_w <= 0: |
|
return mask |
|
|
|
n_lines = int(p_h * mask.shape[-2]) |
|
n_cols = int(p_w * mask.shape[-1]) |
|
|
|
temp_mask = mask.clone().detach() |
|
for i in range(1, n_cols): |
|
temp_mask[:, :, :, i::] += mask[:, :, :, 0:-i] |
|
temp_mask[:, :, :, 0:-i] += mask[:, :, :, i::] |
|
|
|
new_mask = temp_mask.clone().detach() |
|
for i in range(1, n_lines): |
|
new_mask[:, :, i::, :] += temp_mask[:, :, 0:-i, :] |
|
new_mask[:, :, 0:-i, :] += temp_mask[:, :, i::, :] |
|
|
|
new_mask[new_mask >= 1] = 1 |
|
|
|
return new_mask |
|
|
|
|
|
def paste_filter(x, filter_, mask): |
|
""" |
|
Pastes a filter over an image given a mask |
|
Where the mask is 1, the filter is copied as is. |
|
Where the mask is 0, the current value is preserved. |
|
Intermediate values will mix the two images together. |
|
Args: |
|
x (torch.Tensor): Input tensor, range must be [0, 255] |
|
filer_ (torch.Tensor): Filter, range must be [0, 255] |
|
mask (torch.Tensor): Mask, range must be [0, 1] |
|
Returns: |
|
torch.Tensor: New tensor with filter pasted on it |
|
""" |
|
assert len(x.shape) == len(filter_.shape) == len(mask.shape) |
|
x = filter_ * mask + x * (1 - mask) |
|
return x |
|
|
|
|
|
def add_fire(x, seg_preds, fire_opts): |
|
""" |
|
Transforms input tensor given wildfires event |
|
Args: |
|
x (torch.Tensor): Input tensor |
|
seg_preds (torch.Tensor): Semantic segmentation predictions for input tensor |
|
filter_color (tuple): (r,g,b) tuple for the color of the sky |
|
blur_radius (float): radius of the Gaussian blur that smooths |
|
the transition between sky and foreground |
|
Returns: |
|
torch.Tensor: Wildfire version of input tensor |
|
""" |
|
wildfire_tens = normalize(x, 0, 255) |
|
|
|
|
|
wildfire_tens[:, 2, :, :] -= 20 |
|
wildfire_tens[:, 1, :, :] -= 10 |
|
wildfire_tens[:, 0, :, :] += 40 |
|
wildfire_tens.clamp_(0, 255) |
|
wildfire_tens = wildfire_tens.to(torch.uint8) |
|
|
|
|
|
wildfire_tens = adjust_contrast(wildfire_tens, contrast_factor=1.5) |
|
wildfire_tens = adjust_brightness(wildfire_tens, brightness_factor=0.73) |
|
|
|
sky_mask = retrieve_sky_mask(seg_preds).unsqueeze(1) |
|
|
|
if fire_opts.get("crop_bottom_sky_mask"): |
|
i = 2 * sky_mask.shape[-2] // 3 |
|
sky_mask[..., i:, :] = 0 |
|
|
|
sky_mask = F.interpolate( |
|
sky_mask.to(torch.float), |
|
(wildfire_tens.shape[-2], wildfire_tens.shape[-1]), |
|
) |
|
sky_mask = increase_sky_mask(sky_mask, 0.18, 0.18) |
|
|
|
kernel_size = (fire_opts.get("kernel_size", 301), fire_opts.get("kernel_size", 301)) |
|
sigma = (fire_opts.get("kernel_sigma", 150.5), fire_opts.get("kernel_sigma", 150.5)) |
|
border_type = "reflect" |
|
kernel = torch.unsqueeze( |
|
kornia.filters.kernels.get_gaussian_kernel2d(kernel_size, sigma), dim=0 |
|
).to(x.device) |
|
sky_mask = filter2d(sky_mask, kernel, border_type) |
|
|
|
filter_ = torch.ones(wildfire_tens.shape, device=x.device) |
|
filter_[:, 0, :, :] = 255 |
|
filter_[:, 1, :, :] = random.randint(100, 150) |
|
filter_[:, 2, :, :] = 0 |
|
|
|
wildfire_tens = paste_tensor(wildfire_tens, filter_, sky_mask, 200) |
|
|
|
wildfire_tens = adjust_brightness(wildfire_tens.to(torch.uint8), 0.8) |
|
wildfire_tens = wildfire_tens.to(torch.float) |
|
|
|
|
|
wildfire_tens[:, :, 0, 0] = 255.0 |
|
wildfire_tens[:, :, -1, -1] = 0.0 |
|
|
|
return wildfire_tens |
|
|
|
|
|
def paste_tensor(source, filter_, mask, transparency): |
|
mask = transparency / 255.0 * mask |
|
new = mask * filter_ + (1.0 - mask) * source |
|
return new |
|
|