Spaces:
Runtime error
Runtime error
jhj0517
commited on
Commit
•
aa72c98
1
Parent(s):
63be95c
Except with warning
Browse files
segment-anything-2/sam2/sam2_image_predictor.py
CHANGED
@@ -377,6 +377,13 @@ class SAM2ImagePredictor:
|
|
377 |
# we merge "boxes" and "points" into a single "concat_points" input (where
|
378 |
# boxes are added at the beginning) to sam_prompt_encoder
|
379 |
if concat_points is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
380 |
concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
|
381 |
concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
|
382 |
concat_points = (concat_coords, concat_labels)
|
|
|
377 |
# we merge "boxes" and "points" into a single "concat_points" input (where
|
378 |
# boxes are added at the beginning) to sam_prompt_encoder
|
379 |
if concat_points is not None:
|
380 |
+
if concat_points[0].size(1) > 1 or concat_points[1].size(1) > 1:
|
381 |
+
print("Warning: Box and point combination only works if there's "
|
382 |
+
"only one dot and one box. Using only the first one...")
|
383 |
+
concat_points = (concat_points[0][:, :1, :], concat_points[1][:, :1])
|
384 |
+
box_labels = box_labels[:1]
|
385 |
+
box_coords = box_coords[:1]
|
386 |
+
|
387 |
concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
|
388 |
concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
|
389 |
concat_points = (concat_coords, concat_labels)
|