jhj0517 commited on
Commit
aa72c98
1 Parent(s): 63be95c

Except with warning

Browse files
segment-anything-2/sam2/sam2_image_predictor.py CHANGED
@@ -377,6 +377,13 @@ class SAM2ImagePredictor:
377
  # we merge "boxes" and "points" into a single "concat_points" input (where
378
  # boxes are added at the beginning) to sam_prompt_encoder
379
  if concat_points is not None:
 
 
 
 
 
 
 
380
  concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
381
  concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
382
  concat_points = (concat_coords, concat_labels)
 
377
  # we merge "boxes" and "points" into a single "concat_points" input (where
378
  # boxes are added at the beginning) to sam_prompt_encoder
379
  if concat_points is not None:
380
+ if concat_points[0].size(1) > 1 or concat_points[1].size(1) > 1:
381
+ print("Warning: Box and point combination only works if there's "
382
+ "only one dot and one box. Using only the first one...")
383
+ concat_points = (concat_points[0][:, :1, :], concat_points[1][:, :1])
384
+ box_labels = box_labels[:1]
385
+ box_coords = box_coords[:1]
386
+
387
  concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
388
  concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
389
  concat_points = (concat_coords, concat_labels)