jhj0517 commited on
Commit
71fe754
1 Parent(s): cf94415

Catch with warning in video predictor

Browse files
segment-anything-2/sam2/sam2_video_predictor.py CHANGED
@@ -220,6 +220,14 @@ class SAM2VideoPredictor(SAM2Base):
220
  )
221
  if not isinstance(box, torch.Tensor):
222
  box = torch.tensor(box, dtype=torch.float32, device=points.device)
 
 
 
 
 
 
 
 
223
  box_coords = box.reshape(1, 2, 2)
224
  box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
225
  box_labels = box_labels.reshape(1, 2)
 
220
  )
221
  if not isinstance(box, torch.Tensor):
222
  box = torch.tensor(box, dtype=torch.float32, device=points.device)
223
+
224
+ if box.shape[0] > 1:
225
+ box = box[:1, :]
226
+ warnings.warn(
227
+ "Box only works if there's only one. Using only the first one...",
228
+ category=UserWarning,
229
+ )
230
+
231
  box_coords = box.reshape(1, 2, 2)
232
  box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
233
  box_labels = box_labels.reshape(1, 2)