Seokju Cho
commited on
Commit
•
6b9382c
1
Parent(s):
e11cc45
improve speed
Browse files- .gitattributes +2 -0
- app.py +6 -2
- locotrack_pytorch/models/cmdtop.py +4 -2
- locotrack_pytorch/models/locotrack_model.py +46 -81
- locotrack_pytorch/models/utils.py +1 -44
- requirements.txt +1 -1
- weights/locotrack_base.ckpt +3 -0
- weights/locotrack_small.ckpt +3 -0
.gitattributes
CHANGED
@@ -42,3 +42,5 @@ examples/libby.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
42 |
examples/motocross-jump.mp4 filter=lfs diff=lfs merge=lfs -text
|
43 |
examples/bmx-trees.mp4 filter=lfs diff=lfs merge=lfs -text
|
44 |
examples/parkour.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
42 |
examples/motocross-jump.mp4 filter=lfs diff=lfs merge=lfs -text
|
43 |
examples/bmx-trees.mp4 filter=lfs diff=lfs merge=lfs -text
|
44 |
examples/parkour.mp4 filter=lfs diff=lfs merge=lfs -text
|
45 |
+
weights/locotrack_base.ckpt filter=lfs diff=lfs merge=lfs -text
|
46 |
+
weights/locotrack_small.ckpt filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -19,6 +19,10 @@ PREVIEW_WIDTH = 768 # Width of the preview video
|
|
19 |
VIDEO_INPUT_RESO = (256, 256) # Resolution of the input video
|
20 |
POINT_SIZE = 4 # Size of the query point in the preview video
|
21 |
FRAME_LIMIT = 300 # Limit the number of frames to process
|
|
|
|
|
|
|
|
|
22 |
|
23 |
|
24 |
def get_point(frame_num, video_queried_preview, query_points, query_points_color, query_count, evt: gr.SelectData):
|
@@ -120,7 +124,7 @@ def extract_feature(video_input, model_size="small"):
|
|
120 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
121 |
dtype = torch.bfloat16 if device == "cuda" else torch.float16
|
122 |
|
123 |
-
model = load_model(model_size=model_size).to(device)
|
124 |
|
125 |
video_input = (video_input / 255.0) * 2 - 1
|
126 |
video_input = torch.tensor(video_input).unsqueeze(0).to(device, dtype)
|
@@ -223,7 +227,7 @@ def track(
|
|
223 |
video_input = (video_input / 255.0) * 2 - 1
|
224 |
video_input = torch.tensor(video_input).unsqueeze(0).to(device, dtype)
|
225 |
|
226 |
-
model = load_model(model_size=model_size).to(device)
|
227 |
with torch.autocast(device_type=device, dtype=dtype):
|
228 |
with torch.no_grad():
|
229 |
output = model(video_input, query_points_tensor, feature_grids=video_feature)
|
|
|
19 |
VIDEO_INPUT_RESO = (256, 256) # Resolution of the input video
|
20 |
POINT_SIZE = 4 # Size of the query point in the preview video
|
21 |
FRAME_LIMIT = 300 # Limit the number of frames to process
|
22 |
+
WEIGHTS_PATH = {
|
23 |
+
"small": "./weights/locotrack_small.ckpt",
|
24 |
+
"base": "./weights/locotrack_base.ckpt",
|
25 |
+
}
|
26 |
|
27 |
|
28 |
def get_point(frame_num, video_queried_preview, query_points, query_points_color, query_count, evt: gr.SelectData):
|
|
|
124 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
125 |
dtype = torch.bfloat16 if device == "cuda" else torch.float16
|
126 |
|
127 |
+
model = load_model(WEIGHTS_PATH[model_size], model_size=model_size).to(device)
|
128 |
|
129 |
video_input = (video_input / 255.0) * 2 - 1
|
130 |
video_input = torch.tensor(video_input).unsqueeze(0).to(device, dtype)
|
|
|
227 |
video_input = (video_input / 255.0) * 2 - 1
|
228 |
video_input = torch.tensor(video_input).unsqueeze(0).to(device, dtype)
|
229 |
|
230 |
+
model = load_model(WEIGHTS_PATH[model_size], model_size=model_size).to(device)
|
231 |
with torch.autocast(device_type=device, dtype=dtype):
|
232 |
with torch.no_grad():
|
233 |
output = model(video_input, query_points_tensor, feature_grids=video_feature)
|
locotrack_pytorch/models/cmdtop.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
|
|
|
|
4 |
from models import utils
|
5 |
|
6 |
|
@@ -29,8 +31,8 @@ class CMDTop(nn.Module):
|
|
29 |
"""
|
30 |
x: (b, h, w, i, j)
|
31 |
"""
|
32 |
-
out1 =
|
33 |
-
out2 =
|
34 |
|
35 |
for i in range(len(self.out_channels)):
|
36 |
out1 = self.conv[i](out1)
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
from models import utils
|
7 |
|
8 |
|
|
|
31 |
"""
|
32 |
x: (b, h, w, i, j)
|
33 |
"""
|
34 |
+
out1 = rearrange(x, 'b h w i j -> b (i j) h w')
|
35 |
+
out2 = rearrange(x, 'b h w i j -> b (h w) i j')
|
36 |
|
37 |
for i in range(len(self.out_channels)):
|
38 |
out1 = self.conv[i](out1)
|
locotrack_pytorch/models/locotrack_model.py
CHANGED
@@ -22,6 +22,7 @@ import torch
|
|
22 |
from torch import nn
|
23 |
import torch.nn.functional as F
|
24 |
import numpy as np
|
|
|
25 |
|
26 |
from models import nets, utils
|
27 |
from models.cmdtop import CMDTop
|
@@ -57,15 +58,15 @@ def posenc(x, min_deg, max_deg, legacy_posenc_order=False):
|
|
57 |
return torch.cat([x] + [four_feat], dim=-1)
|
58 |
|
59 |
|
60 |
-
def get_relative_positions(seq_len, reverse=False):
|
61 |
-
x = torch.arange(seq_len)[None, :]
|
62 |
-
y = torch.arange(seq_len)[:, None]
|
63 |
return torch.tril(x - y) if not reverse else torch.triu(y - x)
|
64 |
|
65 |
|
66 |
-
def get_alibi_slope(num_heads):
|
67 |
x = (24) ** (1 / num_heads)
|
68 |
-
return torch.tensor([1 / x ** (i + 1) for i in range(num_heads)], dtype=torch.float32).view(-1, 1, 1)
|
69 |
|
70 |
|
71 |
class MultiHeadAttention(nn.Module):
|
@@ -92,31 +93,22 @@ class MultiHeadAttention(nn.Module):
|
|
92 |
key_heads = self._linear_projection(key, self.key_size, self.key_proj) # [T, H, K]
|
93 |
value_heads = self._linear_projection(value, self.value_size, self.value_proj) # [T, H, V]
|
94 |
|
95 |
-
|
|
|
96 |
bias_forward = bias_forward + torch.triu(torch.full_like(bias_forward, -1e9), diagonal=1)
|
97 |
-
bias_backward = get_alibi_slope(self.num_heads // 2) * get_relative_positions(sequence_length, reverse=True)
|
98 |
bias_backward = bias_backward + torch.tril(torch.full_like(bias_backward, -1e9), diagonal=-1)
|
99 |
-
attn_bias = torch.cat([bias_forward, bias_backward], dim=0)
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
if mask is not None:
|
105 |
-
if mask.ndim != attn_logits.ndim:
|
106 |
-
raise ValueError(f"Mask dimensionality {mask.ndim} must match logits dimensionality {attn_logits.ndim}.")
|
107 |
-
attn_logits = torch.where(mask, attn_logits, torch.tensor(-1e30))
|
108 |
-
|
109 |
-
attn_weights = F.softmax(attn_logits, dim=-1) # [H, T', T]
|
110 |
-
|
111 |
-
attn = torch.einsum("...htT,...Thd->...thd", attn_weights, value_heads)
|
112 |
-
attn = attn.reshape(batch_size, sequence_length, -1) # [T', H*V]
|
113 |
|
114 |
return self.final_proj(attn) # [T', D']
|
115 |
|
116 |
def _linear_projection(self, x, head_size, proj_layer):
|
117 |
y = proj_layer(x)
|
118 |
-
|
119 |
-
return y.reshape((
|
120 |
|
121 |
|
122 |
class Transformer(nn.Module):
|
@@ -495,25 +487,25 @@ class LocoTrack(nn.Module):
|
|
495 |
ctx = torch.reshape(ctx, [-1, 3]).to(video.device) # s*s 3
|
496 |
|
497 |
position_support = position_in_grid[..., None, :] + ctx[None, None, ...] # b n s*s 3
|
498 |
-
position_support =
|
499 |
interp_supp = utils.map_coordinates_3d(
|
500 |
feature_grid[i], position_support
|
501 |
)
|
502 |
-
interp_supp =
|
503 |
|
504 |
position_support_hires = position_in_grid_hires[..., None, :] + ctx[None, None, ...]
|
505 |
-
position_support_hires =
|
506 |
hires_interp_supp = utils.map_coordinates_3d(
|
507 |
hires_feats[i], position_support_hires
|
508 |
)
|
509 |
-
hires_interp_supp =
|
510 |
|
511 |
position_support_highest = position_in_grid_highest[..., None, :] + ctx[None, None, ...]
|
512 |
-
position_support_highest =
|
513 |
highest_interp_supp = utils.map_coordinates_3d(
|
514 |
highest_feats[i], position_support_highest
|
515 |
)
|
516 |
-
highest_interp_supp =
|
517 |
|
518 |
interp_features = interp_supp[..., support_size // 2, support_size // 2, :]
|
519 |
hires_interp = hires_interp_supp[..., support_size // 2, support_size // 2, :]
|
@@ -559,7 +551,7 @@ class LocoTrack(nn.Module):
|
|
559 |
video.shape[2:4], self.initial_resolution
|
560 |
)
|
561 |
|
562 |
-
all_required_resolutions = [
|
563 |
all_required_resolutions.extend(refinement_resolutions)
|
564 |
|
565 |
feature_grid = []
|
@@ -715,30 +707,14 @@ class LocoTrack(nn.Module):
|
|
715 |
)
|
716 |
|
717 |
num_queries = query_features.lowres[0].shape[1]
|
718 |
-
if causal_context is None:
|
719 |
-
perm = torch.randperm(num_queries)
|
720 |
-
else:
|
721 |
-
perm = torch.arange(num_queries)
|
722 |
-
|
723 |
-
inv_perm = torch.zeros_like(perm)
|
724 |
-
inv_perm[perm] = torch.arange(num_queries)
|
725 |
|
726 |
for ch in range(0, num_queries, query_chunk_size):
|
727 |
-
|
728 |
-
|
729 |
-
chunk_hires = query_features.hires[0][:, perm_chunk]
|
730 |
-
|
731 |
-
cc_chunk = []
|
732 |
-
if causal_context is not None:
|
733 |
-
for d in range(len(causal_context)):
|
734 |
-
tmp_dict = {}
|
735 |
-
for k, v in causal_context[d].items():
|
736 |
-
tmp_dict[k] = v[:, perm_chunk]
|
737 |
-
cc_chunk.append(tmp_dict)
|
738 |
|
739 |
if query_points_in_video is not None:
|
740 |
infer_query_points = query_points_in_video[
|
741 |
-
:,
|
742 |
]
|
743 |
num_frames = feature_grids.lowres[0].shape[1]
|
744 |
infer_query_points = utils.convert_grid_coordinates(
|
@@ -765,14 +741,14 @@ class LocoTrack(nn.Module):
|
|
765 |
for i in range(num_iters):
|
766 |
feature_level = -1
|
767 |
queries = [
|
768 |
-
query_features.hires[feature_level][:,
|
769 |
-
query_features.lowres[feature_level][:,
|
770 |
-
query_features.highest[feature_level][:,
|
771 |
]
|
772 |
supports = [
|
773 |
-
query_features.hires_supp[feature_level][:,
|
774 |
-
query_features.lowres_supp[feature_level][:,
|
775 |
-
query_features.highest_supp[feature_level][:,
|
776 |
]
|
777 |
for _ in range(self.pyramid_level):
|
778 |
queries.append(queries[-1])
|
@@ -790,7 +766,7 @@ class LocoTrack(nn.Module):
|
|
790 |
padding=0,
|
791 |
)
|
792 |
)
|
793 |
-
|
794 |
refined = self.refine_pips(
|
795 |
queries,
|
796 |
supports,
|
@@ -803,7 +779,6 @@ class LocoTrack(nn.Module):
|
|
803 |
last_iter=mixer_feats,
|
804 |
mixer_iter=i,
|
805 |
resize_hw=feature_grids.resolutions[feature_level],
|
806 |
-
causal_context=cc,
|
807 |
get_causal_context=get_causal_context,
|
808 |
cost_volume=cost_volume
|
809 |
)
|
@@ -822,9 +797,9 @@ class LocoTrack(nn.Module):
|
|
822 |
points = []
|
823 |
expd = []
|
824 |
for i, _ in enumerate(occ_iters):
|
825 |
-
occlusion.append(torch.cat(occ_iters[i], dim=1)
|
826 |
-
points.append(torch.cat(pts_iters[i], dim=1)
|
827 |
-
expd.append(torch.cat(expd_iters[i], dim=1)
|
828 |
|
829 |
out = dict(
|
830 |
occlusion=occlusion,
|
@@ -874,11 +849,11 @@ class LocoTrack(nn.Module):
|
|
874 |
coords2 = coords.unsqueeze(3) + ctx.unsqueeze(0).unsqueeze(0).unsqueeze(0)
|
875 |
neighborhood = utils.map_coordinates_2d(grid, coords2)
|
876 |
|
877 |
-
neighborhood =
|
878 |
patches_input = torch.einsum('bnthwc,bnijc->bnthwij', neighborhood, supp)
|
879 |
-
patches_input =
|
880 |
patches_emb = self.cmdtop[pyridx](patches_input)
|
881 |
-
patches =
|
882 |
|
883 |
corrs_pyr.append(patches)
|
884 |
corrs_pyr = torch.concatenate(corrs_pyr, dim=-1)
|
@@ -913,14 +888,10 @@ class LocoTrack(nn.Module):
|
|
913 |
mlp_input_list.append(rel_pos_emb_input)
|
914 |
mlp_input = torch.cat(mlp_input_list, axis=-1)
|
915 |
|
916 |
-
x =
|
917 |
-
|
918 |
-
if causal_context is not None:
|
919 |
-
for k, v in causal_context.items():
|
920 |
-
causal_context[k] = utils.einshape('bn...->(bn)...', v)
|
921 |
res = self.torch_pips_mixer(x)
|
922 |
|
923 |
-
res =
|
924 |
|
925 |
pos_update = utils.convert_grid_coordinates(
|
926 |
res[..., :2],
|
@@ -983,20 +954,18 @@ class LocoTrack(nn.Module):
|
|
983 |
shape = cost_volume.shape
|
984 |
batch_size, num_points = cost_volume.shape[1:3]
|
985 |
|
986 |
-
interp_cost =
|
987 |
interp_cost = F.interpolate(interp_cost, cost_volume_hires.shape[3:], mode='bilinear', align_corners=False)
|
988 |
-
|
989 |
-
interp_cost = utils.einshape('(tbn)1hw->tbnhw', interp_cost, b=batch_size, n=num_points)
|
990 |
cost_volume_stack = torch.stack(
|
991 |
[
|
992 |
-
# jax.image.resize(cost_volume, cost_volume_hires.shape, method='bilinear'),
|
993 |
interp_cost,
|
994 |
cost_volume_hires,
|
995 |
], dim=-1
|
996 |
)
|
997 |
-
pos =
|
998 |
pos = self.cost_conv(pos)
|
999 |
-
pos =
|
1000 |
|
1001 |
pos_sm = pos.reshape(pos.size(0), pos.size(1), pos.size(2), -1)
|
1002 |
softmaxed = F.softmax(pos_sm * self.softmax_temperature, dim=-1)
|
@@ -1012,14 +981,10 @@ class LocoTrack(nn.Module):
|
|
1012 |
], dim=-1
|
1013 |
)
|
1014 |
occlusion = self.occ_linear(occlusion)
|
1015 |
-
expected_dist =
|
1016 |
-
|
1017 |
-
)
|
1018 |
-
occlusion = utils.einshape(
|
1019 |
-
'tbn1->bnt', occlusion[..., 0:1]
|
1020 |
-
)
|
1021 |
|
1022 |
-
return points, occlusion, expected_dist,
|
1023 |
|
1024 |
def construct_initial_causal_state(self, num_points, num_resolutions=1):
|
1025 |
"""Construct initial causal state."""
|
|
|
22 |
from torch import nn
|
23 |
import torch.nn.functional as F
|
24 |
import numpy as np
|
25 |
+
from einops import rearrange
|
26 |
|
27 |
from models import nets, utils
|
28 |
from models.cmdtop import CMDTop
|
|
|
58 |
return torch.cat([x] + [four_feat], dim=-1)
|
59 |
|
60 |
|
61 |
+
def get_relative_positions(seq_len, reverse=False, device='cuda'):
|
62 |
+
x = torch.arange(seq_len, device=device)[None, :]
|
63 |
+
y = torch.arange(seq_len, device=device)[:, None]
|
64 |
return torch.tril(x - y) if not reverse else torch.triu(y - x)
|
65 |
|
66 |
|
67 |
+
def get_alibi_slope(num_heads, device='cuda'):
|
68 |
x = (24) ** (1 / num_heads)
|
69 |
+
return torch.tensor([1 / x ** (i + 1) for i in range(num_heads)], device=device, dtype=torch.float32).view(-1, 1, 1)
|
70 |
|
71 |
|
72 |
class MultiHeadAttention(nn.Module):
|
|
|
93 |
key_heads = self._linear_projection(key, self.key_size, self.key_proj) # [T, H, K]
|
94 |
value_heads = self._linear_projection(value, self.value_size, self.value_proj) # [T, H, V]
|
95 |
|
96 |
+
device = query.device
|
97 |
+
bias_forward = get_alibi_slope(self.num_heads // 2, device=device) * get_relative_positions(sequence_length, device=device)
|
98 |
bias_forward = bias_forward + torch.triu(torch.full_like(bias_forward, -1e9), diagonal=1)
|
99 |
+
bias_backward = get_alibi_slope(self.num_heads // 2, device=device) * get_relative_positions(sequence_length, reverse=True, device=device)
|
100 |
bias_backward = bias_backward + torch.tril(torch.full_like(bias_backward, -1e9), diagonal=-1)
|
101 |
+
attn_bias = torch.cat([bias_forward, bias_backward], dim=0)
|
102 |
|
103 |
+
attn = F.scaled_dot_product_attention(query_heads, key_heads, value_heads, attn_mask=attn_bias, scale=1 / np.sqrt(self.key_size))
|
104 |
+
attn = attn.permute(0, 2, 1, 3).reshape(batch_size, sequence_length, -1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
return self.final_proj(attn) # [T', D']
|
107 |
|
108 |
def _linear_projection(self, x, head_size, proj_layer):
|
109 |
y = proj_layer(x)
|
110 |
+
batch_size, sequence_length, _= x.shape
|
111 |
+
return y.reshape((batch_size, sequence_length, self.num_heads, head_size)).permute(0, 2, 1, 3)
|
112 |
|
113 |
|
114 |
class Transformer(nn.Module):
|
|
|
487 |
ctx = torch.reshape(ctx, [-1, 3]).to(video.device) # s*s 3
|
488 |
|
489 |
position_support = position_in_grid[..., None, :] + ctx[None, None, ...] # b n s*s 3
|
490 |
+
position_support = rearrange(position_support, 'b n s c -> b (n s) c')
|
491 |
interp_supp = utils.map_coordinates_3d(
|
492 |
feature_grid[i], position_support
|
493 |
)
|
494 |
+
interp_supp = rearrange(interp_supp, 'b (n h w) c -> b n h w c', h=support_size, w=support_size)
|
495 |
|
496 |
position_support_hires = position_in_grid_hires[..., None, :] + ctx[None, None, ...]
|
497 |
+
position_support_hires = rearrange(position_support_hires, 'b n s c -> b (n s) c')
|
498 |
hires_interp_supp = utils.map_coordinates_3d(
|
499 |
hires_feats[i], position_support_hires
|
500 |
)
|
501 |
+
hires_interp_supp = rearrange(hires_interp_supp, 'b (n h w) c -> b n h w c', h=support_size, w=support_size)
|
502 |
|
503 |
position_support_highest = position_in_grid_highest[..., None, :] + ctx[None, None, ...]
|
504 |
+
position_support_highest = rearrange(position_support_highest, 'b n s c -> b (n s) c')
|
505 |
highest_interp_supp = utils.map_coordinates_3d(
|
506 |
highest_feats[i], position_support_highest
|
507 |
)
|
508 |
+
highest_interp_supp = rearrange(highest_interp_supp, 'b (n h w) c -> b n h w c', h=support_size, w=support_size)
|
509 |
|
510 |
interp_features = interp_supp[..., support_size // 2, support_size // 2, :]
|
511 |
hires_interp = hires_interp_supp[..., support_size // 2, support_size // 2, :]
|
|
|
551 |
video.shape[2:4], self.initial_resolution
|
552 |
)
|
553 |
|
554 |
+
all_required_resolutions = []
|
555 |
all_required_resolutions.extend(refinement_resolutions)
|
556 |
|
557 |
feature_grid = []
|
|
|
707 |
)
|
708 |
|
709 |
num_queries = query_features.lowres[0].shape[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
710 |
|
711 |
for ch in range(0, num_queries, query_chunk_size):
|
712 |
+
chunk = query_features.lowres[0][:, ch:ch + query_chunk_size]
|
713 |
+
chunk_hires = query_features.hires[0][:, ch:ch + query_chunk_size]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
714 |
|
715 |
if query_points_in_video is not None:
|
716 |
infer_query_points = query_points_in_video[
|
717 |
+
:, ch : ch + query_chunk_size
|
718 |
]
|
719 |
num_frames = feature_grids.lowres[0].shape[1]
|
720 |
infer_query_points = utils.convert_grid_coordinates(
|
|
|
741 |
for i in range(num_iters):
|
742 |
feature_level = -1
|
743 |
queries = [
|
744 |
+
query_features.hires[feature_level][:, ch:ch + query_chunk_size],
|
745 |
+
query_features.lowres[feature_level][:, ch:ch + query_chunk_size],
|
746 |
+
query_features.highest[feature_level][:, ch:ch + query_chunk_size],
|
747 |
]
|
748 |
supports = [
|
749 |
+
query_features.hires_supp[feature_level][:, ch:ch + query_chunk_size],
|
750 |
+
query_features.lowres_supp[feature_level][:, ch:ch + query_chunk_size],
|
751 |
+
query_features.highest_supp[feature_level][:, ch:ch + query_chunk_size],
|
752 |
]
|
753 |
for _ in range(self.pyramid_level):
|
754 |
queries.append(queries[-1])
|
|
|
766 |
padding=0,
|
767 |
)
|
768 |
)
|
769 |
+
|
770 |
refined = self.refine_pips(
|
771 |
queries,
|
772 |
supports,
|
|
|
779 |
last_iter=mixer_feats,
|
780 |
mixer_iter=i,
|
781 |
resize_hw=feature_grids.resolutions[feature_level],
|
|
|
782 |
get_causal_context=get_causal_context,
|
783 |
cost_volume=cost_volume
|
784 |
)
|
|
|
797 |
points = []
|
798 |
expd = []
|
799 |
for i, _ in enumerate(occ_iters):
|
800 |
+
occlusion.append(torch.cat(occ_iters[i], dim=1))
|
801 |
+
points.append(torch.cat(pts_iters[i], dim=1))
|
802 |
+
expd.append(torch.cat(expd_iters[i], dim=1))
|
803 |
|
804 |
out = dict(
|
805 |
occlusion=occlusion,
|
|
|
849 |
coords2 = coords.unsqueeze(3) + ctx.unsqueeze(0).unsqueeze(0).unsqueeze(0)
|
850 |
neighborhood = utils.map_coordinates_2d(grid, coords2)
|
851 |
|
852 |
+
neighborhood = rearrange(neighborhood, 'b n t (h w) c -> b n t h w c', h=support_size, w=support_size)
|
853 |
patches_input = torch.einsum('bnthwc,bnijc->bnthwij', neighborhood, supp)
|
854 |
+
patches_input = rearrange(patches_input, 'b n t h w i j -> (b n t) h w i j')
|
855 |
patches_emb = self.cmdtop[pyridx](patches_input)
|
856 |
+
patches = rearrange(patches_emb, '(b n t) c -> b n t c', b=neighborhood.shape[0], n=neighborhood.shape[1])
|
857 |
|
858 |
corrs_pyr.append(patches)
|
859 |
corrs_pyr = torch.concatenate(corrs_pyr, dim=-1)
|
|
|
888 |
mlp_input_list.append(rel_pos_emb_input)
|
889 |
mlp_input = torch.cat(mlp_input_list, axis=-1)
|
890 |
|
891 |
+
x = rearrange(mlp_input, 'b n f c -> (b n) f c')
|
|
|
|
|
|
|
|
|
892 |
res = self.torch_pips_mixer(x)
|
893 |
|
894 |
+
res = rearrange(res, '(b n) f c -> b n f c', b=mlp_input.shape[0])
|
895 |
|
896 |
pos_update = utils.convert_grid_coordinates(
|
897 |
res[..., :2],
|
|
|
954 |
shape = cost_volume.shape
|
955 |
batch_size, num_points = cost_volume.shape[1:3]
|
956 |
|
957 |
+
interp_cost = rearrange(cost_volume, 't b n h w -> (t b n) () h w')
|
958 |
interp_cost = F.interpolate(interp_cost, cost_volume_hires.shape[3:], mode='bilinear', align_corners=False)
|
959 |
+
interp_cost = rearrange(interp_cost, '(t b n) () h w -> t b n h w', b=batch_size, n=num_points)
|
|
|
960 |
cost_volume_stack = torch.stack(
|
961 |
[
|
|
|
962 |
interp_cost,
|
963 |
cost_volume_hires,
|
964 |
], dim=-1
|
965 |
)
|
966 |
+
pos = rearrange(cost_volume_stack, 't b n h w c -> (t b n) c h w')
|
967 |
pos = self.cost_conv(pos)
|
968 |
+
pos = rearrange(pos, '(t b n) () h w -> b n t h w', b=batch_size, n=num_points)
|
969 |
|
970 |
pos_sm = pos.reshape(pos.size(0), pos.size(1), pos.size(2), -1)
|
971 |
softmaxed = F.softmax(pos_sm * self.softmax_temperature, dim=-1)
|
|
|
981 |
], dim=-1
|
982 |
)
|
983 |
occlusion = self.occ_linear(occlusion)
|
984 |
+
expected_dist = rearrange(occlusion[..., 1:2], 't b n () -> b n t', t=shape[0])
|
985 |
+
occlusion = rearrange(occlusion[..., 0:1], 't b n () -> b n t', t=shape[0])
|
|
|
|
|
|
|
|
|
986 |
|
987 |
+
return points, occlusion, expected_dist, rearrange(cost_volume, 't b n h w -> b n t h w')
|
988 |
|
989 |
def construct_initial_causal_state(self, num_points, num_resolutions=1):
|
990 |
"""Construct initial causal state."""
|
locotrack_pytorch/models/utils.py
CHANGED
@@ -16,8 +16,6 @@
|
|
16 |
"""Pytorch model utilities."""
|
17 |
import math
|
18 |
from typing import Any, Sequence, Union
|
19 |
-
from einshape.src import abstract_ops
|
20 |
-
from einshape.src import backend
|
21 |
import numpy as np
|
22 |
import torch
|
23 |
import torch.nn.functional as F
|
@@ -101,7 +99,7 @@ def map_coordinates_2d(
|
|
101 |
|
102 |
n, p, t, s, xy = coordinates.shape
|
103 |
y = coordinates.permute(0, 2, 1, 3, 4).reshape(n * t, p, s, xy)
|
104 |
-
y = 2 * (y / h) - 1
|
105 |
y = torch.flip(y, dims=(-1,)).float()
|
106 |
|
107 |
out = F.grid_sample(
|
@@ -231,47 +229,6 @@ def convert_grid_coordinates(
|
|
231 |
return position_in_grid
|
232 |
|
233 |
|
234 |
-
class _JaxBackend(backend.Backend[torch.Tensor]):
|
235 |
-
"""Einshape implementation for PyTorch."""
|
236 |
-
|
237 |
-
# https://github.com/vacancy/einshape/blob/main/einshape/src/pytorch/pytorch_ops.py
|
238 |
-
|
239 |
-
def reshape(self, x: torch.Tensor, op: abstract_ops.Reshape) -> torch.Tensor:
|
240 |
-
return x.reshape(op.shape)
|
241 |
-
|
242 |
-
def transpose(
|
243 |
-
self, x: torch.Tensor, op: abstract_ops.Transpose
|
244 |
-
) -> torch.Tensor:
|
245 |
-
return x.permute(op.perm)
|
246 |
-
|
247 |
-
def broadcast(
|
248 |
-
self, x: torch.Tensor, op: abstract_ops.Broadcast
|
249 |
-
) -> torch.Tensor:
|
250 |
-
shape = op.transform_shape(x.shape)
|
251 |
-
for axis_position in sorted(op.axis_sizes.keys()):
|
252 |
-
x = x.unsqueeze(axis_position)
|
253 |
-
return x.expand(shape)
|
254 |
-
|
255 |
-
|
256 |
-
def einshape(
|
257 |
-
equation: str, value: Union[torch.Tensor, Any], **index_sizes: int
|
258 |
-
) -> torch.Tensor:
|
259 |
-
"""Reshapes `value` according to the given Shape Equation.
|
260 |
-
|
261 |
-
Args:
|
262 |
-
equation: The Shape Equation specifying the index regrouping and reordering.
|
263 |
-
value: Input tensor, or tensor-like object.
|
264 |
-
**index_sizes: Sizes of indices, where they cannot be inferred from
|
265 |
-
`input_shape`.
|
266 |
-
|
267 |
-
Returns:
|
268 |
-
Tensor derived from `value` by reshaping as specified by `equation`.
|
269 |
-
"""
|
270 |
-
if not isinstance(value, torch.Tensor):
|
271 |
-
value = torch.tensor(value)
|
272 |
-
return _JaxBackend().exec(equation, value, value.shape, **index_sizes)
|
273 |
-
|
274 |
-
|
275 |
def generate_default_resolutions(full_size, train_size, num_levels=None):
|
276 |
"""Generate a list of logarithmically-spaced resolutions.
|
277 |
|
|
|
16 |
"""Pytorch model utilities."""
|
17 |
import math
|
18 |
from typing import Any, Sequence, Union
|
|
|
|
|
19 |
import numpy as np
|
20 |
import torch
|
21 |
import torch.nn.functional as F
|
|
|
99 |
|
100 |
n, p, t, s, xy = coordinates.shape
|
101 |
y = coordinates.permute(0, 2, 1, 3, 4).reshape(n * t, p, s, xy)
|
102 |
+
y = 2 * (y / torch.tensor([h, w], device=feats.device)) - 1
|
103 |
y = torch.flip(y, dims=(-1,)).float()
|
104 |
|
105 |
out = F.grid_sample(
|
|
|
229 |
return position_in_grid
|
230 |
|
231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
def generate_default_resolutions(full_size, train_size, num_levels=None):
|
233 |
"""Generate a list of logarithmically-spaced resolutions.
|
234 |
|
requirements.txt
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
|
2 |
gradio==4.40.0
|
3 |
mediapy==1.2.2
|
4 |
opencv-python==4.10.0.84
|
|
|
1 |
+
einops==0.8.0
|
2 |
gradio==4.40.0
|
3 |
mediapy==1.2.2
|
4 |
opencv-python==4.10.0.84
|
weights/locotrack_base.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8a5adbaeb610d1f06adfbc7c9076b66f727d674c0fd1d668890201cf3339736c
|
3 |
+
size 46139570
|
weights/locotrack_small.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:da023594e6d6c05ecad9644efc1467545481cfa899e20730bd9fdce778ffa5ac
|
3 |
+
size 33001026
|