Seokju Cho commited on
Commit
7798ce5
1 Parent(s): e4d9159
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -133,10 +133,10 @@ def extract_feature(video_input, model_size="small"):
133
  feature = model.get_feature_grids(video_input)
134
 
135
  feature = FeatureGrids(
136
- lowres=(feature.lowres[0].cpu(),),
137
- hires=(feature.hires[0].cpu(),),
138
- highest=(feature.highest[0].cpu(),),
139
- resolutions=feature.resolutions,
140
  )
141
  return feature
142
 
@@ -203,10 +203,10 @@ def track(
203
  dtype = torch.bfloat16 if device == "cuda" else torch.float16
204
 
205
  video_feature = FeatureGrids(
206
- lowres=(video_feature.lowres[0].to(device, dtype),),
207
- hires=(video_feature.hires[0].to(device, dtype),),
208
- highest=(video_feature.highest[0].to(device, dtype),),
209
- resolutions=video_feature.resolutions,
210
  )
211
 
212
  # Convert query points to tensor, normalize to input resolution
 
133
  feature = model.get_feature_grids(video_input)
134
 
135
  feature = FeatureGrids(
136
+ lowres=(feature.lowres[-1].cpu(),),
137
+ hires=(feature.hires[-1].cpu(),),
138
+ highest=(feature.highest[-1].cpu(),),
139
+ resolutions=(feature.resolutions[-1],),
140
  )
141
  return feature
142
 
 
203
  dtype = torch.bfloat16 if device == "cuda" else torch.float16
204
 
205
  video_feature = FeatureGrids(
206
+ lowres=(video_feature.lowres[-1].to(device, dtype),),
207
+ hires=(video_feature.hires[-1].to(device, dtype),),
208
+ highest=(video_feature.highest[-1].to(device, dtype),),
209
+ resolutions=(video_feature.resolutions[-1],),
210
  )
211
 
212
  # Convert query points to tensor, normalize to input resolution