Seokju Cho
commited on
Commit
•
7798ce5
1
Parent(s):
e4d9159
fix
Browse files
app.py
CHANGED
@@ -133,10 +133,10 @@ def extract_feature(video_input, model_size="small"):
|
|
133 |
feature = model.get_feature_grids(video_input)
|
134 |
|
135 |
feature = FeatureGrids(
|
136 |
-
lowres=(feature.lowres[
|
137 |
-
hires=(feature.hires[
|
138 |
-
highest=(feature.highest[
|
139 |
-
resolutions=feature.resolutions,
|
140 |
)
|
141 |
return feature
|
142 |
|
@@ -203,10 +203,10 @@ def track(
|
|
203 |
dtype = torch.bfloat16 if device == "cuda" else torch.float16
|
204 |
|
205 |
video_feature = FeatureGrids(
|
206 |
-
lowres=(video_feature.lowres[
|
207 |
-
hires=(video_feature.hires[
|
208 |
-
highest=(video_feature.highest[
|
209 |
-
resolutions=video_feature.resolutions,
|
210 |
)
|
211 |
|
212 |
# Convert query points to tensor, normalize to input resolution
|
|
|
133 |
feature = model.get_feature_grids(video_input)
|
134 |
|
135 |
feature = FeatureGrids(
|
136 |
+
lowres=(feature.lowres[-1].cpu(),),
|
137 |
+
hires=(feature.hires[-1].cpu(),),
|
138 |
+
highest=(feature.highest[-1].cpu(),),
|
139 |
+
resolutions=(feature.resolutions[-1],),
|
140 |
)
|
141 |
return feature
|
142 |
|
|
|
203 |
dtype = torch.bfloat16 if device == "cuda" else torch.float16
|
204 |
|
205 |
video_feature = FeatureGrids(
|
206 |
+
lowres=(video_feature.lowres[-1].to(device, dtype),),
|
207 |
+
hires=(video_feature.hires[-1].to(device, dtype),),
|
208 |
+
highest=(video_feature.highest[-1].to(device, dtype),),
|
209 |
+
resolutions=(video_feature.resolutions[-1],),
|
210 |
)
|
211 |
|
212 |
# Convert query points to tensor, normalize to input resolution
|