Spaces:
Running
Running
import torch | |
def extract_patches( | |
tensor: torch.Tensor, | |
required_corners: torch.Tensor, | |
ps: int, | |
) -> torch.Tensor: | |
c, h, w = tensor.shape | |
corner = required_corners.long() | |
corner[:, 0] = corner[:, 0].clamp(min=0, max=w - 1 - ps) | |
corner[:, 1] = corner[:, 1].clamp(min=0, max=h - 1 - ps) | |
offset = torch.arange(0, ps) | |
kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {} | |
x, y = torch.meshgrid(offset, offset, **kw) | |
patches = torch.stack((x, y)).permute(2, 1, 0).unsqueeze(2) | |
patches = patches.to(corner) + corner[None, None] | |
pts = patches.reshape(-1, 2) | |
sampled = tensor.permute(1, 2, 0)[tuple(pts.T)[::-1]] | |
sampled = sampled.reshape(ps, ps, -1, c) | |
assert sampled.shape[:3] == patches.shape[:3] | |
return sampled.permute(2, 3, 0, 1), corner.float() | |
def batch_extract_patches(tensor: torch.Tensor, kpts: torch.Tensor, ps: int): | |
b, c, h, w = tensor.shape | |
b, n, _ = kpts.shape | |
out = torch.zeros((b, n, c, ps, ps), dtype=tensor.dtype, device=tensor.device) | |
corners = torch.zeros((b, n, 2), dtype=tensor.dtype, device=tensor.device) | |
for i in range(b): | |
out[i], corners[i] = extract_patches(tensor[i], kpts[i] - ps / 2 - 1, ps) | |
return out, corners | |
def draw_image_patches(img, patches, corners): | |
b, c, h, w = img.shape | |
b, n, c, p, p = patches.shape | |
b, n, _ = corners.shape | |
for i in range(b): | |
for k in range(n): | |
y, x = corners[i, k] | |
img[i, :, x : x + p, y : y + p] = patches[i, k] | |
def build_heatmap(img, patches, corners): | |
hmap = torch.zeros_like(img) | |
draw_image_patches(hmap, patches, corners.long()) | |
hmap = hmap.squeeze(1) | |
return hmap, (hmap > 0.0).float() # bxhxw | |