Realcat
add: GIM (https://github.com/xuelunshen/gim)
c0283b3
raw
history blame
1.74 kB
import torch
def extract_patches(
tensor: torch.Tensor,
required_corners: torch.Tensor,
ps: int,
) -> torch.Tensor:
c, h, w = tensor.shape
corner = required_corners.long()
corner[:, 0] = corner[:, 0].clamp(min=0, max=w - 1 - ps)
corner[:, 1] = corner[:, 1].clamp(min=0, max=h - 1 - ps)
offset = torch.arange(0, ps)
kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
x, y = torch.meshgrid(offset, offset, **kw)
patches = torch.stack((x, y)).permute(2, 1, 0).unsqueeze(2)
patches = patches.to(corner) + corner[None, None]
pts = patches.reshape(-1, 2)
sampled = tensor.permute(1, 2, 0)[tuple(pts.T)[::-1]]
sampled = sampled.reshape(ps, ps, -1, c)
assert sampled.shape[:3] == patches.shape[:3]
return sampled.permute(2, 3, 0, 1), corner.float()
def batch_extract_patches(tensor: torch.Tensor, kpts: torch.Tensor, ps: int):
b, c, h, w = tensor.shape
b, n, _ = kpts.shape
out = torch.zeros((b, n, c, ps, ps), dtype=tensor.dtype, device=tensor.device)
corners = torch.zeros((b, n, 2), dtype=tensor.dtype, device=tensor.device)
for i in range(b):
out[i], corners[i] = extract_patches(tensor[i], kpts[i] - ps / 2 - 1, ps)
return out, corners
def draw_image_patches(img, patches, corners):
b, c, h, w = img.shape
b, n, c, p, p = patches.shape
b, n, _ = corners.shape
for i in range(b):
for k in range(n):
y, x = corners[i, k]
img[i, :, x : x + p, y : y + p] = patches[i, k]
def build_heatmap(img, patches, corners):
hmap = torch.zeros_like(img)
draw_image_patches(hmap, patches, corners.long())
hmap = hmap.squeeze(1)
return hmap, (hmap > 0.0).float() # bxhxw