from typing import Optional import gradio as gr import spaces import supervision as sv import torch from PIL import Image from utils.florence import load_florence_model, run_florence_inference, \ FLORENCE_OPEN_VOCABULARY_DETECTION_TASK from utils.sam import load_sam_image_model, run_sam_inference DEVICE = torch.device("cuda") # DEVICE = torch.device("cpu") torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() if torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE) SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE) @spaces.GPU(duration=20) @torch.inference_mode() @torch.autocast(device_type="cuda", dtype=torch.bfloat16) def process_image(image_input, text_input) -> Optional[Image.Image]: if not image_input: gr.Info("Please upload an image.") return None if not text_input: gr.Info("Please enter a text prompt.") return None _, result = run_florence_inference( model=FLORENCE_MODEL, processor=FLORENCE_PROCESSOR, device=DEVICE, image=image_input, task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK, text=text_input ) detections = sv.Detections.from_lmm( lmm=sv.LMM.FLORENCE_2, result=result, resolution_wh=image_input.size ) detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections) if len(detections) == 0: gr.Info("No objects detected.") return None return Image.fromarray(detections.mask[0].astype("uint8") * 255) with gr.Blocks() as demo: with gr.Row(): with gr.Column(): image_input_component = gr.Image( type='pil', label='Upload image') text_input_component = gr.Textbox( label='Text prompt', placeholder='Enter text prompts') submit_button_component = gr.Button( value='Submit', variant='primary') with gr.Column(): image_output_component = gr.Image(label='Output mask') submit_button_component.click( fn=process_image, inputs=[ image_input_component, text_input_component ], outputs=[ image_output_component, ] ) text_input_component.submit( fn=process_image, inputs=[ image_input_component, text_input_component ], outputs=[ image_output_component, ] ) demo.launch(debug=False, show_error=True)