# IMPORTS
import torch
import numpy as np
from PIL import Image
from lang_sam import LangSAM
import gradio as gr
def run_lang_sam(input_image, text_prompt, model):
image = input_image.convert("RGB").resize((256, 256))
masks, _, _, _ = model.predict(
image,
text_prompt
)
masks_int = masks.to(torch.uint8)
masks_max, _ = masks_int.max(dim=0, keepdim=True)
unified_mask = masks_max.squeeze(0).to(torch.bool)
return Image.fromarray(
(unified_mask[..., None].numpy() * np.array(image)).astype(np.uint8)
)
def setup_gradio_interface(model):
block = gr.Blocks()
with block:
gr.Markdown("
Lang SAM")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
text_prompt = gr.Textbox(label="Enter what you want to segment")
run_button = gr.Button(value="Run")
with gr.Column():
output_mask = gr.Image(type="numpy", label="Segmentation Mask")
run_button.click(
fn=lambda image, prompt: run_lang_sam(
image, prompt, model,
),
inputs=[input_image, text_prompt],
outputs=[output_mask],
)
gr.Examples(
examples=[["bw-image.jpeg", "road"]],
inputs=[input_image, text_prompt],
outputs=[output_mask],
fn=lambda image, prompt: run_lang_sam(
image, prompt, model,
),
cache_examples=True,
label="Try this example input!",
)
return block
if __name__ == "__main__":
model = LangSAM()
gradio_interface = setup_gradio_interface(model)
gradio_interface.launch(share=False, show_api=False, show_error=True)