Seokju Cho commited on
Commit
6b9382c
1 Parent(s): e11cc45

improve speed

Browse files
.gitattributes CHANGED
@@ -42,3 +42,5 @@ examples/libby.mp4 filter=lfs diff=lfs merge=lfs -text
42
  examples/motocross-jump.mp4 filter=lfs diff=lfs merge=lfs -text
43
  examples/bmx-trees.mp4 filter=lfs diff=lfs merge=lfs -text
44
  examples/parkour.mp4 filter=lfs diff=lfs merge=lfs -text
 
 
 
42
  examples/motocross-jump.mp4 filter=lfs diff=lfs merge=lfs -text
43
  examples/bmx-trees.mp4 filter=lfs diff=lfs merge=lfs -text
44
  examples/parkour.mp4 filter=lfs diff=lfs merge=lfs -text
45
+ weights/locotrack_base.ckpt filter=lfs diff=lfs merge=lfs -text
46
+ weights/locotrack_small.ckpt filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -19,6 +19,10 @@ PREVIEW_WIDTH = 768 # Width of the preview video
19
  VIDEO_INPUT_RESO = (256, 256) # Resolution of the input video
20
  POINT_SIZE = 4 # Size of the query point in the preview video
21
  FRAME_LIMIT = 300 # Limit the number of frames to process
 
 
 
 
22
 
23
 
24
  def get_point(frame_num, video_queried_preview, query_points, query_points_color, query_count, evt: gr.SelectData):
@@ -120,7 +124,7 @@ def extract_feature(video_input, model_size="small"):
120
  device = "cuda" if torch.cuda.is_available() else "cpu"
121
  dtype = torch.bfloat16 if device == "cuda" else torch.float16
122
 
123
- model = load_model(model_size=model_size).to(device)
124
 
125
  video_input = (video_input / 255.0) * 2 - 1
126
  video_input = torch.tensor(video_input).unsqueeze(0).to(device, dtype)
@@ -223,7 +227,7 @@ def track(
223
  video_input = (video_input / 255.0) * 2 - 1
224
  video_input = torch.tensor(video_input).unsqueeze(0).to(device, dtype)
225
 
226
- model = load_model(model_size=model_size).to(device)
227
  with torch.autocast(device_type=device, dtype=dtype):
228
  with torch.no_grad():
229
  output = model(video_input, query_points_tensor, feature_grids=video_feature)
 
19
  VIDEO_INPUT_RESO = (256, 256) # Resolution of the input video
20
  POINT_SIZE = 4 # Size of the query point in the preview video
21
  FRAME_LIMIT = 300 # Limit the number of frames to process
22
+ WEIGHTS_PATH = {
23
+ "small": "./weights/locotrack_small.ckpt",
24
+ "base": "./weights/locotrack_base.ckpt",
25
+ }
26
 
27
 
28
  def get_point(frame_num, video_queried_preview, query_points, query_points_color, query_count, evt: gr.SelectData):
 
124
  device = "cuda" if torch.cuda.is_available() else "cpu"
125
  dtype = torch.bfloat16 if device == "cuda" else torch.float16
126
 
127
+ model = load_model(WEIGHTS_PATH[model_size], model_size=model_size).to(device)
128
 
129
  video_input = (video_input / 255.0) * 2 - 1
130
  video_input = torch.tensor(video_input).unsqueeze(0).to(device, dtype)
 
227
  video_input = (video_input / 255.0) * 2 - 1
228
  video_input = torch.tensor(video_input).unsqueeze(0).to(device, dtype)
229
 
230
+ model = load_model(WEIGHTS_PATH[model_size], model_size=model_size).to(device)
231
  with torch.autocast(device_type=device, dtype=dtype):
232
  with torch.no_grad():
233
  output = model(video_input, query_points_tensor, feature_grids=video_feature)
locotrack_pytorch/models/cmdtop.py CHANGED
@@ -1,6 +1,8 @@
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
 
 
4
  from models import utils
5
 
6
 
@@ -29,8 +31,8 @@ class CMDTop(nn.Module):
29
  """
30
  x: (b, h, w, i, j)
31
  """
32
- out1 = utils.einshape('bhwij->b(ij)hw', x)
33
- out2 = utils.einshape('bhwij->b(hw)ij', x)
34
 
35
  for i in range(len(self.out_channels)):
36
  out1 = self.conv[i](out1)
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
+ from einops import rearrange
5
+
6
  from models import utils
7
 
8
 
 
31
  """
32
  x: (b, h, w, i, j)
33
  """
34
+ out1 = rearrange(x, 'b h w i j -> b (i j) h w')
35
+ out2 = rearrange(x, 'b h w i j -> b (h w) i j')
36
 
37
  for i in range(len(self.out_channels)):
38
  out1 = self.conv[i](out1)
locotrack_pytorch/models/locotrack_model.py CHANGED
@@ -22,6 +22,7 @@ import torch
22
  from torch import nn
23
  import torch.nn.functional as F
24
  import numpy as np
 
25
 
26
  from models import nets, utils
27
  from models.cmdtop import CMDTop
@@ -57,15 +58,15 @@ def posenc(x, min_deg, max_deg, legacy_posenc_order=False):
57
  return torch.cat([x] + [four_feat], dim=-1)
58
 
59
 
60
- def get_relative_positions(seq_len, reverse=False):
61
- x = torch.arange(seq_len)[None, :]
62
- y = torch.arange(seq_len)[:, None]
63
  return torch.tril(x - y) if not reverse else torch.triu(y - x)
64
 
65
 
66
- def get_alibi_slope(num_heads):
67
  x = (24) ** (1 / num_heads)
68
- return torch.tensor([1 / x ** (i + 1) for i in range(num_heads)], dtype=torch.float32).view(-1, 1, 1)
69
 
70
 
71
  class MultiHeadAttention(nn.Module):
@@ -92,31 +93,22 @@ class MultiHeadAttention(nn.Module):
92
  key_heads = self._linear_projection(key, self.key_size, self.key_proj) # [T, H, K]
93
  value_heads = self._linear_projection(value, self.value_size, self.value_proj) # [T, H, V]
94
 
95
- bias_forward = get_alibi_slope(self.num_heads // 2) * get_relative_positions(sequence_length)
 
96
  bias_forward = bias_forward + torch.triu(torch.full_like(bias_forward, -1e9), diagonal=1)
97
- bias_backward = get_alibi_slope(self.num_heads // 2) * get_relative_positions(sequence_length, reverse=True)
98
  bias_backward = bias_backward + torch.tril(torch.full_like(bias_backward, -1e9), diagonal=-1)
99
- attn_bias = torch.cat([bias_forward, bias_backward], dim=0).to(query.device)
100
 
101
- attn_logits = torch.einsum("...thd,...Thd->...htT", query_heads, key_heads)
102
- attn_logits = attn_logits / np.sqrt(self.key_size) + attn_bias
103
-
104
- if mask is not None:
105
- if mask.ndim != attn_logits.ndim:
106
- raise ValueError(f"Mask dimensionality {mask.ndim} must match logits dimensionality {attn_logits.ndim}.")
107
- attn_logits = torch.where(mask, attn_logits, torch.tensor(-1e30))
108
-
109
- attn_weights = F.softmax(attn_logits, dim=-1) # [H, T', T]
110
-
111
- attn = torch.einsum("...htT,...Thd->...thd", attn_weights, value_heads)
112
- attn = attn.reshape(batch_size, sequence_length, -1) # [T', H*V]
113
 
114
  return self.final_proj(attn) # [T', D']
115
 
116
  def _linear_projection(self, x, head_size, proj_layer):
117
  y = proj_layer(x)
118
- *leading_dims, _ = x.shape
119
- return y.reshape((*leading_dims, self.num_heads, head_size))
120
 
121
 
122
  class Transformer(nn.Module):
@@ -495,25 +487,25 @@ class LocoTrack(nn.Module):
495
  ctx = torch.reshape(ctx, [-1, 3]).to(video.device) # s*s 3
496
 
497
  position_support = position_in_grid[..., None, :] + ctx[None, None, ...] # b n s*s 3
498
- position_support = utils.einshape('bnsc->b(ns)c', position_support)
499
  interp_supp = utils.map_coordinates_3d(
500
  feature_grid[i], position_support
501
  )
502
- interp_supp = utils.einshape('b(nhw)c->bnhwc', interp_supp, h=support_size, w=support_size)
503
 
504
  position_support_hires = position_in_grid_hires[..., None, :] + ctx[None, None, ...]
505
- position_support_hires = utils.einshape('bnsc->b(ns)c', position_support_hires)
506
  hires_interp_supp = utils.map_coordinates_3d(
507
  hires_feats[i], position_support_hires
508
  )
509
- hires_interp_supp = utils.einshape('b(nhw)c->bnhwc', hires_interp_supp, h=support_size, w=support_size)
510
 
511
  position_support_highest = position_in_grid_highest[..., None, :] + ctx[None, None, ...]
512
- position_support_highest = utils.einshape('bnsc->b(ns)c', position_support_highest)
513
  highest_interp_supp = utils.map_coordinates_3d(
514
  highest_feats[i], position_support_highest
515
  )
516
- highest_interp_supp = utils.einshape('b(nhw)c->bnhwc', highest_interp_supp, h=support_size, w=support_size)
517
 
518
  interp_features = interp_supp[..., support_size // 2, support_size // 2, :]
519
  hires_interp = hires_interp_supp[..., support_size // 2, support_size // 2, :]
@@ -559,7 +551,7 @@ class LocoTrack(nn.Module):
559
  video.shape[2:4], self.initial_resolution
560
  )
561
 
562
- all_required_resolutions = [self.initial_resolution]
563
  all_required_resolutions.extend(refinement_resolutions)
564
 
565
  feature_grid = []
@@ -715,30 +707,14 @@ class LocoTrack(nn.Module):
715
  )
716
 
717
  num_queries = query_features.lowres[0].shape[1]
718
- if causal_context is None:
719
- perm = torch.randperm(num_queries)
720
- else:
721
- perm = torch.arange(num_queries)
722
-
723
- inv_perm = torch.zeros_like(perm)
724
- inv_perm[perm] = torch.arange(num_queries)
725
 
726
  for ch in range(0, num_queries, query_chunk_size):
727
- perm_chunk = perm[ch : ch + query_chunk_size]
728
- chunk = query_features.lowres[0][:, perm_chunk]
729
- chunk_hires = query_features.hires[0][:, perm_chunk]
730
-
731
- cc_chunk = []
732
- if causal_context is not None:
733
- for d in range(len(causal_context)):
734
- tmp_dict = {}
735
- for k, v in causal_context[d].items():
736
- tmp_dict[k] = v[:, perm_chunk]
737
- cc_chunk.append(tmp_dict)
738
 
739
  if query_points_in_video is not None:
740
  infer_query_points = query_points_in_video[
741
- :, perm[ch : ch + query_chunk_size]
742
  ]
743
  num_frames = feature_grids.lowres[0].shape[1]
744
  infer_query_points = utils.convert_grid_coordinates(
@@ -765,14 +741,14 @@ class LocoTrack(nn.Module):
765
  for i in range(num_iters):
766
  feature_level = -1
767
  queries = [
768
- query_features.hires[feature_level][:, perm_chunk],
769
- query_features.lowres[feature_level][:, perm_chunk],
770
- query_features.highest[feature_level][:, perm_chunk],
771
  ]
772
  supports = [
773
- query_features.hires_supp[feature_level][:, perm_chunk],
774
- query_features.lowres_supp[feature_level][:, perm_chunk],
775
- query_features.highest_supp[feature_level][:, perm_chunk],
776
  ]
777
  for _ in range(self.pyramid_level):
778
  queries.append(queries[-1])
@@ -790,7 +766,7 @@ class LocoTrack(nn.Module):
790
  padding=0,
791
  )
792
  )
793
- cc = cc_chunk[i] if causal_context is not None else None
794
  refined = self.refine_pips(
795
  queries,
796
  supports,
@@ -803,7 +779,6 @@ class LocoTrack(nn.Module):
803
  last_iter=mixer_feats,
804
  mixer_iter=i,
805
  resize_hw=feature_grids.resolutions[feature_level],
806
- causal_context=cc,
807
  get_causal_context=get_causal_context,
808
  cost_volume=cost_volume
809
  )
@@ -822,9 +797,9 @@ class LocoTrack(nn.Module):
822
  points = []
823
  expd = []
824
  for i, _ in enumerate(occ_iters):
825
- occlusion.append(torch.cat(occ_iters[i], dim=1)[:, inv_perm])
826
- points.append(torch.cat(pts_iters[i], dim=1)[:, inv_perm])
827
- expd.append(torch.cat(expd_iters[i], dim=1)[:, inv_perm])
828
 
829
  out = dict(
830
  occlusion=occlusion,
@@ -874,11 +849,11 @@ class LocoTrack(nn.Module):
874
  coords2 = coords.unsqueeze(3) + ctx.unsqueeze(0).unsqueeze(0).unsqueeze(0)
875
  neighborhood = utils.map_coordinates_2d(grid, coords2)
876
 
877
- neighborhood = utils.einshape('bnt(hw)c->bnthwc', neighborhood, h=support_size, w=support_size)
878
  patches_input = torch.einsum('bnthwc,bnijc->bnthwij', neighborhood, supp)
879
- patches_input = utils.einshape('bnthwij->(bnt)hwij', patches_input)
880
  patches_emb = self.cmdtop[pyridx](patches_input)
881
- patches = utils.einshape('(bnt)c->bntc', patches_emb, b=neighborhood.shape[0], n=neighborhood.shape[1])
882
 
883
  corrs_pyr.append(patches)
884
  corrs_pyr = torch.concatenate(corrs_pyr, dim=-1)
@@ -913,14 +888,10 @@ class LocoTrack(nn.Module):
913
  mlp_input_list.append(rel_pos_emb_input)
914
  mlp_input = torch.cat(mlp_input_list, axis=-1)
915
 
916
- x = utils.einshape('bnfc->(bn)fc', mlp_input)
917
-
918
- if causal_context is not None:
919
- for k, v in causal_context.items():
920
- causal_context[k] = utils.einshape('bn...->(bn)...', v)
921
  res = self.torch_pips_mixer(x)
922
 
923
- res = utils.einshape('(bn)fc->bnfc', res, b=mlp_input.shape[0])
924
 
925
  pos_update = utils.convert_grid_coordinates(
926
  res[..., :2],
@@ -983,20 +954,18 @@ class LocoTrack(nn.Module):
983
  shape = cost_volume.shape
984
  batch_size, num_points = cost_volume.shape[1:3]
985
 
986
- interp_cost = utils.einshape('tbnhw->(tbn)1hw', cost_volume)
987
  interp_cost = F.interpolate(interp_cost, cost_volume_hires.shape[3:], mode='bilinear', align_corners=False)
988
- # TODO: not sure if this is correct
989
- interp_cost = utils.einshape('(tbn)1hw->tbnhw', interp_cost, b=batch_size, n=num_points)
990
  cost_volume_stack = torch.stack(
991
  [
992
- # jax.image.resize(cost_volume, cost_volume_hires.shape, method='bilinear'),
993
  interp_cost,
994
  cost_volume_hires,
995
  ], dim=-1
996
  )
997
- pos = utils.einshape('tbnhwc->(tbn)chw', cost_volume_stack)
998
  pos = self.cost_conv(pos)
999
- pos = utils.einshape('(tbn)1hw->bnthw', pos, b=batch_size, n=num_points)
1000
 
1001
  pos_sm = pos.reshape(pos.size(0), pos.size(1), pos.size(2), -1)
1002
  softmaxed = F.softmax(pos_sm * self.softmax_temperature, dim=-1)
@@ -1012,14 +981,10 @@ class LocoTrack(nn.Module):
1012
  ], dim=-1
1013
  )
1014
  occlusion = self.occ_linear(occlusion)
1015
- expected_dist = utils.einshape(
1016
- 'tbn1->bnt', occlusion[..., 1:2]
1017
- )
1018
- occlusion = utils.einshape(
1019
- 'tbn1->bnt', occlusion[..., 0:1]
1020
- )
1021
 
1022
- return points, occlusion, expected_dist, utils.einshape('tbnhw->bnthw', cost_volume)
1023
 
1024
  def construct_initial_causal_state(self, num_points, num_resolutions=1):
1025
  """Construct initial causal state."""
 
22
  from torch import nn
23
  import torch.nn.functional as F
24
  import numpy as np
25
+ from einops import rearrange
26
 
27
  from models import nets, utils
28
  from models.cmdtop import CMDTop
 
58
  return torch.cat([x] + [four_feat], dim=-1)
59
 
60
 
61
+ def get_relative_positions(seq_len, reverse=False, device='cuda'):
62
+ x = torch.arange(seq_len, device=device)[None, :]
63
+ y = torch.arange(seq_len, device=device)[:, None]
64
  return torch.tril(x - y) if not reverse else torch.triu(y - x)
65
 
66
 
67
+ def get_alibi_slope(num_heads, device='cuda'):
68
  x = (24) ** (1 / num_heads)
69
+ return torch.tensor([1 / x ** (i + 1) for i in range(num_heads)], device=device, dtype=torch.float32).view(-1, 1, 1)
70
 
71
 
72
  class MultiHeadAttention(nn.Module):
 
93
  key_heads = self._linear_projection(key, self.key_size, self.key_proj) # [T, H, K]
94
  value_heads = self._linear_projection(value, self.value_size, self.value_proj) # [T, H, V]
95
 
96
+ device = query.device
97
+ bias_forward = get_alibi_slope(self.num_heads // 2, device=device) * get_relative_positions(sequence_length, device=device)
98
  bias_forward = bias_forward + torch.triu(torch.full_like(bias_forward, -1e9), diagonal=1)
99
+ bias_backward = get_alibi_slope(self.num_heads // 2, device=device) * get_relative_positions(sequence_length, reverse=True, device=device)
100
  bias_backward = bias_backward + torch.tril(torch.full_like(bias_backward, -1e9), diagonal=-1)
101
+ attn_bias = torch.cat([bias_forward, bias_backward], dim=0)
102
 
103
+ attn = F.scaled_dot_product_attention(query_heads, key_heads, value_heads, attn_mask=attn_bias, scale=1 / np.sqrt(self.key_size))
104
+ attn = attn.permute(0, 2, 1, 3).reshape(batch_size, sequence_length, -1)
 
 
 
 
 
 
 
 
 
 
105
 
106
  return self.final_proj(attn) # [T', D']
107
 
108
  def _linear_projection(self, x, head_size, proj_layer):
109
  y = proj_layer(x)
110
+ batch_size, sequence_length, _= x.shape
111
+ return y.reshape((batch_size, sequence_length, self.num_heads, head_size)).permute(0, 2, 1, 3)
112
 
113
 
114
  class Transformer(nn.Module):
 
487
  ctx = torch.reshape(ctx, [-1, 3]).to(video.device) # s*s 3
488
 
489
  position_support = position_in_grid[..., None, :] + ctx[None, None, ...] # b n s*s 3
490
+ position_support = rearrange(position_support, 'b n s c -> b (n s) c')
491
  interp_supp = utils.map_coordinates_3d(
492
  feature_grid[i], position_support
493
  )
494
+ interp_supp = rearrange(interp_supp, 'b (n h w) c -> b n h w c', h=support_size, w=support_size)
495
 
496
  position_support_hires = position_in_grid_hires[..., None, :] + ctx[None, None, ...]
497
+ position_support_hires = rearrange(position_support_hires, 'b n s c -> b (n s) c')
498
  hires_interp_supp = utils.map_coordinates_3d(
499
  hires_feats[i], position_support_hires
500
  )
501
+ hires_interp_supp = rearrange(hires_interp_supp, 'b (n h w) c -> b n h w c', h=support_size, w=support_size)
502
 
503
  position_support_highest = position_in_grid_highest[..., None, :] + ctx[None, None, ...]
504
+ position_support_highest = rearrange(position_support_highest, 'b n s c -> b (n s) c')
505
  highest_interp_supp = utils.map_coordinates_3d(
506
  highest_feats[i], position_support_highest
507
  )
508
+ highest_interp_supp = rearrange(highest_interp_supp, 'b (n h w) c -> b n h w c', h=support_size, w=support_size)
509
 
510
  interp_features = interp_supp[..., support_size // 2, support_size // 2, :]
511
  hires_interp = hires_interp_supp[..., support_size // 2, support_size // 2, :]
 
551
  video.shape[2:4], self.initial_resolution
552
  )
553
 
554
+ all_required_resolutions = []
555
  all_required_resolutions.extend(refinement_resolutions)
556
 
557
  feature_grid = []
 
707
  )
708
 
709
  num_queries = query_features.lowres[0].shape[1]
 
 
 
 
 
 
 
710
 
711
  for ch in range(0, num_queries, query_chunk_size):
712
+ chunk = query_features.lowres[0][:, ch:ch + query_chunk_size]
713
+ chunk_hires = query_features.hires[0][:, ch:ch + query_chunk_size]
 
 
 
 
 
 
 
 
 
714
 
715
  if query_points_in_video is not None:
716
  infer_query_points = query_points_in_video[
717
+ :, ch : ch + query_chunk_size
718
  ]
719
  num_frames = feature_grids.lowres[0].shape[1]
720
  infer_query_points = utils.convert_grid_coordinates(
 
741
  for i in range(num_iters):
742
  feature_level = -1
743
  queries = [
744
+ query_features.hires[feature_level][:, ch:ch + query_chunk_size],
745
+ query_features.lowres[feature_level][:, ch:ch + query_chunk_size],
746
+ query_features.highest[feature_level][:, ch:ch + query_chunk_size],
747
  ]
748
  supports = [
749
+ query_features.hires_supp[feature_level][:, ch:ch + query_chunk_size],
750
+ query_features.lowres_supp[feature_level][:, ch:ch + query_chunk_size],
751
+ query_features.highest_supp[feature_level][:, ch:ch + query_chunk_size],
752
  ]
753
  for _ in range(self.pyramid_level):
754
  queries.append(queries[-1])
 
766
  padding=0,
767
  )
768
  )
769
+
770
  refined = self.refine_pips(
771
  queries,
772
  supports,
 
779
  last_iter=mixer_feats,
780
  mixer_iter=i,
781
  resize_hw=feature_grids.resolutions[feature_level],
 
782
  get_causal_context=get_causal_context,
783
  cost_volume=cost_volume
784
  )
 
797
  points = []
798
  expd = []
799
  for i, _ in enumerate(occ_iters):
800
+ occlusion.append(torch.cat(occ_iters[i], dim=1))
801
+ points.append(torch.cat(pts_iters[i], dim=1))
802
+ expd.append(torch.cat(expd_iters[i], dim=1))
803
 
804
  out = dict(
805
  occlusion=occlusion,
 
849
  coords2 = coords.unsqueeze(3) + ctx.unsqueeze(0).unsqueeze(0).unsqueeze(0)
850
  neighborhood = utils.map_coordinates_2d(grid, coords2)
851
 
852
+ neighborhood = rearrange(neighborhood, 'b n t (h w) c -> b n t h w c', h=support_size, w=support_size)
853
  patches_input = torch.einsum('bnthwc,bnijc->bnthwij', neighborhood, supp)
854
+ patches_input = rearrange(patches_input, 'b n t h w i j -> (b n t) h w i j')
855
  patches_emb = self.cmdtop[pyridx](patches_input)
856
+ patches = rearrange(patches_emb, '(b n t) c -> b n t c', b=neighborhood.shape[0], n=neighborhood.shape[1])
857
 
858
  corrs_pyr.append(patches)
859
  corrs_pyr = torch.concatenate(corrs_pyr, dim=-1)
 
888
  mlp_input_list.append(rel_pos_emb_input)
889
  mlp_input = torch.cat(mlp_input_list, axis=-1)
890
 
891
+ x = rearrange(mlp_input, 'b n f c -> (b n) f c')
 
 
 
 
892
  res = self.torch_pips_mixer(x)
893
 
894
+ res = rearrange(res, '(b n) f c -> b n f c', b=mlp_input.shape[0])
895
 
896
  pos_update = utils.convert_grid_coordinates(
897
  res[..., :2],
 
954
  shape = cost_volume.shape
955
  batch_size, num_points = cost_volume.shape[1:3]
956
 
957
+ interp_cost = rearrange(cost_volume, 't b n h w -> (t b n) () h w')
958
  interp_cost = F.interpolate(interp_cost, cost_volume_hires.shape[3:], mode='bilinear', align_corners=False)
959
+ interp_cost = rearrange(interp_cost, '(t b n) () h w -> t b n h w', b=batch_size, n=num_points)
 
960
  cost_volume_stack = torch.stack(
961
  [
 
962
  interp_cost,
963
  cost_volume_hires,
964
  ], dim=-1
965
  )
966
+ pos = rearrange(cost_volume_stack, 't b n h w c -> (t b n) c h w')
967
  pos = self.cost_conv(pos)
968
+ pos = rearrange(pos, '(t b n) () h w -> b n t h w', b=batch_size, n=num_points)
969
 
970
  pos_sm = pos.reshape(pos.size(0), pos.size(1), pos.size(2), -1)
971
  softmaxed = F.softmax(pos_sm * self.softmax_temperature, dim=-1)
 
981
  ], dim=-1
982
  )
983
  occlusion = self.occ_linear(occlusion)
984
+ expected_dist = rearrange(occlusion[..., 1:2], 't b n () -> b n t', t=shape[0])
985
+ occlusion = rearrange(occlusion[..., 0:1], 't b n () -> b n t', t=shape[0])
 
 
 
 
986
 
987
+ return points, occlusion, expected_dist, rearrange(cost_volume, 't b n h w -> b n t h w')
988
 
989
  def construct_initial_causal_state(self, num_points, num_resolutions=1):
990
  """Construct initial causal state."""
locotrack_pytorch/models/utils.py CHANGED
@@ -16,8 +16,6 @@
16
  """Pytorch model utilities."""
17
  import math
18
  from typing import Any, Sequence, Union
19
- from einshape.src import abstract_ops
20
- from einshape.src import backend
21
  import numpy as np
22
  import torch
23
  import torch.nn.functional as F
@@ -101,7 +99,7 @@ def map_coordinates_2d(
101
 
102
  n, p, t, s, xy = coordinates.shape
103
  y = coordinates.permute(0, 2, 1, 3, 4).reshape(n * t, p, s, xy)
104
- y = 2 * (y / h) - 1
105
  y = torch.flip(y, dims=(-1,)).float()
106
 
107
  out = F.grid_sample(
@@ -231,47 +229,6 @@ def convert_grid_coordinates(
231
  return position_in_grid
232
 
233
 
234
- class _JaxBackend(backend.Backend[torch.Tensor]):
235
- """Einshape implementation for PyTorch."""
236
-
237
- # https://github.com/vacancy/einshape/blob/main/einshape/src/pytorch/pytorch_ops.py
238
-
239
- def reshape(self, x: torch.Tensor, op: abstract_ops.Reshape) -> torch.Tensor:
240
- return x.reshape(op.shape)
241
-
242
- def transpose(
243
- self, x: torch.Tensor, op: abstract_ops.Transpose
244
- ) -> torch.Tensor:
245
- return x.permute(op.perm)
246
-
247
- def broadcast(
248
- self, x: torch.Tensor, op: abstract_ops.Broadcast
249
- ) -> torch.Tensor:
250
- shape = op.transform_shape(x.shape)
251
- for axis_position in sorted(op.axis_sizes.keys()):
252
- x = x.unsqueeze(axis_position)
253
- return x.expand(shape)
254
-
255
-
256
- def einshape(
257
- equation: str, value: Union[torch.Tensor, Any], **index_sizes: int
258
- ) -> torch.Tensor:
259
- """Reshapes `value` according to the given Shape Equation.
260
-
261
- Args:
262
- equation: The Shape Equation specifying the index regrouping and reordering.
263
- value: Input tensor, or tensor-like object.
264
- **index_sizes: Sizes of indices, where they cannot be inferred from
265
- `input_shape`.
266
-
267
- Returns:
268
- Tensor derived from `value` by reshaping as specified by `equation`.
269
- """
270
- if not isinstance(value, torch.Tensor):
271
- value = torch.tensor(value)
272
- return _JaxBackend().exec(equation, value, value.shape, **index_sizes)
273
-
274
-
275
  def generate_default_resolutions(full_size, train_size, num_levels=None):
276
  """Generate a list of logarithmically-spaced resolutions.
277
 
 
16
  """Pytorch model utilities."""
17
  import math
18
  from typing import Any, Sequence, Union
 
 
19
  import numpy as np
20
  import torch
21
  import torch.nn.functional as F
 
99
 
100
  n, p, t, s, xy = coordinates.shape
101
  y = coordinates.permute(0, 2, 1, 3, 4).reshape(n * t, p, s, xy)
102
+ y = 2 * (y / torch.tensor([h, w], device=feats.device)) - 1
103
  y = torch.flip(y, dims=(-1,)).float()
104
 
105
  out = F.grid_sample(
 
229
  return position_in_grid
230
 
231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  def generate_default_resolutions(full_size, train_size, num_levels=None):
233
  """Generate a list of logarithmically-spaced resolutions.
234
 
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- einshape==1.0
2
  gradio==4.40.0
3
  mediapy==1.2.2
4
  opencv-python==4.10.0.84
 
1
+ einops==0.8.0
2
  gradio==4.40.0
3
  mediapy==1.2.2
4
  opencv-python==4.10.0.84
weights/locotrack_base.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a5adbaeb610d1f06adfbc7c9076b66f727d674c0fd1d668890201cf3339736c
3
+ size 46139570
weights/locotrack_small.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da023594e6d6c05ecad9644efc1467545481cfa899e20730bd9fdce778ffa5ac
3
+ size 33001026