jiuface commited on
Commit
f70c323
1 Parent(s): 03684b8
Files changed (1) hide show
  1. app.py +85 -56
app.py CHANGED
@@ -27,11 +27,34 @@ if torch.cuda.get_device_properties(0).major >= 8:
27
  FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
28
  SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  @spaces.GPU(duration=20)
32
  @torch.inference_mode()
33
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
34
- def process_image(image_input, image_url, task_prompt, text_prompt=None, dilate=0, merge_masks=False, return_rectangles=False) -> Optional[Image.Image]:
35
  if not image_input:
36
  gr.Info("Please upload an image.")
37
  return None
@@ -40,62 +63,69 @@ def process_image(image_input, image_url, task_prompt, text_prompt=None, dilate=
40
  return None
41
 
42
  if image_url:
43
- print("start to fetch image from url", image_url)
44
- response = requests.get(image_url)
45
- response.raise_for_status()
46
- image_input = PIL.Image.open(BytesIO(response.content))
47
- print("fetch image success")
 
48
  # start to parse prompt
49
- _, result = run_florence_inference(
50
- model=FLORENCE_MODEL,
51
- processor=FLORENCE_PROCESSOR,
52
- device=DEVICE,
53
- image=image_input,
54
- task=task_prompt,
55
- text=text_prompt
56
- )
57
- # start to dectect
58
- detections = sv.Detections.from_lmm(
59
- lmm=sv.LMM.FLORENCE_2,
60
- result=result,
61
- resolution_wh=image_input.size
62
- )
63
- json_result = json.dumps({"bbox": detections.xyxy, "data": detections.data})
 
 
 
64
  images = []
65
  if return_rectangles:
66
- # create mask in rectangle
67
- (image_width, image_height) = image_input.size
68
- bboxes = detections.xyxy
69
- merge_mask_image = np.zeros((image_height, image_width), dtype=np.uint8)
70
- for bbox in bboxes:
71
- x1, y1, x2, y2 = map(int, bbox)
72
- cv2.rectangle(merge_mask_image, (x1, y1), (x2, y2), 255, thickness=cv2.FILLED)
73
- clip_mask = np.zeros((image_height, image_width), dtype=np.uint8)
74
- cv2.rectangle(clip_mask, (x1, y1), (x2, y2), 255, thickness=cv2.FILLED)
75
- images.append(clip_mask)
76
- if merge_masks:
77
- images = [merge_mask_image] + images
 
 
78
  else:
79
- # using sam generate segments images
80
- detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
81
- if len(detections) == 0:
82
- gr.Info("No objects detected.")
83
- return None
84
- print("mask generated:", len(detections.mask))
85
- kernel_size = dilate
86
- kernel = np.ones((kernel_size, kernel_size), np.uint8)
87
-
88
- for i in range(len(detections.mask)):
89
- mask = detections.mask[i].astype(np.uint8) * 255
90
- if dilate > 0:
91
- mask = cv2.dilate(mask, kernel, iterations=1)
92
- images.append(mask)
93
-
94
- if merge_masks:
95
- merged_mask = np.zeros_like(images[0], dtype=np.uint8)
96
- for mask in images:
97
- merged_mask = cv2.bitwise_or(merged_mask, mask)
98
- images = [merged_mask]
 
99
 
100
  return [images, json_result]
101
 
@@ -115,13 +145,12 @@ with gr.Blocks() as demo:
115
  submit_button = gr.Button(value='Submit', variant='primary')
116
  with gr.Column():
117
  image_gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", columns=[3], rows=[1], object_fit="contain", height="auto")
118
- json_result = gr.Code(label="JSON Result", language="json")
119
 
120
- print(image, image_url, task_prompt, text_prompt, image_gallery)
121
  submit_button.click(
122
  fn=process_image,
123
  inputs=[image, image_url, task_prompt, text_prompt, dilate, merge_masks, return_rectangles],
124
- outputs=[image_gallery, json_result],
125
  show_api=False
126
  )
127
 
 
27
  FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
28
  SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
29
 
30
+ class calculateDuration:
31
+ def __init__(self, activity_name=""):
32
+ self.activity_name = activity_name
33
+
34
+ def __enter__(self):
35
+ self.start_time = time.time()
36
+ self.start_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.start_time))
37
+ print(f"Activity: {self.activity_name}, Start time: {self.start_time_formatted}")
38
+ return self
39
+
40
+ def __exit__(self, exc_type, exc_value, traceback):
41
+ self.end_time = time.time()
42
+ self.elapsed_time = self.end_time - self.start_time
43
+ self.end_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.end_time))
44
+
45
+ if self.activity_name:
46
+ print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
47
+ else:
48
+ print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
49
+
50
+ print(f"Activity: {self.activity_name}, End time: {self.start_time_formatted}")
51
+
52
+
53
 
54
  @spaces.GPU(duration=20)
55
  @torch.inference_mode()
56
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
57
+ def process_image(image_input, image_url, task_prompt, text_prompt=None, dilate=0, merge_masks=False, return_rectangles=False, progress=gr.Progress(track_tqdm=True)) -> Optional[Image.Image]:
58
  if not image_input:
59
  gr.Info("Please upload an image.")
60
  return None
 
63
  return None
64
 
65
  if image_url:
66
+ with calculateDuration("Download Image"):
67
+ print("start to fetch image from url", image_url)
68
+ response = requests.get(image_url)
69
+ response.raise_for_status()
70
+ image_input = PIL.Image.open(BytesIO(response.content))
71
+ print("fetch image success")
72
  # start to parse prompt
73
+ with calculateDuration("run_florence_inference"):
74
+ _, result = run_florence_inference(
75
+ model=FLORENCE_MODEL,
76
+ processor=FLORENCE_PROCESSOR,
77
+ device=DEVICE,
78
+ image=image_input,
79
+ task=task_prompt,
80
+ text=text_prompt
81
+ )
82
+ with calculateDuration("sv.Detections"):
83
+ # start to dectect
84
+ detections = sv.Detections.from_lmm(
85
+ lmm=sv.LMM.FLORENCE_2,
86
+ result=result,
87
+ resolution_wh=image_input.size
88
+ )
89
+ # json_result = json.dumps([])
90
+ # print(detections)
91
  images = []
92
  if return_rectangles:
93
+ with calculateDuration("generate rectangle mask"):
94
+ # create mask in rectangle
95
+ (image_width, image_height) = image_input.size
96
+ bboxes = detections.xyxy
97
+ merge_mask_image = np.zeros((image_height, image_width), dtype=np.uint8)
98
+
99
+ for bbox in bboxes:
100
+ x1, y1, x2, y2 = map(int, bbox)
101
+ cv2.rectangle(merge_mask_image, (x1, y1), (x2, y2), 255, thickness=cv2.FILLED)
102
+ clip_mask = np.zeros((image_height, image_width), dtype=np.uint8)
103
+ cv2.rectangle(clip_mask, (x1, y1), (x2, y2), 255, thickness=cv2.FILLED)
104
+ images.append(clip_mask)
105
+ if merge_masks:
106
+ images = [merge_mask_image] + images
107
  else:
108
+ with calculateDuration("generate segmenet mask"):
109
+ # using sam generate segments images
110
+ detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
111
+ if len(detections) == 0:
112
+ gr.Info("No objects detected.")
113
+ return None
114
+ print("mask generated:", len(detections.mask))
115
+ kernel_size = dilate
116
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
117
+
118
+ for i in range(len(detections.mask)):
119
+ mask = detections.mask[i].astype(np.uint8) * 255
120
+ if dilate > 0:
121
+ mask = cv2.dilate(mask, kernel, iterations=1)
122
+ images.append(mask)
123
+
124
+ if merge_masks:
125
+ merged_mask = np.zeros_like(images[0], dtype=np.uint8)
126
+ for mask in images:
127
+ merged_mask = cv2.bitwise_or(merged_mask, mask)
128
+ images = [merged_mask]
129
 
130
  return [images, json_result]
131
 
 
145
  submit_button = gr.Button(value='Submit', variant='primary')
146
  with gr.Column():
147
  image_gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", columns=[3], rows=[1], object_fit="contain", height="auto")
148
+ # json_result = gr.Code(label="JSON Result", language="json")
149
 
 
150
  submit_button.click(
151
  fn=process_image,
152
  inputs=[image, image_url, task_prompt, text_prompt, dilate, merge_masks, return_rectangles],
153
+ outputs=[image_gallery],
154
  show_api=False
155
  )
156