DawnC commited on
Commit
c2b1ae2
·
1 Parent(s): a84fce6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -76
app.py CHANGED
@@ -243,14 +243,12 @@ def _predict_single_dog(image):
243
  # print(error_msg) # 添加日誌輸出
244
  # return error_msg, None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), None
245
 
246
- async def detect_multiple_dogs(image, conf_threshold=0.1, iou_threshold=0.5, merge_threshold=0.2):
247
  results = model_yolo(image, conf=conf_threshold, iou=iou_threshold)[0]
248
  dogs = []
249
- boxes = []
250
- confidences = []
251
 
252
  image_area = image.width * image.height
253
- min_area_ratio = 0.005 # 最小檢測面積佔整個圖像的比例
254
 
255
  for box in results.boxes:
256
  if box.cls == 16: # COCO 數據集中狗的類別是 16
@@ -258,72 +256,34 @@ async def detect_multiple_dogs(image, conf_threshold=0.1, iou_threshold=0.5, mer
258
  area = (xyxy[2] - xyxy[0]) * (xyxy[3] - xyxy[1])
259
  if area / image_area >= min_area_ratio:
260
  confidence = box.conf.item()
261
- boxes.append(torch.tensor(xyxy))
262
- confidences.append(confidence)
263
 
264
- if boxes:
265
- boxes = torch.stack(boxes)
266
- confidences = torch.tensor(confidences)
 
 
 
267
 
268
- # 應用軟 NMS
269
- keep = soft_nms(boxes, confidences, iou_threshold=iou_threshold, sigma=0.5)
270
-
271
- for i in keep:
272
- xyxy = boxes[i].tolist()
273
- confidence = confidences[i].item()
274
- cropped_image = image.crop((xyxy[0], xyxy[1], xyxy[2], xyxy[3]))
275
- dogs.append((cropped_image, confidence, xyxy))
276
-
277
- # 合併鄰近的邊界框
278
- merged_dogs = []
279
- while dogs:
280
- base_dog = dogs.pop(0)
281
- base_box = torch.tensor(base_dog[2])
282
- to_merge = [base_dog]
283
-
284
- i = 0
285
- while i < len(dogs):
286
- compare_box = torch.tensor(dogs[i][2])
287
- iou = box_iou(base_box.unsqueeze(0), compare_box.unsqueeze(0)).item()
288
- if iou > merge_threshold:
289
- to_merge.append(dogs.pop(i))
290
- else:
291
- i += 1
292
-
293
- if len(to_merge) == 1:
294
- merged_dogs.append(base_dog)
295
  else:
296
- merged_box = torch.cat([torch.tensor(dog[2]).unsqueeze(0) for dog in to_merge]).mean(0)
297
- merged_confidence = max(dog[1] for dog in to_merge)
298
- merged_image = image.crop(merged_box.tolist())
299
- merged_dogs.append((merged_image, merged_confidence, merged_box.tolist()))
300
-
301
- # 後處理:限制檢測到的狗狗數量
302
- if len(merged_dogs) > 5:
303
- merged_dogs = sorted(merged_dogs, key=lambda x: x[1], reverse=True)[:5]
304
 
305
- return merged_dogs
306
-
307
- # 如果沒有檢測到狗狗,使用備用分類器
308
- return await fallback_classifier(image)
309
-
310
- async def fallback_classifier(image):
311
- # 使用預訓練的 ResNet 或其他適合的分類器
312
- transform = transforms.Compose([
313
- transforms.Resize((224, 224)),
314
- transforms.ToTensor(),
315
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
316
- ])
317
- img_tensor = transform(image).unsqueeze(0)
318
-
319
- with torch.no_grad():
320
- output = fallback_model(img_tensor)
321
- confidence, predicted = torch.max(output, 1)
322
 
323
- if confidence.item() > 0.5: # 設置一個合適的閾值
324
- return [(image, confidence.item(), [0, 0, image.width, image.height])]
325
- else:
326
- return []
327
 
328
  async def predict(image):
329
  if image is None:
@@ -333,15 +293,12 @@ async def predict(image):
333
  if isinstance(image, np.ndarray):
334
  image = Image.fromarray(image)
335
 
336
- dogs = await detect_multiple_dogs(image, conf_threshold=0.25, iou_threshold=0.4, merge_threshold=0.3)
337
-
338
- if len(dogs) == 0:
339
- return "No dogs detected in the image.", image, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), None
340
 
341
  if len(dogs) == 1:
342
- return await process_single_dog(dogs[0][0]) # Pass the cropped image of the single detected dog
343
 
344
- # Multi-dog scenario
345
  color_list = ['#FF0000', '#00FF00', '#0000FF', '#FFFF00', '#00FFFF', '#FF00FF', '#800080', '#FFA500']
346
  explanations = []
347
  buttons = []
@@ -349,7 +306,7 @@ async def predict(image):
349
  draw = ImageDraw.Draw(annotated_image)
350
  font = ImageFont.load_default()
351
 
352
- for i, (cropped_image, confidence, box) in enumerate(dogs):
353
  top1_prob, topk_breeds, topk_probs_percent = await predict_single_dog(cropped_image)
354
  color = color_list[i % len(color_list)]
355
  draw.rectangle(box, outline=color, width=3)
@@ -360,13 +317,11 @@ async def predict(image):
360
  description = get_dog_description(breed)
361
  formatted_description = format_description(description, breed)
362
  explanations.append(f"Dog {i+1}: {formatted_description}")
363
- elif top1_prob >= 0.2:
364
  dog_explanation = f"Dog {i+1}: Top 3 possible breeds:\n"
365
  dog_explanation += "\n".join([f"{j+1}. **{breed}** ({prob} confidence)" for j, (breed, prob) in enumerate(zip(topk_breeds[:3], topk_probs_percent[:3]))])
366
  explanations.append(dog_explanation)
367
  buttons.extend([gr.update(visible=True, value=f"Dog {i+1}: More about {breed}") for breed in topk_breeds[:3]])
368
- else:
369
- explanations.append(f"Dog {i+1}: The image is unclear or the breed is not in the dataset.")
370
 
371
  final_explanation = "\n\n".join(explanations)
372
  if buttons:
@@ -392,7 +347,7 @@ async def predict(image):
392
 
393
  except Exception as e:
394
  error_msg = f"An error occurred: {str(e)}"
395
- print(error_msg) # Add log output
396
  return error_msg, None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), None
397
 
398
 
 
243
  # print(error_msg) # 添加日誌輸出
244
  # return error_msg, None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), None
245
 
246
+ async def detect_multiple_dogs(image, conf_threshold=0.2, iou_threshold=0.4, merge_threshold=0.3):
247
  results = model_yolo(image, conf=conf_threshold, iou=iou_threshold)[0]
248
  dogs = []
 
 
249
 
250
  image_area = image.width * image.height
251
+ min_area_ratio = 0.01 # 最小檢測面積佔整個圖像的比例
252
 
253
  for box in results.boxes:
254
  if box.cls == 16: # COCO 數據集中狗的類別是 16
 
256
  area = (xyxy[2] - xyxy[0]) * (xyxy[3] - xyxy[1])
257
  if area / image_area >= min_area_ratio:
258
  confidence = box.conf.item()
259
+ cropped_image = image.crop((xyxy[0], xyxy[1], xyxy[2], xyxy[3]))
260
+ dogs.append((cropped_image, confidence, xyxy))
261
 
262
+ # 合併重疊的框
263
+ merged_dogs = []
264
+ while dogs:
265
+ base_dog = dogs.pop(0)
266
+ base_box = torch.tensor(base_dog[2])
267
+ to_merge = [base_dog]
268
 
269
+ i = 0
270
+ while i < len(dogs):
271
+ compare_box = torch.tensor(dogs[i][2])
272
+ iou = box_iou(base_box.unsqueeze(0), compare_box.unsqueeze(0)).item()
273
+ if iou > merge_threshold:
274
+ to_merge.append(dogs.pop(i))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  else:
276
+ i += 1
 
 
 
 
 
 
 
277
 
278
+ if len(to_merge) == 1:
279
+ merged_dogs.append(base_dog)
280
+ else:
281
+ merged_box = torch.tensor([dog[2] for dog in to_merge]).mean(0)
282
+ merged_confidence = max(dog[1] for dog in to_merge)
283
+ merged_image = image.crop(merged_box.tolist())
284
+ merged_dogs.append((merged_image, merged_confidence, merged_box.tolist()))
 
 
 
 
 
 
 
 
 
 
285
 
286
+ return merged_dogs if merged_dogs else [(image, 1.0, [0, 0, image.width, image.height])]
 
 
 
287
 
288
  async def predict(image):
289
  if image is None:
 
293
  if isinstance(image, np.ndarray):
294
  image = Image.fromarray(image)
295
 
296
+ dogs = await detect_multiple_dogs(image)
 
 
 
297
 
298
  if len(dogs) == 1:
299
+ return await process_single_dog(dogs[0][0])
300
 
301
+ # 多狗情境
302
  color_list = ['#FF0000', '#00FF00', '#0000FF', '#FFFF00', '#00FFFF', '#FF00FF', '#800080', '#FFA500']
303
  explanations = []
304
  buttons = []
 
306
  draw = ImageDraw.Draw(annotated_image)
307
  font = ImageFont.load_default()
308
 
309
+ for i, (cropped_image, _, box) in enumerate(dogs):
310
  top1_prob, topk_breeds, topk_probs_percent = await predict_single_dog(cropped_image)
311
  color = color_list[i % len(color_list)]
312
  draw.rectangle(box, outline=color, width=3)
 
317
  description = get_dog_description(breed)
318
  formatted_description = format_description(description, breed)
319
  explanations.append(f"Dog {i+1}: {formatted_description}")
320
+ else:
321
  dog_explanation = f"Dog {i+1}: Top 3 possible breeds:\n"
322
  dog_explanation += "\n".join([f"{j+1}. **{breed}** ({prob} confidence)" for j, (breed, prob) in enumerate(zip(topk_breeds[:3], topk_probs_percent[:3]))])
323
  explanations.append(dog_explanation)
324
  buttons.extend([gr.update(visible=True, value=f"Dog {i+1}: More about {breed}") for breed in topk_breeds[:3]])
 
 
325
 
326
  final_explanation = "\n\n".join(explanations)
327
  if buttons:
 
347
 
348
  except Exception as e:
349
  error_msg = f"An error occurred: {str(e)}"
350
+ print(error_msg) # 添加日誌輸出
351
  return error_msg, None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), None
352
 
353