File size: 1,239 Bytes
57a1960
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import numpy as np
import torch
from segment_anything import SamPredictor, sam_model_registry
from PIL import Image

models = {
    'vit_b': './checkpoints/sam_vit_b_01ec64.pth',
    'vit_l': './checkpoints/sam_vit_l_0b3195.pth',
    'vit_h': './checkpoints/sam_vit_h_4b8939.pth'
}


def get_sam_predictor(model_type='vit_h', device=None, image=None):
    if device is None and torch.cuda.is_available():
        device = 'cuda'
    elif device is None:
        device = 'cpu'
    # sam model
    sam = sam_model_registry[model_type](checkpoint=models[model_type])
    sam = sam.to(device)

    predictor = SamPredictor(sam)
    if image is not None:
        predictor.set_image(image)
    return predictor

def sam_seg(predictor, input_img, input_points, input_labels):
    masks, scores, logits = predictor.predict(
        point_coords=input_points,
        point_labels=input_labels,
        multimask_output=True,
    )

    opt_idx = np.argmax(scores)
    mask = masks[opt_idx]
    out_image = np.zeros((input_img.shape[0], input_img.shape[1], 4), dtype=np.uint8)
    out_image[:, :, :3] = input_img
    out_image[:, :, 3] = mask.astype(np.uint8) * 255
    torch.cuda.empty_cache()
    return Image.fromarray(out_image, mode='RGBA')