import numpy as np import torch from segment_anything import SamPredictor, sam_model_registry from PIL import Image models = { 'vit_b': './checkpoints/sam_vit_b_01ec64.pth', 'vit_l': './checkpoints/sam_vit_l_0b3195.pth', 'vit_h': './checkpoints/sam_vit_h_4b8939.pth' } def get_sam_predictor(model_type='vit_h', device=None, image=None): if device is None and torch.cuda.is_available(): device = 'cuda' elif device is None: device = 'cpu' # sam model sam = sam_model_registry[model_type](checkpoint=models[model_type]) sam = sam.to(device) predictor = SamPredictor(sam) if image is not None: predictor.set_image(image) return predictor def sam_seg(predictor, input_img, input_points, input_labels): masks, scores, logits = predictor.predict( point_coords=input_points, point_labels=input_labels, multimask_output=True, ) opt_idx = np.argmax(scores) mask = masks[opt_idx] out_image = np.zeros((input_img.shape[0], input_img.shape[1], 4), dtype=np.uint8) out_image[:, :, :3] = input_img out_image[:, :, 3] = mask.astype(np.uint8) * 255 torch.cuda.empty_cache() return Image.fromarray(out_image, mode='RGBA')