DawnC commited on
Commit
922691a
·
1 Parent(s): 4a1799c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -51
app.py CHANGED
@@ -95,7 +95,6 @@ class BaseModel(nn.Module):
95
  self.feature_dim = self.backbone.classifier[1].in_features
96
  self.backbone.classifier = nn.Identity()
97
 
98
- # 動態計算 num_heads
99
  self.num_heads = max(1, min(8, self.feature_dim // 64))
100
  self.attention = MultiHeadAttention(self.feature_dim, num_heads=self.num_heads)
101
 
@@ -122,16 +121,16 @@ model = BaseModel(num_classes=num_classes, device=device)
122
  checkpoint = torch.load('best_model_81_dog.pth', map_location=torch.device('cpu'))
123
  model.load_state_dict(checkpoint['model_state_dict'])
124
 
125
- # 將模型設置為評估模式
126
  model.eval()
127
 
128
  # Image preprocessing function
129
  def preprocess_image(image):
130
- # 如果圖片是 numpy.ndarray 轉換為 PIL.Image
131
  if isinstance(image, np.ndarray):
132
  image = Image.fromarray(image)
133
 
134
- # 使用 torchvision.transforms 進行預處理
135
  transform = transforms.Compose([
136
  transforms.Resize((224, 224)),
137
  transforms.ToTensor(),
@@ -140,38 +139,12 @@ def preprocess_image(image):
140
 
141
  return transform(image).unsqueeze(0)
142
 
143
- # def predict(image):
144
- # try:
145
- # image_tensor = preprocess_image(image)
146
- # with torch.no_grad():
147
- # logits, _ = model(image_tensor)
148
- # _, predicted = torch.max(logits, 1)
149
-
150
- # breed = dog_breeds[predicted.item()] # Map label to breed name
151
-
152
- # # Retrieve breed description
153
- # description = get_dog_description(breed)
154
-
155
- # # Formatting the description for better display
156
- # if isinstance(description, dict):
157
- # description_str = f"**Breed**: {description['Breed']}\n\n"
158
- # description_str += f"**Size**: {description['Size']}\n\n"
159
- # description_str += f"**Lifespan**: {description['Lifespan']}\n\n"
160
- # description_str += f"**Temperament**: {description['Temperament']}\n\n"
161
- # description_str += f"**Care Level**: {description['Care Level']}\n\n"
162
- # description_str += f"**Good with Children**: {description['Good with Children']}\n\n"
163
- # description_str += f"**Exercise Needs**: {description['Exercise Needs']}\n\n"
164
- # description_str += f"**Grooming Needs**: {description['Grooming Needs']}\n\n"
165
- # description_str += f"**Description**: {description['Description']}\n\n"
166
- # else:
167
- # description_str = description
168
-
169
- # return description_str
170
- # except Exception as e:
171
- # return f"An error occurred: {e}"
172
 
173
  def get_akc_link(breed):
174
- formatted_breed = breed.lower().replace(" ", "-")
 
 
 
175
  return f"https://www.akc.org/dog-breeds/{formatted_breed}/"
176
 
177
  def predict(image):
@@ -184,30 +157,23 @@ def predict(image):
184
  else:
185
  logits = output
186
  _, predicted = torch.max(logits, 1)
187
- breed = dog_breeds[predicted.item()] # Map label to breed name
188
 
189
- # Retrieve breed description
190
  description = get_dog_description(breed)
191
-
192
- # Generate AKC link
193
  akc_link = get_akc_link(breed)
194
 
195
- # Formatting the description for better display
196
  if isinstance(description, dict):
197
- description_str = f"**Breed**: {description['Breed']}\n\n"
198
- description_str += f"**Size**: {description['Size']}\n\n"
199
- description_str += f"**Lifespan**: {description['Lifespan']}\n\n"
200
- description_str += f"**Temperament**: {description['Temperament']}\n\n"
201
- description_str += f"**Care Level**: {description['Care Level']}\n\n"
202
- description_str += f"**Good with Children**: {description['Good with Children']}\n\n"
203
- description_str += f"**Exercise Needs**: {description['Exercise Needs']}\n\n"
204
- description_str += f"**Grooming Needs**: {description['Grooming Needs']}\n\n"
205
- description_str += f"**Description**: {description['Description']}\n\n"
206
  else:
207
  description_str = description
208
 
209
- # Add AKC link
210
- description_str += f"\n\n[Click here for more information about {breed} on AKC]({akc_link})"
 
 
 
 
 
211
 
212
  return description_str
213
  except Exception as e:
@@ -226,7 +192,6 @@ iface = gr.Interface(
226
  'French_Bulldog.jpeg',
227
  'Samoyed.jpg'],
228
  css = """
229
- /* 新增樣式 */
230
  .container {
231
  max-width: 900px;
232
  margin: 0 auto;
 
95
  self.feature_dim = self.backbone.classifier[1].in_features
96
  self.backbone.classifier = nn.Identity()
97
 
 
98
  self.num_heads = max(1, min(8, self.feature_dim // 64))
99
  self.attention = MultiHeadAttention(self.feature_dim, num_heads=self.num_heads)
100
 
 
121
  checkpoint = torch.load('best_model_81_dog.pth', map_location=torch.device('cpu'))
122
  model.load_state_dict(checkpoint['model_state_dict'])
123
 
124
+ # evaluation mode
125
  model.eval()
126
 
127
  # Image preprocessing function
128
  def preprocess_image(image):
129
+ # If the image is numpy.ndarray turn into PIL.Image
130
  if isinstance(image, np.ndarray):
131
  image = Image.fromarray(image)
132
 
133
+ # Use torchvision.transforms to process images
134
  transform = transforms.Compose([
135
  transforms.Resize((224, 224)),
136
  transforms.ToTensor(),
 
139
 
140
  return transform(image).unsqueeze(0)
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  def get_akc_link(breed):
144
+ # Remove any non-English characters and convert to lowercase
145
+ formatted_breed = ''.join(c for c in breed if ord(c) < 128).lower()
146
+ # Replace spaces with hyphens and remove any remaining special characters
147
+ formatted_breed = '-'.join(word for word in formatted_breed.split() if word.isalnum())
148
  return f"https://www.akc.org/dog-breeds/{formatted_breed}/"
149
 
150
  def predict(image):
 
157
  else:
158
  logits = output
159
  _, predicted = torch.max(logits, 1)
160
+ breed = dog_breeds[predicted.item()]
161
 
 
162
  description = get_dog_description(breed)
 
 
163
  akc_link = get_akc_link(breed)
164
 
 
165
  if isinstance(description, dict):
166
+ description_str = "\n\n".join([f"**{key}**: {value}" for key, value in description.items()])
 
 
 
 
 
 
 
 
167
  else:
168
  description_str = description
169
 
170
+ # Add AKC link as an option
171
+ description_str += f"\n\n**Want to learn more?** [View detailed information about {breed} on the AKC website]({akc_link})"
172
+
173
+ # Add disclaimer
174
+ disclaimer = ("\n\n*Disclaimer: The external link provided leads to the American Kennel Club (AKC) website. "
175
+ "We are not responsible for the content on external sites. Please refer to the AKC's terms of use and privacy policy.*")
176
+ description_str += disclaimer
177
 
178
  return description_str
179
  except Exception as e:
 
192
  'French_Bulldog.jpeg',
193
  'Samoyed.jpg'],
194
  css = """
 
195
  .container {
196
  max-width: 900px;
197
  margin: 0 auto;