Spaces:
Runtime error
Runtime error
File size: 2,276 Bytes
6a62ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import torch
from torch import Tensor
from torch.nn.functional import affine_grid, grid_sample
def apply_rgb_change(alpha: Tensor, color_change: Tensor, image: Tensor):
image_rgb = image[:, 0:3, :, :]
color_change_rgb = color_change[:, 0:3, :, :]
output_rgb = color_change_rgb * alpha + image_rgb * (1 - alpha)
return torch.cat([output_rgb, image[:, 3:4, :, :]], dim=1)
def apply_grid_change(grid_change, image: Tensor) -> Tensor:
n, c, h, w = image.shape
device = grid_change.device
grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2)
identity = torch.tensor(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
dtype=grid_change.dtype,
device=device).unsqueeze(0).repeat(n, 1, 1)
base_grid = affine_grid(identity, [n, c, h, w], align_corners=False)
grid = base_grid + grid_change
resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=False)
return resampled_image
class GridChangeApplier:
def __init__(self):
self.last_n = None
self.last_device = None
self.last_identity = None
def apply(self, grid_change: Tensor, image: Tensor, align_corners: bool = False) -> Tensor:
n, c, h, w = image.shape
device = grid_change.device
grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2)
if n == self.last_n and device == self.last_device:
identity = self.last_identity
else:
identity = torch.tensor(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
dtype=grid_change.dtype,
device=device,
requires_grad=False) \
.unsqueeze(0).repeat(n, 1, 1)
self.last_identity = identity
self.last_n = n
self.last_device = device
base_grid = affine_grid(identity, [n, c, h, w], align_corners=align_corners)
grid = base_grid + grid_change
resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=align_corners)
return resampled_image
def apply_color_change(alpha, color_change, image: Tensor) -> Tensor:
return color_change * alpha + image * (1 - alpha)
|