File size: 965 Bytes
8078d22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
import numpy as np
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
class Predictor:
def __init__(self, model_cfg, checkpoint, device):
self.device = device
self.model = build_sam2(model_cfg, checkpoint, device=device)
self.predictor = SAM2ImagePredictor(self.model)
self.image_set = False
def set_image(self, image):
"""Set the image for SAM prediction."""
self.image = image
self.predictor.set_image(image)
self.image_set = True
def predict(self, point_coords, point_labels, multimask_output=False):
"""Run SAM prediction."""
if not self.image_set:
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
return self.predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
multimask_output=multimask_output
)
|