SkalskiP commited on
Commit
0691c7d
1 Parent(s): 2c71d17

make it return 0 or 1 mask

Browse files
Files changed (1) hide show
  1. app.py +26 -35
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import List
2
 
3
  import gradio as gr
4
  import spaces
@@ -26,43 +26,34 @@ SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
26
  @spaces.GPU
27
  @torch.inference_mode()
28
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
29
- def process_image(
30
- image_input, text_input
31
- ) -> List[Image.Image]:
32
  if not image_input:
33
  gr.Info("Please upload an image.")
34
- return []
35
 
36
  if not text_input:
37
  gr.Info("Please enter a text prompt.")
38
- return []
39
 
40
- texts = [prompt.strip() for prompt in text_input.split(",")]
41
- detections_list = []
42
- for text in texts:
43
- _, result = run_florence_inference(
44
- model=FLORENCE_MODEL,
45
- processor=FLORENCE_PROCESSOR,
46
- device=DEVICE,
47
- image=image_input,
48
- task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
49
- text=text
50
- )
51
- detections = sv.Detections.from_lmm(
52
- lmm=sv.LMM.FLORENCE_2,
53
- result=result,
54
- resolution_wh=image_input.size
55
- )
56
- detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
57
- detections_list.append(detections)
58
-
59
- detections = sv.Detections.merge(detections_list)
60
  detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
61
- return [
62
- Image.fromarray(mask.astype("uint8") * 255)
63
- for mask
64
- in detections.mask
65
- ]
66
 
67
 
68
  with gr.Blocks() as demo:
@@ -72,11 +63,11 @@ with gr.Blocks() as demo:
72
  type='pil', label='Upload image')
73
  text_input_component = gr.Textbox(
74
  label='Text prompt',
75
- placeholder='Enter comma separated text prompts')
76
  submit_button_component = gr.Button(
77
  value='Submit', variant='primary')
78
  with gr.Column():
79
- gallery_output_component = gr.Gallery(label='Output masks')
80
 
81
  submit_button_component.click(
82
  fn=process_image,
@@ -85,7 +76,7 @@ with gr.Blocks() as demo:
85
  text_input_component
86
  ],
87
  outputs=[
88
- gallery_output_component,
89
  ]
90
  )
91
  text_input_component.submit(
@@ -95,7 +86,7 @@ with gr.Blocks() as demo:
95
  text_input_component
96
  ],
97
  outputs=[
98
- gallery_output_component,
99
  ]
100
  )
101
 
 
1
+ from typing import Optional
2
 
3
  import gradio as gr
4
  import spaces
 
26
  @spaces.GPU
27
  @torch.inference_mode()
28
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
29
+ def process_image(image_input, text_input) -> Optional[Image.Image]:
 
 
30
  if not image_input:
31
  gr.Info("Please upload an image.")
32
+ return None
33
 
34
  if not text_input:
35
  gr.Info("Please enter a text prompt.")
36
+ return None
37
 
38
+ _, result = run_florence_inference(
39
+ model=FLORENCE_MODEL,
40
+ processor=FLORENCE_PROCESSOR,
41
+ device=DEVICE,
42
+ image=image_input,
43
+ task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
44
+ text=text_input
45
+ )
46
+ detections = sv.Detections.from_lmm(
47
+ lmm=sv.LMM.FLORENCE_2,
48
+ result=result,
49
+ resolution_wh=image_input.size
50
+ )
51
+ detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
 
 
 
 
 
 
52
  detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
53
+ if len(detections) == 0:
54
+ gr.Info("No objects detected.")
55
+ return None
56
+ return Image.fromarray(detections.mask[0].astype("uint8") * 255)
 
57
 
58
 
59
  with gr.Blocks() as demo:
 
63
  type='pil', label='Upload image')
64
  text_input_component = gr.Textbox(
65
  label='Text prompt',
66
+ placeholder='Enter text prompts')
67
  submit_button_component = gr.Button(
68
  value='Submit', variant='primary')
69
  with gr.Column():
70
+ image_output_component = gr.Image(label='Output mask')
71
 
72
  submit_button_component.click(
73
  fn=process_image,
 
76
  text_input_component
77
  ],
78
  outputs=[
79
+ image_output_component,
80
  ]
81
  )
82
  text_input_component.submit(
 
86
  text_input_component
87
  ],
88
  outputs=[
89
+ image_output_component,
90
  ]
91
  )
92