curt-park commited on
Commit
bf0a9ad
1 Parent(s): b267f43

Enhance accuracy

Browse files
Files changed (4) hide show
  1. app.py +50 -35
  2. examples/bears.jpg +0 -0
  3. examples/cats.jpg +0 -0
  4. examples/fish.jpg +0 -0
app.py CHANGED
@@ -12,13 +12,13 @@ import PIL
12
  import torch
13
  from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
14
 
15
-
16
  CHECKPOINT_PATH = os.path.join(os.path.expanduser("~"), ".cache", "SAM")
17
  CHECKPOINT_NAME = "sam_vit_h_4b8939.pth"
18
  CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
19
  MODEL_TYPE = "default"
20
- MAX_WIDTH = MAX_HEIGHT = 800
21
- THRESHOLD = 0.05
 
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
 
24
 
@@ -55,23 +55,19 @@ def adjust_image_size(image: np.ndarray) -> np.ndarray:
55
 
56
 
57
  @torch.no_grad()
58
- def get_scores(crops: List[PIL.Image.Image], query: str) -> torch.Tensor:
59
  model, preprocess = load_clip()
60
- preprocessed = [preprocess(crop) for crop in crops]
61
- preprocessed = torch.stack(preprocessed).to(device)
62
- token = clip.tokenize(query).to(device)
63
- img_features = model.encode_image(preprocessed)
64
- txt_features = model.encode_text(token)
65
- img_features /= img_features.norm(dim=-1, keepdim=True)
66
- txt_features /= txt_features.norm(dim=-1, keepdim=True)
67
- similarity = (100 * img_features @ txt_features.T).softmax(0)
68
- return similarity
69
 
70
 
71
  def crop_image(image: np.ndarray, mask: Dict[str, Any]) -> PIL.Image.Image:
72
  x, y, w, h = mask["bbox"]
73
  masked = image * np.expand_dims(mask["segmentation"], -1)
74
- crop = masked[y: y + h, x: x + w]
75
  if h > w:
76
  top, bottom, left, right = 0, 0, (h - w) // 2, (h - w) // 2
77
  else:
@@ -86,11 +82,14 @@ def crop_image(image: np.ndarray, mask: Dict[str, Any]) -> PIL.Image.Image:
86
  cv2.BORDER_CONSTANT,
87
  value=(0, 0, 0),
88
  )
89
- crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
90
  crop = PIL.Image.fromarray(crop)
91
  return crop
92
 
93
 
 
 
 
 
94
  def filter_masks(
95
  image: np.ndarray,
96
  masks: List[Dict[str, Any]],
@@ -99,26 +98,19 @@ def filter_masks(
99
  query: str,
100
  clip_threshold: float,
101
  ) -> List[Dict[str, Any]]:
102
- cropped_masks: List[PIL.Image.Image] = []
103
  filtered_masks: List[Dict[str, Any]] = []
104
 
105
- for mask in masks:
106
  if (
107
  mask["predicted_iou"] < predicted_iou_threshold
108
  or mask["stability_score"] < stability_score_threshold
109
  or image.shape[:2] != mask["segmentation"].shape[:2]
 
 
110
  ):
111
  continue
112
- filtered_masks.append(mask)
113
- cropped_masks.append(crop_image(image, mask))
114
 
115
- if query and filtered_masks:
116
- scores = get_scores(cropped_masks, query)
117
- filtered_masks = [
118
- filtered_masks[i]
119
- for i, score in enumerate(scores)
120
- if score > clip_threshold
121
- ]
122
 
123
  return filtered_masks
124
 
@@ -140,7 +132,7 @@ def draw_masks(
140
  contours, _ = cv2.findContours(
141
  np.uint8(mask["segmentation"]), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
142
  )
143
- cv2.drawContours(image, contours, -1, (255, 0, 0), 2)
144
  return image
145
 
146
 
@@ -152,8 +144,11 @@ def segment(
152
  query: str,
153
  ) -> PIL.ImageFile.ImageFile:
154
  mask_generator = load_mask_generator()
 
 
 
155
  # reduce the size to save gpu memory
156
- image = adjust_image_size(cv2.imread(image_path))
157
  masks = mask_generator.generate(image)
158
  masks = filter_masks(
159
  image,
@@ -164,7 +159,6 @@ def segment(
164
  clip_threshold,
165
  )
166
  image = draw_masks(image, masks)
167
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
168
  image = PIL.Image.fromarray(image)
169
  return image
170
 
@@ -185,31 +179,52 @@ demo = gr.Interface(
185
  [
186
  0.9,
187
  0.8,
188
- 0.15,
189
  os.path.join(os.path.dirname(__file__), "examples/dog.jpg"),
190
- "A dog",
191
  ],
192
  [
193
  0.9,
194
  0.8,
195
- 0.001,
196
  os.path.join(os.path.dirname(__file__), "examples/city.jpg"),
197
  "building",
198
  ],
199
  [
200
  0.9,
201
  0.8,
202
- 0.05,
203
  os.path.join(os.path.dirname(__file__), "examples/food.jpg"),
204
- "spoon",
205
  ],
206
  [
207
  0.9,
208
  0.8,
209
- 0.05,
210
  os.path.join(os.path.dirname(__file__), "examples/horse.jpg"),
211
  "horse",
212
  ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  ],
214
  )
215
 
 
12
  import torch
13
  from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
14
 
 
15
  CHECKPOINT_PATH = os.path.join(os.path.expanduser("~"), ".cache", "SAM")
16
  CHECKPOINT_NAME = "sam_vit_h_4b8939.pth"
17
  CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
18
  MODEL_TYPE = "default"
19
+ MAX_WIDTH = MAX_HEIGHT = 1024
20
+ TOP_K_OBJ = 100
21
+ THRESHOLD = 0.85
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
 
24
 
 
55
 
56
 
57
  @torch.no_grad()
58
+ def get_score(crop: PIL.Image.Image, texts: List[str]) -> torch.Tensor:
59
  model, preprocess = load_clip()
60
+ preprocessed = preprocess(crop).unsqueeze(0).to(device)
61
+ tokens = clip.tokenize(texts).to(device)
62
+ logits_per_image, _ = model(preprocessed, tokens)
63
+ similarity = logits_per_image.softmax(-1).cpu()
64
+ return similarity[0, 0]
 
 
 
 
65
 
66
 
67
  def crop_image(image: np.ndarray, mask: Dict[str, Any]) -> PIL.Image.Image:
68
  x, y, w, h = mask["bbox"]
69
  masked = image * np.expand_dims(mask["segmentation"], -1)
70
+ crop = masked[y : y + h, x : x + w]
71
  if h > w:
72
  top, bottom, left, right = 0, 0, (h - w) // 2, (h - w) // 2
73
  else:
 
82
  cv2.BORDER_CONSTANT,
83
  value=(0, 0, 0),
84
  )
 
85
  crop = PIL.Image.fromarray(crop)
86
  return crop
87
 
88
 
89
+ def get_texts(query: str) -> List[str]:
90
+ return [f"a picture of {query}", "a picture of background"]
91
+
92
+
93
  def filter_masks(
94
  image: np.ndarray,
95
  masks: List[Dict[str, Any]],
 
98
  query: str,
99
  clip_threshold: float,
100
  ) -> List[Dict[str, Any]]:
 
101
  filtered_masks: List[Dict[str, Any]] = []
102
 
103
+ for mask in sorted(masks, key=lambda mask: mask["area"])[-TOP_K_OBJ:]:
104
  if (
105
  mask["predicted_iou"] < predicted_iou_threshold
106
  or mask["stability_score"] < stability_score_threshold
107
  or image.shape[:2] != mask["segmentation"].shape[:2]
108
+ or query
109
+ and get_score(crop_image(image, mask), get_texts(query)) < clip_threshold
110
  ):
111
  continue
 
 
112
 
113
+ filtered_masks.append(mask)
 
 
 
 
 
 
114
 
115
  return filtered_masks
116
 
 
132
  contours, _ = cv2.findContours(
133
  np.uint8(mask["segmentation"]), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
134
  )
135
+ cv2.drawContours(image, contours, -1, (0, 0, 255), 2)
136
  return image
137
 
138
 
 
144
  query: str,
145
  ) -> PIL.ImageFile.ImageFile:
146
  mask_generator = load_mask_generator()
147
+ image = cv2.imread(image_path, cv2.IMREAD_COLOR)
148
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
149
+
150
  # reduce the size to save gpu memory
151
+ image = adjust_image_size(image)
152
  masks = mask_generator.generate(image)
153
  masks = filter_masks(
154
  image,
 
159
  clip_threshold,
160
  )
161
  image = draw_masks(image, masks)
 
162
  image = PIL.Image.fromarray(image)
163
  return image
164
 
 
179
  [
180
  0.9,
181
  0.8,
182
+ 0.99,
183
  os.path.join(os.path.dirname(__file__), "examples/dog.jpg"),
184
+ "dog",
185
  ],
186
  [
187
  0.9,
188
  0.8,
189
+ 0.75,
190
  os.path.join(os.path.dirname(__file__), "examples/city.jpg"),
191
  "building",
192
  ],
193
  [
194
  0.9,
195
  0.8,
196
+ 0.998,
197
  os.path.join(os.path.dirname(__file__), "examples/food.jpg"),
198
+ "strawberry",
199
  ],
200
  [
201
  0.9,
202
  0.8,
203
+ 0.75,
204
  os.path.join(os.path.dirname(__file__), "examples/horse.jpg"),
205
  "horse",
206
  ],
207
+ [
208
+ 0.9,
209
+ 0.8,
210
+ 0.99,
211
+ os.path.join(os.path.dirname(__file__), "examples/bears.jpg"),
212
+ "bear",
213
+ ],
214
+ [
215
+ 0.9,
216
+ 0.8,
217
+ 0.99,
218
+ os.path.join(os.path.dirname(__file__), "examples/cats.jpg"),
219
+ "cat",
220
+ ],
221
+ [
222
+ 0.9,
223
+ 0.8,
224
+ 0.99,
225
+ os.path.join(os.path.dirname(__file__), "examples/fish.jpg"),
226
+ "fish",
227
+ ],
228
  ],
229
  )
230
 
examples/bears.jpg ADDED
examples/cats.jpg ADDED
examples/fish.jpg ADDED