SkalskiP's picture
go back to GPU ZERO
08430c8
raw
history blame contribute delete
No virus
2.67 kB
from typing import Optional
import gradio as gr
import spaces
import supervision as sv
import torch
from PIL import Image
from utils.florence import load_florence_model, run_florence_inference, \
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
from utils.sam import load_sam_image_model, run_sam_inference
DEVICE = torch.device("cuda")
# DEVICE = torch.device("cpu")
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
@spaces.GPU(duration=20)
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def process_image(image_input, text_input) -> Optional[Image.Image]:
if not image_input:
gr.Info("Please upload an image.")
return None
if not text_input:
gr.Info("Please enter a text prompt.")
return None
_, result = run_florence_inference(
model=FLORENCE_MODEL,
processor=FLORENCE_PROCESSOR,
device=DEVICE,
image=image_input,
task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
text=text_input
)
detections = sv.Detections.from_lmm(
lmm=sv.LMM.FLORENCE_2,
result=result,
resolution_wh=image_input.size
)
detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
if len(detections) == 0:
gr.Info("No objects detected.")
return None
return Image.fromarray(detections.mask[0].astype("uint8") * 255)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
image_input_component = gr.Image(
type='pil', label='Upload image')
text_input_component = gr.Textbox(
label='Text prompt',
placeholder='Enter text prompts')
submit_button_component = gr.Button(
value='Submit', variant='primary')
with gr.Column():
image_output_component = gr.Image(label='Output mask')
submit_button_component.click(
fn=process_image,
inputs=[
image_input_component,
text_input_component
],
outputs=[
image_output_component,
]
)
text_input_component.submit(
fn=process_image,
inputs=[
image_input_component,
text_input_component
],
outputs=[
image_output_component,
]
)
demo.launch(debug=False, show_error=True)