Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,673 Bytes
0691c7d baea9b2 488d99e baea9b2 2fbf361 b32b0a3 488d99e 08430c8 488d99e 08430c8 488d99e 2fbf361 488d99e 2fbf361 baea9b2 08430c8 0691c7d 576e22a 488d99e 0691c7d 488d99e 0691c7d 488d99e 0691c7d 488d99e baea9b2 b32b0a3 0691c7d b32b0a3 0691c7d b32b0a3 488d99e 576e22a b32b0a3 576e22a 0691c7d 576e22a b32b0a3 488d99e b32b0a3 488d99e 0691c7d 488d99e 576e22a 5ae5bca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from typing import Optional
import gradio as gr
import spaces
import supervision as sv
import torch
from PIL import Image
from utils.florence import load_florence_model, run_florence_inference, \
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
from utils.sam import load_sam_image_model, run_sam_inference
DEVICE = torch.device("cuda")
# DEVICE = torch.device("cpu")
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
@spaces.GPU(duration=20)
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def process_image(image_input, text_input) -> Optional[Image.Image]:
if not image_input:
gr.Info("Please upload an image.")
return None
if not text_input:
gr.Info("Please enter a text prompt.")
return None
_, result = run_florence_inference(
model=FLORENCE_MODEL,
processor=FLORENCE_PROCESSOR,
device=DEVICE,
image=image_input,
task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
text=text_input
)
detections = sv.Detections.from_lmm(
lmm=sv.LMM.FLORENCE_2,
result=result,
resolution_wh=image_input.size
)
detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
if len(detections) == 0:
gr.Info("No objects detected.")
return None
return Image.fromarray(detections.mask[0].astype("uint8") * 255)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
image_input_component = gr.Image(
type='pil', label='Upload image')
text_input_component = gr.Textbox(
label='Text prompt',
placeholder='Enter text prompts')
submit_button_component = gr.Button(
value='Submit', variant='primary')
with gr.Column():
image_output_component = gr.Image(label='Output mask')
submit_button_component.click(
fn=process_image,
inputs=[
image_input_component,
text_input_component
],
outputs=[
image_output_component,
]
)
text_input_component.submit(
fn=process_image,
inputs=[
image_input_component,
text_input_component
],
outputs=[
image_output_component,
]
)
demo.launch(debug=False, show_error=True)
|