Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -70,18 +70,9 @@ async def predict_endpoint(input: fastapi.UploadFile):
|
|
70 |
logits = prediction.logits
|
71 |
|
72 |
# Get the top N predictions
|
73 |
-
|
74 |
-
top_predictions = torch.topk(logits, k=num_top_predictions, dim=1)
|
75 |
-
top_indices = top_predictions.indices.squeeze().tolist()
|
76 |
-
top_probabilities = torch.softmax(top_predictions.values, dim=1).squeeze().tolist()
|
77 |
-
# Return the top N class names and their probabilities in JSON format
|
78 |
-
top_diseases = [class_names[idx] for idx in top_indices]
|
79 |
|
80 |
# Create a response dictionary
|
81 |
-
|
82 |
|
83 |
-
|
84 |
-
for disease, probability in zip(top_diseases, top_probabilities):
|
85 |
-
response_data["predictions"].append(f"{disease}: {probability * 100:.2f}%")
|
86 |
-
|
87 |
-
return response_data
|
|
|
70 |
logits = prediction.logits
|
71 |
|
72 |
# Get the top N predictions
|
73 |
+
predicted_class = torch.argmax(logits, dim=1).item()
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
# Create a response dictionary
|
76 |
+
return {"prediction": predicted_class}
|
77 |
|
78 |
+
|
|
|
|
|
|
|
|