DawnC commited on
Commit
220a9cc
·
1 Parent(s): 2d3e457

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -26
app.py CHANGED
@@ -177,7 +177,6 @@ def get_akc_breeds_link():
177
  # except Exception as e:
178
  # return f"An error occurred: {e}"
179
 
180
- # Prediction function
181
  def predict(image):
182
  try:
183
  image_tensor = preprocess_image(image)
@@ -187,37 +186,43 @@ def predict(image):
187
  logits = output[0]
188
  else:
189
  logits = output
190
-
 
191
  probabilities = F.softmax(logits, dim=1)
192
- top_confidence, top_index = torch.max(probabilities, 1)
193
- top_confidence = top_confidence.item()
194
- top_breed = dog_breeds[top_index.item()]
195
 
196
- # If confidence is higher than 60%, return the top prediction directly
197
- if top_confidence >= 0.60:
198
- description = get_dog_description(top_breed)
199
- akc_link = get_akc_breeds_link()
200
- description_str = f"**Breed**: {top_breed}\n\n**Description**: {description}\n"
201
- description_str += f"\n\n**Want to learn more about dog breeds?** [Visit the AKC dog breeds page]({akc_link}) and search for {top_breed}."
202
- return description_str
203
 
204
- # If confidence is lower than 60%, return top 3 results and explain why
205
- else:
206
- top3_confidences, top3_indices = torch.topk(probabilities, 3, dim=1)
207
- top3_breeds = [dog_breeds[idx] for idx in top3_indices.squeeze().tolist()]
208
- top3_confidences = top3_confidences.squeeze().tolist()
 
209
 
210
- top3_info = "\n\n".join([f"{i+1}. {breed} ({conf*100:.2f}% confidence)"
211
- for i, (breed, conf) in enumerate(zip(top3_breeds, top3_confidences))])
212
 
213
- # Return top 3 breeds and an explanation
214
- message = (f"The model couldn't confidently identify the breed. Here are the top 3 possible breeds:\n\n"
215
- f"{top3_info}\n\n"
216
- "This can happen if the image quality is low or the breed is rare in the dataset. "
217
- "Please try uploading a clearer image or a different angle of the dog. "
218
- "For more accurate results, ensure the dog is the main subject of the photo.")
219
- return message
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  except Exception as e:
222
  return f"An error occurred: {e}"
223
 
 
177
  # except Exception as e:
178
  # return f"An error occurred: {e}"
179
 
 
180
  def predict(image):
181
  try:
182
  image_tensor = preprocess_image(image)
 
186
  logits = output[0]
187
  else:
188
  logits = output
189
+
190
+ # 取得預測的top k結果
191
  probabilities = F.softmax(logits, dim=1)
192
+ topk_probs, topk_indices = torch.topk(probabilities, k=3)
 
 
193
 
194
+ # 檢查最高的預測機率
195
+ top1_prob = topk_probs[0][0].item()
 
 
 
 
 
196
 
197
+ if top1_prob >= 0.5:
198
+ # 正確辨識時,返回該品種資訊
199
+ predicted = topk_indices[0][0]
200
+ breed = dog_breeds[predicted.item()]
201
+ description = get_dog_description(breed)
202
+ akc_link = get_akc_breeds_link()
203
 
204
+ description_str = f"**Breed**: {breed}\n\n**Description**: {description}"
205
+ description_str += f"\n\n**Want to learn more about dog breeds?** [Visit the AKC dog breeds page]({akc_link}) and search for {breed} to find detailed information."
206
 
207
+ return description_str
 
 
 
 
 
 
208
 
209
+ else:
210
+ # 不確定時,返回top 3的預測結果
211
+ topk_breeds = [dog_breeds[idx.item()] for idx in topk_indices[0]]
212
+ topk_probs_percent = [f"{prob.item() * 100:.2f}%" for prob in topk_probs[0]]
213
+
214
+ # 用粗體返回品種和機率
215
+ topk_results = "\n".join([f"**{i+1}. {breed}** ({prob} confidence)" for i, (breed, prob) in enumerate(zip(topk_breeds, topk_probs_percent))])
216
+
217
+ # 提供說明
218
+ explanation = (
219
+ f"The model couldn't confidently identify the breed. Here are the top 3 possible breeds:\n\n{topk_results}\n\n"
220
+ "This can happen if the image quality is low or the breed is rare in the dataset. "
221
+ "Please try uploading a clearer image or a different angle of the dog. "
222
+ "For more accurate results, ensure the dog is the main subject of the photo."
223
+ )
224
+
225
+ return explanation
226
  except Exception as e:
227
  return f"An error occurred: {e}"
228