SAM2_Cutie / app.py
Satyajithchary's picture
Update app.py
0e13941 verified
raw
history blame
1.51 kB
#pip install --upgrade pip
import torch
from sam2.sam2_image_predictor import SAM2ImagePredictor
predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
import gradio as gr
import torch
import numpy as np
from PIL import Image
from segment_anything_2 import SAM2ImagePredictor, build_sam2
# Load your model
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
#device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = "checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
model = build_sam2(model_cfg, checkpoint, device=device)
predictor = SAM2ImagePredictor(model)
def process_image(image, input_points, input_labels):
input_point = np.array([input_points])
input_label = np.array([input_labels])
# Use predictor to predict mask
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
return Image.fromarray(masks[0].astype(np.uint8))
# Define Gradio Interface
image_input = gr.inputs.Image(type="pil")
point_input = gr.inputs.Number(label="Point X,Y (comma-separated)")
label_input = gr.inputs.Radio([0, 1], label="Label (0 for background, 1 for object)")
iface = gr.Interface(
fn=process_image,
inputs=[image_input, point_input, label_input],
outputs="image",
description="Interactive tool for mask prediction with Segment Anything 2 and CUTIE"
)
iface.launch()