Vincentqyw commited on
Commit
b864970
·
1 Parent(s): d88270c

update: aspanformer

Browse files
third_party/ASpanFormer/src/ASpanFormer/aspan_module/attention.py CHANGED
@@ -4,6 +4,7 @@ import torch.nn as nn
4
  from itertools import product
5
  from torch.nn import functional as F
6
 
 
7
 
8
  class layernorm2d(nn.Module):
9
  def __init__(self, dim):
@@ -176,7 +177,7 @@ class HierachicalAttention(Module):
176
  offset_sample = self.sample_offset[None, None] * span_scale
177
  sample_pixel = offset[:, :, None] + offset_sample # B*G*r^2*2
178
  sample_norm = (
179
- sample_pixel / torch.tensor([wk / 2, hk / 2]).cuda()[None, None, None] - 1
180
  )
181
 
182
  q = (
 
4
  from itertools import product
5
  from torch.nn import functional as F
6
 
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
9
  class layernorm2d(nn.Module):
10
  def __init__(self, dim):
 
177
  offset_sample = self.sample_offset[None, None] * span_scale
178
  sample_pixel = offset[:, :, None] + offset_sample # B*G*r^2*2
179
  sample_norm = (
180
+ sample_pixel / torch.tensor([wk / 2, hk / 2]).to(device)[None, None, None] - 1
181
  )
182
 
183
  q = (
third_party/DeDoDe/DeDoDe/utils.py CHANGED
@@ -11,6 +11,7 @@ from einops import rearrange
11
  import torch
12
  from time import perf_counter
13
 
 
14
 
15
  def recover_pose(E, kpts0, kpts1, K0, K1, mask):
16
  best_num_inliers = 0
@@ -54,7 +55,7 @@ def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
54
  return ret
55
 
56
 
57
- def get_grid(B, H, W, device="cuda"):
58
  x1_n = torch.meshgrid(
59
  *[torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device) for n in (B, H, W)]
60
  )
@@ -63,7 +64,7 @@ def get_grid(B, H, W, device="cuda"):
63
 
64
 
65
  @torch.no_grad()
66
- def finite_diff_hessian(f: tuple(["B", "H", "W"]), device="cuda"):
67
  dxx = (
68
  torch.tensor([[0, 0, 0], [1, -2, 1], [0, 0, 0]], device=device)[None, None] / 2
69
  )
@@ -78,7 +79,7 @@ def finite_diff_hessian(f: tuple(["B", "H", "W"]), device="cuda"):
78
  return H
79
 
80
 
81
- def finite_diff_grad(f: tuple(["B", "H", "W"]), device="cuda"):
82
  dx = torch.tensor([[0, 0, 0], [-1, 0, 1], [0, 0, 0]], device=device)[None, None] / 2
83
  dy = dx.mT
84
  gx = F.conv2d(f[:, None], dx, padding=1)
@@ -103,7 +104,7 @@ def fast_inv_2x2(matrix: tuple[..., 2, 2], eps=1e-10):
103
  )
104
 
105
 
106
- def newton_step(f: tuple["B", "H", "W"], inds, device="cuda"):
107
  B, H, W = f.shape
108
  Hess = finite_diff_hessian(f).reshape(B, H * W, 2, 2)
109
  Hess = torch.gather(Hess, dim=1, index=inds[..., None].expand(B, -1, 2, 2))
@@ -118,7 +119,7 @@ def newton_step(f: tuple["B", "H", "W"], inds, device="cuda"):
118
  def sample_keypoints(
119
  scoremap,
120
  num_samples=8192,
121
- device="cuda",
122
  use_nms=True,
123
  sample_topk=False,
124
  return_scoremap=False,
@@ -176,7 +177,7 @@ def sample_keypoints(
176
 
177
 
178
  @torch.no_grad()
179
- def jacobi_determinant(warp, certainty, R=3, device="cuda", dtype=torch.float32):
180
  t = perf_counter()
181
  *dims, _ = warp.shape
182
  warp = warp.to(dtype)
@@ -831,7 +832,7 @@ def homog_transform(Homog, x):
831
  return y
832
 
833
 
834
- def get_homog_warp(Homog, H, W, device="cuda"):
835
  grid = torch.meshgrid(
836
  torch.linspace(-1 + 1 / H, 1 - 1 / H, H, device=device),
837
  torch.linspace(-1 + 1 / W, 1 - 1 / W, W, device=device),
 
11
  import torch
12
  from time import perf_counter
13
 
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
  def recover_pose(E, kpts0, kpts1, K0, K1, mask):
17
  best_num_inliers = 0
 
55
  return ret
56
 
57
 
58
+ def get_grid(B, H, W, device=device):
59
  x1_n = torch.meshgrid(
60
  *[torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device) for n in (B, H, W)]
61
  )
 
64
 
65
 
66
  @torch.no_grad()
67
+ def finite_diff_hessian(f: tuple(["B", "H", "W"]), device=device):
68
  dxx = (
69
  torch.tensor([[0, 0, 0], [1, -2, 1], [0, 0, 0]], device=device)[None, None] / 2
70
  )
 
79
  return H
80
 
81
 
82
+ def finite_diff_grad(f: tuple(["B", "H", "W"]), device=device):
83
  dx = torch.tensor([[0, 0, 0], [-1, 0, 1], [0, 0, 0]], device=device)[None, None] / 2
84
  dy = dx.mT
85
  gx = F.conv2d(f[:, None], dx, padding=1)
 
104
  )
105
 
106
 
107
+ def newton_step(f: tuple["B", "H", "W"], inds, device=device):
108
  B, H, W = f.shape
109
  Hess = finite_diff_hessian(f).reshape(B, H * W, 2, 2)
110
  Hess = torch.gather(Hess, dim=1, index=inds[..., None].expand(B, -1, 2, 2))
 
119
  def sample_keypoints(
120
  scoremap,
121
  num_samples=8192,
122
+ device=device,
123
  use_nms=True,
124
  sample_topk=False,
125
  return_scoremap=False,
 
177
 
178
 
179
  @torch.no_grad()
180
+ def jacobi_determinant(warp, certainty, R=3, device=device, dtype=torch.float32):
181
  t = perf_counter()
182
  *dims, _ = warp.shape
183
  warp = warp.to(dtype)
 
832
  return y
833
 
834
 
835
+ def get_homog_warp(Homog, H, W, device=device):
836
  grid = torch.meshgrid(
837
  torch.linspace(-1 + 1 / H, 1 - 1 / H, H, device=device),
838
  torch.linspace(-1 + 1 / W, 1 - 1 / W, W, device=device),