# IMPORTS import torch import numpy as np from PIL import Image from lang_sam import LangSAM import gradio as gr def run_lang_sam(input_image, text_prompt, model): image = input_image.convert("RGB").resize((256, 256)) masks, _, _, _ = model.predict( image, text_prompt ) masks_int = masks.to(torch.uint8) masks_max, _ = masks_int.max(dim=0, keepdim=True) unified_mask = masks_max.squeeze(0).to(torch.bool) return Image.fromarray( (unified_mask[..., None].numpy() * np.array(image)).astype(np.uint8) ) def setup_gradio_interface(model): block = gr.Blocks() with block: gr.Markdown("

Lang SAM

") with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Input Image") text_prompt = gr.Textbox(label="Enter what you want to segment") run_button = gr.Button(value="Run") with gr.Column(): output_mask = gr.Image(type="numpy", label="Segmentation Mask") run_button.click( fn=lambda image, prompt: run_lang_sam( image, prompt, model, ), inputs=[input_image, text_prompt], outputs=[output_mask], ) gr.Examples( examples=[["bw-image.jpeg", "road"]], inputs=[input_image, text_prompt], outputs=[output_mask], fn=lambda image, prompt: run_lang_sam( image, prompt, model, ), cache_examples=True, label="Try this example input!", ) return block if __name__ == "__main__": model = LangSAM() gradio_interface = setup_gradio_interface(model) gradio_interface.launch(share=False, show_api=False, show_error=True)