Spaces:
Paused
Paused
File size: 1,863 Bytes
d8431dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import torch
def convert_flow_to_deformation(flow):
r"""convert flow fields to deformations.
Args:
flow (tensor): Flow field obtained by the model
Returns:
deformation (tensor): The deformation used for warping
"""
b,c,h,w = flow.shape
flow_norm = 2 * torch.cat([flow[:,:1,...]/(w-1),flow[:,1:,...]/(h-1)], 1)
grid = make_coordinate_grid(flow)
deformation = grid + flow_norm.permute(0,2,3,1)
return deformation
def make_coordinate_grid(flow):
r"""obtain coordinate grid with the same size as the flow filed.
Args:
flow (tensor): Flow field obtained by the model
Returns:
grid (tensor): The grid with the same size as the input flow
"""
b,c,h,w = flow.shape
x = torch.arange(w).to(flow)
y = torch.arange(h).to(flow)
x = (2 * (x / (w - 1)) - 1)
y = (2 * (y / (h - 1)) - 1)
yy = y.view(-1, 1).repeat(1, w)
xx = x.view(1, -1).repeat(h, 1)
meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
meshed = meshed.expand(b, -1, -1, -1)
return meshed
def warp_image(source_image, deformation):
r"""warp the input image according to the deformation
Args:
source_image (tensor): source images to be warped
deformation (tensor): deformations used to warp the images; value in range (-1, 1)
Returns:
output (tensor): the warped images
"""
_, h_old, w_old, _ = deformation.shape
_, _, h, w = source_image.shape
if h_old != h or w_old != w:
deformation = deformation.permute(0, 3, 1, 2)
deformation = torch.nn.functional.interpolate(deformation, size=(h, w), mode='bilinear')
deformation = deformation.permute(0, 2, 3, 1)
return torch.nn.functional.grid_sample(source_image, deformation)
|