SANJAYV10 commited on
Commit
9cd2acf
1 Parent(s): d1056de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -7
app.py CHANGED
@@ -10,15 +10,25 @@ class Prediction:
10
  prediction: TorchTensor
11
 
12
  app = fastapi.FastAPI(docs_url="/")
 
13
 
14
- # Load the pre-trained model
15
- pre_trained_model = torch.load('best_model.pth', map_location=torch.device('cpu'))
16
 
 
 
 
 
 
 
 
 
17
  # Define a function to preprocess the input image
18
  def preprocess_input(input: fastapi.UploadFile):
19
  image = Image.open(input.file)
20
- image = image.resize((224, 224))
21
  input = np.array(image)
 
22
  input = torch.from_numpy(input).float()
23
  input = input.unsqueeze(0)
24
  return input
@@ -35,7 +45,17 @@ async def predict_endpoint(input:fastapi.UploadFile):
35
  prediction = model(input)
36
 
37
 
38
- predicted_class = prediction.argmax(1).item()
39
-
40
- # Return the predicted class in JSON format
41
- return {"prediction": predicted_class}
 
 
 
 
 
 
 
 
 
 
 
10
  prediction: TorchTensor
11
 
12
  app = fastapi.FastAPI(docs_url="/")
13
+ from transformers import ViTForImageClassification
14
 
15
+ # Define the number of classes in your custom dataset
16
+ num_classes = 20
17
 
18
+ # Initialize the ViTForImageClassification model
19
+ model = ViTForImageClassification.from_pretrained(
20
+ 'google/vit-base-patch16-224-in21k',
21
+ num_labels=num_classes # Specify the number of classes
22
+ )
23
+
24
+
25
+ model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
26
  # Define a function to preprocess the input image
27
  def preprocess_input(input: fastapi.UploadFile):
28
  image = Image.open(input.file)
29
+ image = image.resize((224, 224)).convert("RGB")
30
  input = np.array(image)
31
+ input = np.transpose(input, (2, 0, 1))
32
  input = torch.from_numpy(input).float()
33
  input = input.unsqueeze(0)
34
  return input
 
45
  prediction = model(input)
46
 
47
 
48
+ logits = prediction.logits
49
+ num_top_predictions = 3
50
+ top_predictions = torch.topk(logits, k=num_top_predictions, dim=1)
51
+
52
+ # Get the top 3 class indices and their probabilities
53
+ top_indices = top_predictions.indices.squeeze().tolist()
54
+ top_probabilities = torch.softmax(top_predictions.values, dim=1).squeeze().tolist()
55
+
56
+ # Get the disease names for the top 3 predictions
57
+ disease_names = [disease_names[idx] for idx in top_indices]
58
+
59
+ # Return the top 3 disease names and their probabilities in JSON format
60
+ response_data = [{"disease_name": name, "probability": prob} for name, prob in zip(disease_names, top_probabilities)]
61
+ return {"predictions": response_data}