SANJAYV10 commited on
Commit
5c503f2
1 Parent(s): 6699b70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -3
app.py CHANGED
@@ -67,9 +67,21 @@ async def predict_endpoint(input: fastapi.UploadFile):
67
 
68
  # Make a prediction
69
  prediction = model(input_data)
 
70
 
71
-
 
 
 
 
72
  # Return the top N class names and their probabilities in JSON format
73
-
 
 
 
 
 
 
 
74
 
75
- return {"predictions": prediction}
 
67
 
68
  # Make a prediction
69
  prediction = model(input_data)
70
+ logits = prediction.logits
71
 
72
+ # Get the top N predictions
73
+ num_top_predictions = 20
74
+ top_predictions = torch.topk(logits, k=num_top_predictions, dim=1)
75
+ top_indices = top_predictions.indices.squeeze().tolist()
76
+ top_probabilities = torch.softmax(top_predictions.values, dim=1).squeeze().tolist()
77
  # Return the top N class names and their probabilities in JSON format
78
+ top_diseases = [class_names[idx] for idx in top_indices]
79
+
80
+ # Create a response dictionary
81
+ response_data = {"predictions": []}
82
+
83
+ # Add the top N predictions and their probabilities to the response
84
+ for disease, probability in zip(top_diseases, top_probabilities):
85
+ response_data["predictions"].append(f"{disease}: {probability * 100:.2f}%")
86
 
87
+ return response_data