import torch from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor SAM_CHECKPOINT = "checkpoints/sam2_hiera_large.pt" SAM_CONFIG = "sam2_hiera_l.yaml" def load_sam_model( device: torch.device, config: str = SAM_CONFIG, checkpoint: str = SAM_CHECKPOINT ) -> SAM2ImagePredictor: model = build_sam2(config, checkpoint, device=device) return SAM2ImagePredictor(sam_model=model)