Martin Tomov commited on
Commit
96805b8
·
verified ·
1 Parent(s): 027d733

Update sam_utils.py

Browse files
Files changed (1) hide show
  1. sam_utils.py +16 -4
sam_utils.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- import spaces
3
  import random
4
  from dataclasses import dataclass
5
  from typing import Any, List, Dict, Optional, Union, Tuple
@@ -13,6 +13,7 @@ from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
13
  import gradio as gr
14
  import json
15
 
 
16
  @dataclass
17
  class BoundingBox:
18
  xmin: int
@@ -23,7 +24,6 @@ class BoundingBox:
23
  @property
24
  def xyxy(self) -> List[float]:
25
  return [self.xmin, self.ymin, self.xmax, self.ymax]
26
-
27
  @dataclass
28
  class DetectionResult:
29
  score: float
@@ -63,10 +63,12 @@ def annotate(image: Union[Image.Image, np.ndarray], detection_results: List[Dete
63
 
64
  return cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
65
 
 
66
  def plot_detections(image: Union[Image.Image, np.ndarray], detections: List[DetectionResult], include_bboxes: bool = True) -> np.ndarray:
67
  annotated_image = annotate(image, detections, include_bboxes)
68
  return annotated_image
69
 
 
70
  def load_image(image: Union[str, Image.Image]) -> Image.Image:
71
  if isinstance(image, str) and image.startswith("http"):
72
  image = Image.open(requests.get(image, stream=True).raw).convert("RGB")
@@ -76,6 +78,7 @@ def load_image(image: Union[str, Image.Image]) -> Image.Image:
76
  image = image.convert("RGB")
77
  return image
78
 
 
79
  def get_boxes(detection_results: List[DetectionResult]) -> List[List[List[float]]]:
80
  boxes = []
81
  for result in detection_results:
@@ -83,6 +86,7 @@ def get_boxes(detection_results: List[DetectionResult]) -> List[List[List[float]
83
  boxes.append(xyxy)
84
  return [boxes]
85
 
 
86
  def mask_to_polygon(mask: np.ndarray) -> np.ndarray:
87
  contours, _ = cv2.findContours(
88
  mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
@@ -91,6 +95,7 @@ def mask_to_polygon(mask: np.ndarray) -> np.ndarray:
91
  largest_contour = max(contours, key=cv2.contourArea)
92
  return largest_contour
93
 
 
94
  def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]:
95
  masks = masks.cpu().float().permute(0, 2, 3, 1).mean(
96
  axis=-1).numpy().astype(np.uint8)
@@ -103,7 +108,7 @@ def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> L
103
  np.zeros(shape, dtype=np.uint8), [polygon], 1)
104
  return list(masks)
105
 
106
- @spaces.GPU
107
  def detect(image: Image.Image, labels: List[str], threshold: float = 0.3, detector_id: Optional[str] = None) -> List[Dict[str, Any]]:
108
  detector_id = detector_id if detector_id else "IDEA-Research/grounding-dino-base"
109
  object_detector = pipeline(
@@ -113,7 +118,7 @@ def detect(image: Image.Image, labels: List[str], threshold: float = 0.3, detect
113
  image, candidate_labels=labels, threshold=threshold)
114
  return [DetectionResult.from_dict(result) for result in results]
115
 
116
- @spaces.GPU
117
  def segment(image: Image.Image, detection_results: List[DetectionResult], polygon_refinement: bool = False, segmenter_id: Optional[str] = None) -> List[DetectionResult]:
118
  segmenter_id = segmenter_id if segmenter_id else "martintmv/InsectSAM"
119
  segmentator = AutoModelForMaskGeneration.from_pretrained(
@@ -130,16 +135,19 @@ def segment(image: Image.Image, detection_results: List[DetectionResult], polygo
130
  detection_result.mask = mask
131
  return detection_results
132
 
 
133
  def grounded_segmentation(image: Union[Image.Image, str], labels: List[str], threshold: float = 0.3, polygon_refinement: bool = False, detector_id: Optional[str] = None, segmenter_id: Optional[str] = None) -> Tuple[np.ndarray, List[DetectionResult]]:
134
  image = load_image(image)
135
  detections = detect(image, labels, threshold, detector_id)
136
  detections = segment(image, detections, polygon_refinement, segmenter_id)
137
  return np.array(image), detections
138
 
 
139
  def mask_to_min_max(mask: np.ndarray) -> Tuple[int, int, int, int]:
140
  y, x = np.where(mask)
141
  return x.min(), y.min(), x.max(), y.max()
142
 
 
143
  def extract_and_paste_insect(original_image: np.ndarray, detection: DetectionResult, background: np.ndarray) -> None:
144
  mask = detection.mask
145
  xmin, ymin, xmax, ymax = mask_to_min_max(mask)
@@ -154,6 +162,7 @@ def extract_and_paste_insect(original_image: np.ndarray, detection: DetectionRes
154
  insect_area = background[y_offset:y_end, x_offset:x_end]
155
  insect_area[mask_crop == 1] = insect[mask_crop == 1]
156
 
 
157
  def create_yellow_background_with_insects(image: np.ndarray) -> np.ndarray:
158
  labels = ["insect"]
159
 
@@ -170,6 +179,7 @@ def create_yellow_background_with_insects(image: np.ndarray) -> np.ndarray:
170
  yellow_background = cv2.cvtColor(yellow_background, cv2.COLOR_BGR2RGB)
171
  return yellow_background
172
 
 
173
  def run_length_encoding(mask):
174
  pixels = mask.flatten()
175
  rle = []
@@ -187,6 +197,7 @@ def run_length_encoding(mask):
187
  rle.append(count)
188
  return rle
189
 
 
190
  def detections_to_json(detections):
191
  detections_list = []
192
  for detection in detections:
@@ -203,6 +214,7 @@ def detections_to_json(detections):
203
  detections_list.append(detection_dict)
204
  return detections_list
205
 
 
206
  def crop_bounding_boxes_with_yellow_background(image: np.ndarray, yellow_background: np.ndarray, detections: List[DetectionResult]) -> List[np.ndarray]:
207
  crops = []
208
  for detection in detections:
 
1
  import os
2
+
3
  import random
4
  from dataclasses import dataclass
5
  from typing import Any, List, Dict, Optional, Union, Tuple
 
13
  import gradio as gr
14
  import json
15
 
16
+
17
  @dataclass
18
  class BoundingBox:
19
  xmin: int
 
24
  @property
25
  def xyxy(self) -> List[float]:
26
  return [self.xmin, self.ymin, self.xmax, self.ymax]
 
27
  @dataclass
28
  class DetectionResult:
29
  score: float
 
63
 
64
  return cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
65
 
66
+
67
  def plot_detections(image: Union[Image.Image, np.ndarray], detections: List[DetectionResult], include_bboxes: bool = True) -> np.ndarray:
68
  annotated_image = annotate(image, detections, include_bboxes)
69
  return annotated_image
70
 
71
+
72
  def load_image(image: Union[str, Image.Image]) -> Image.Image:
73
  if isinstance(image, str) and image.startswith("http"):
74
  image = Image.open(requests.get(image, stream=True).raw).convert("RGB")
 
78
  image = image.convert("RGB")
79
  return image
80
 
81
+
82
  def get_boxes(detection_results: List[DetectionResult]) -> List[List[List[float]]]:
83
  boxes = []
84
  for result in detection_results:
 
86
  boxes.append(xyxy)
87
  return [boxes]
88
 
89
+
90
  def mask_to_polygon(mask: np.ndarray) -> np.ndarray:
91
  contours, _ = cv2.findContours(
92
  mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
 
95
  largest_contour = max(contours, key=cv2.contourArea)
96
  return largest_contour
97
 
98
+
99
  def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]:
100
  masks = masks.cpu().float().permute(0, 2, 3, 1).mean(
101
  axis=-1).numpy().astype(np.uint8)
 
108
  np.zeros(shape, dtype=np.uint8), [polygon], 1)
109
  return list(masks)
110
 
111
+
112
  def detect(image: Image.Image, labels: List[str], threshold: float = 0.3, detector_id: Optional[str] = None) -> List[Dict[str, Any]]:
113
  detector_id = detector_id if detector_id else "IDEA-Research/grounding-dino-base"
114
  object_detector = pipeline(
 
118
  image, candidate_labels=labels, threshold=threshold)
119
  return [DetectionResult.from_dict(result) for result in results]
120
 
121
+
122
  def segment(image: Image.Image, detection_results: List[DetectionResult], polygon_refinement: bool = False, segmenter_id: Optional[str] = None) -> List[DetectionResult]:
123
  segmenter_id = segmenter_id if segmenter_id else "martintmv/InsectSAM"
124
  segmentator = AutoModelForMaskGeneration.from_pretrained(
 
135
  detection_result.mask = mask
136
  return detection_results
137
 
138
+
139
  def grounded_segmentation(image: Union[Image.Image, str], labels: List[str], threshold: float = 0.3, polygon_refinement: bool = False, detector_id: Optional[str] = None, segmenter_id: Optional[str] = None) -> Tuple[np.ndarray, List[DetectionResult]]:
140
  image = load_image(image)
141
  detections = detect(image, labels, threshold, detector_id)
142
  detections = segment(image, detections, polygon_refinement, segmenter_id)
143
  return np.array(image), detections
144
 
145
+
146
  def mask_to_min_max(mask: np.ndarray) -> Tuple[int, int, int, int]:
147
  y, x = np.where(mask)
148
  return x.min(), y.min(), x.max(), y.max()
149
 
150
+
151
  def extract_and_paste_insect(original_image: np.ndarray, detection: DetectionResult, background: np.ndarray) -> None:
152
  mask = detection.mask
153
  xmin, ymin, xmax, ymax = mask_to_min_max(mask)
 
162
  insect_area = background[y_offset:y_end, x_offset:x_end]
163
  insect_area[mask_crop == 1] = insect[mask_crop == 1]
164
 
165
+
166
  def create_yellow_background_with_insects(image: np.ndarray) -> np.ndarray:
167
  labels = ["insect"]
168
 
 
179
  yellow_background = cv2.cvtColor(yellow_background, cv2.COLOR_BGR2RGB)
180
  return yellow_background
181
 
182
+
183
  def run_length_encoding(mask):
184
  pixels = mask.flatten()
185
  rle = []
 
197
  rle.append(count)
198
  return rle
199
 
200
+
201
  def detections_to_json(detections):
202
  detections_list = []
203
  for detection in detections:
 
214
  detections_list.append(detection_dict)
215
  return detections_list
216
 
217
+
218
  def crop_bounding_boxes_with_yellow_background(image: np.ndarray, yellow_background: np.ndarray, detections: List[DetectionResult]) -> List[np.ndarray]:
219
  crops = []
220
  for detection in detections: