SANJAYV10 commited on
Commit
c1a7e91
1 Parent(s): 5c503f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -12
app.py CHANGED
@@ -70,18 +70,9 @@ async def predict_endpoint(input: fastapi.UploadFile):
70
  logits = prediction.logits
71
 
72
  # Get the top N predictions
73
- num_top_predictions = 20
74
- top_predictions = torch.topk(logits, k=num_top_predictions, dim=1)
75
- top_indices = top_predictions.indices.squeeze().tolist()
76
- top_probabilities = torch.softmax(top_predictions.values, dim=1).squeeze().tolist()
77
- # Return the top N class names and their probabilities in JSON format
78
- top_diseases = [class_names[idx] for idx in top_indices]
79
 
80
  # Create a response dictionary
81
- response_data = {"predictions": []}
82
 
83
- # Add the top N predictions and their probabilities to the response
84
- for disease, probability in zip(top_diseases, top_probabilities):
85
- response_data["predictions"].append(f"{disease}: {probability * 100:.2f}%")
86
-
87
- return response_data
 
70
  logits = prediction.logits
71
 
72
  # Get the top N predictions
73
+ predicted_class = torch.argmax(logits, dim=1).item()
 
 
 
 
 
74
 
75
  # Create a response dictionary
76
+ return {"prediction": predicted_class}
77
 
78
+