jiuface commited on
Commit
6d7bbcd
1 Parent(s): 2a31f6e

add invert_mask

Browse files
Files changed (1) hide show
  1. app.py +49 -9
app.py CHANGED
@@ -29,6 +29,15 @@ if torch.cuda.get_device_properties(0).major >= 8:
29
  FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
30
  SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
31
 
 
 
 
 
 
 
 
 
 
32
  class calculateDuration:
33
  def __init__(self, activity_name=""):
34
  self.activity_name = activity_name
@@ -55,7 +64,7 @@ class calculateDuration:
55
  @spaces.GPU()
56
  @torch.inference_mode()
57
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
58
- def process_image(image_input, image_url, task_prompt, text_prompt=None, dilate=0, merge_masks=False, return_rectangles=False) -> Optional[Image.Image]:
59
 
60
  if not image_input:
61
  gr.Info("Please upload an image.")
@@ -68,9 +77,7 @@ def process_image(image_input, image_url, task_prompt, text_prompt=None, dilate=
68
  if image_url:
69
  with calculateDuration("Download Image"):
70
  print("start to fetch image from url", image_url)
71
- response = requests.get(image_url)
72
- response.raise_for_status()
73
- image_input = PIL.Image.open(BytesIO(response.content))
74
  print("fetch image success")
75
 
76
  # start to parse prompt
@@ -131,10 +138,30 @@ def process_image(image_input, image_url, task_prompt, text_prompt=None, dilate=
131
  for mask in images:
132
  merged_mask = cv2.bitwise_or(merged_mask, mask)
133
  images = [merged_mask]
 
 
 
134
 
135
  return images
136
 
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  with gr.Blocks() as demo:
139
  with gr.Row():
140
  with gr.Column():
@@ -143,18 +170,31 @@ with gr.Blocks() as demo:
143
  task_prompt = gr.Dropdown(
144
  ['<OD>', '<CAPTION_TO_PHRASE_GROUNDING>', '<DENSE_REGION_CAPTION>', '<REGION_PROPOSAL>', '<OCR_WITH_REGION>', '<REFERRING_EXPRESSION_SEGMENTATION>', '<REGION_TO_SEGMENTATION>', '<OPEN_VOCABULARY_DETECTION>', '<REGION_TO_CATEGORY>', '<REGION_TO_DESCRIPTION>'], value="<CAPTION_TO_PHRASE_GROUNDING>", label="Task Prompt", info="task prompts"
145
  )
146
- dilate = gr.Slider(label="dilate mask", minimum=0, maximum=50, value=10, step=1)
147
- merge_masks = gr.Checkbox(label="Merge masks", value=False)
148
- return_rectangles = gr.Checkbox(label="Return Rectangles", value=False)
 
 
149
  text_prompt = gr.Textbox(label='Text prompt', placeholder='Enter text prompts')
150
  submit_button = gr.Button(value='Submit', variant='primary')
151
  with gr.Column():
152
  image_gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", columns=[3], rows=[1], object_fit="contain", height="auto")
153
  # json_result = gr.Code(label="JSON Result", language="json")
154
-
 
 
 
 
 
 
 
 
 
 
 
155
  submit_button.click(
156
  fn=process_image,
157
- inputs=[image, image_url, task_prompt, text_prompt, dilate, merge_masks, return_rectangles],
158
  outputs=[image_gallery],
159
  show_api=False
160
  )
 
29
  FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
30
  SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
31
 
32
+ def fetch_image_from_url(image_url):
33
+ try:
34
+ response = requests.get(image_url)
35
+ response.raise_for_status()
36
+ img = Image.open(BytesIO(response.content))
37
+ return img
38
+ except Exception as e:
39
+ return None
40
+
41
  class calculateDuration:
42
  def __init__(self, activity_name=""):
43
  self.activity_name = activity_name
 
64
  @spaces.GPU()
65
  @torch.inference_mode()
66
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
67
+ def process_image(image_input, image_url, task_prompt, text_prompt=None, dilate=0, merge_masks=False, return_rectangles=False, invert_mask=False) -> Optional[Image.Image]:
68
 
69
  if not image_input:
70
  gr.Info("Please upload an image.")
 
77
  if image_url:
78
  with calculateDuration("Download Image"):
79
  print("start to fetch image from url", image_url)
80
+ image_input = fetch_image_from_url(image_url)
 
 
81
  print("fetch image success")
82
 
83
  # start to parse prompt
 
138
  for mask in images:
139
  merged_mask = cv2.bitwise_or(merged_mask, mask)
140
  images = [merged_mask]
141
+ if invert_mask:
142
+ with calculateDuration("invert mask colors"):
143
+ images = [cv2.bitwise_not(mask) for mask in images]
144
 
145
  return images
146
 
147
 
148
+ def update_task_info(task_prompt):
149
+ task_info = {
150
+ '<OD>': "Object Detection: Detect objects in the image.",
151
+ '<CAPTION_TO_PHRASE_GROUNDING>': "Phrase Grounding: Link phrases in captions to corresponding regions in the image.",
152
+ '<DENSE_REGION_CAPTION>': "Dense Region Captioning: Generate captions for different regions in the image.",
153
+ '<REGION_PROPOSAL>': "Region Proposal: Propose potential regions of interest in the image.",
154
+ '<OCR_WITH_REGION>': "OCR with Region: Extract text and its bounding regions from the image.",
155
+ '<REFERRING_EXPRESSION_SEGMENTATION>': "Referring Expression Segmentation: Segment the region referred to by a natural language expression.",
156
+ '<REGION_TO_SEGMENTATION>': "Region to Segmentation: Convert region proposals into detailed segmentations.",
157
+ '<OPEN_VOCABULARY_DETECTION>': "Open Vocabulary Detection: Detect objects based on open vocabulary concepts.",
158
+ '<REGION_TO_CATEGORY>': "Region to Category: Assign categories to proposed regions.",
159
+ '<REGION_TO_DESCRIPTION>': "Region to Description: Generate descriptive text for specified regions."
160
+ }
161
+ return task_info.get(task_prompt, "Select a task to see its description.")
162
+
163
+
164
+
165
  with gr.Blocks() as demo:
166
  with gr.Row():
167
  with gr.Column():
 
170
  task_prompt = gr.Dropdown(
171
  ['<OD>', '<CAPTION_TO_PHRASE_GROUNDING>', '<DENSE_REGION_CAPTION>', '<REGION_PROPOSAL>', '<OCR_WITH_REGION>', '<REFERRING_EXPRESSION_SEGMENTATION>', '<REGION_TO_SEGMENTATION>', '<OPEN_VOCABULARY_DETECTION>', '<REGION_TO_CATEGORY>', '<REGION_TO_DESCRIPTION>'], value="<CAPTION_TO_PHRASE_GROUNDING>", label="Task Prompt", info="task prompts"
172
  )
173
+ task_info = gr.Textbox(label='Task Info', value=update_task_info("<CAPTION_TO_PHRASE_GROUNDING>"), interactive=False)
174
+ dilate = gr.Slider(label="dilate mask", minimum=0, maximum=50, value=10, step=1, info="The dilate parameter controls the expansion of the mask's white areas by a specified number of pixels. Increasing this value will enlarge the white regions, which can help in smoothing out the mask's edges or covering more area in the segmentation.")
175
+ merge_masks = gr.Checkbox(label="Merge masks", value=False, info="The merge_masks parameter combines all the individual masks into a single mask. When enabled, the separate masks generated for different objects or regions will be merged into one unified mask, which can simplify further processing or visualization.")
176
+ return_rectangles = gr.Checkbox(label="Return Rectangles", value=False, info="The return_rectangles parameter, when enabled, generates masks as filled white rectangles corresponding to the bounding boxes of detected objects, rather than detailed contours or segments. This option is useful for simpler, box-based visualizations.")
177
+ invert_mask = gr.Checkbox(label="invert mask", value=False, info="The invert_mask option allows you to reverse the colors of the generated mask, changing black areas to white and white areas to black. This can be useful for visualizing or processing the mask in a different context.")
178
  text_prompt = gr.Textbox(label='Text prompt', placeholder='Enter text prompts')
179
  submit_button = gr.Button(value='Submit', variant='primary')
180
  with gr.Column():
181
  image_gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", columns=[3], rows=[1], object_fit="contain", height="auto")
182
  # json_result = gr.Code(label="JSON Result", language="json")
183
+
184
+ task_prompt.change(
185
+ fn=update_task_info,
186
+ inputs=[task_prompt],
187
+ outputs=[task_info]
188
+ )
189
+ image_url.change(
190
+ fn=fetch_image_from_url,
191
+ inputs=[image_url],
192
+ outputs=[image]
193
+ )
194
+
195
  submit_button.click(
196
  fn=process_image,
197
+ inputs=[image, image_url, task_prompt, text_prompt, dilate, merge_masks, return_rectangles, invert_mask],
198
  outputs=[image_gallery],
199
  show_api=False
200
  )