File size: 965 Bytes
8078d22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import numpy as np
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

class Predictor:
    def __init__(self, model_cfg, checkpoint, device):
        self.device = device
        self.model = build_sam2(model_cfg, checkpoint, device=device)
        self.predictor = SAM2ImagePredictor(self.model)
        self.image_set = False

    def set_image(self, image):
        """Set the image for SAM prediction."""
        self.image = image
        self.predictor.set_image(image)
        self.image_set = True

    def predict(self, point_coords, point_labels, multimask_output=False):
        """Run SAM prediction."""
        if not self.image_set:
            raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
        return self.predictor.predict(
            point_coords=point_coords,
            point_labels=point_labels,
            multimask_output=multimask_output
        )