SANJAYV10 commited on
Commit
60712d2
1 Parent(s): ca9bd9d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -32
app.py CHANGED
@@ -1,51 +1,47 @@
1
- # app.py
2
-
3
- from fastapi import FastAPI
4
- import torch
5
- import torch.nn as nn
6
  import torch
7
- from torchvision import transforms
8
-
 
9
  from typing import Any, Type
10
 
11
- import torch
12
-
13
  class TorchTensor(torch.Tensor):
14
  pass
15
 
16
  class Prediction:
17
  prediction: TorchTensor
18
 
19
- app = FastAPI()
20
 
21
- model = torch.load("/main/best_model-epoch=01-val_loss=3.00.ckpt")
 
22
 
23
- def preprocess_input(input):
24
-
25
-
26
- input = input.resize((224, 224))
27
-
28
- input = torch.from_numpy(np.array(input)).float()
29
-
30
- input = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(input)
31
-
32
  return input
33
 
 
34
  @app.post("/predict")
35
  async def predict_endpoint(input: Any):
36
-
37
-
38
- image = Image.open(BytesIO(input))
39
-
40
- image = preprocess_input(image)
41
-
42
- prediction = model(image.unsqueeze(0))
43
-
44
- predicted_class = prediction.argmax(1)
45
-
46
- return {"prediction": predicted_class.item()}
47
 
 
 
48
 
49
- if _name_ == "_main_":
50
  import uvicorn
51
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
 
 
 
 
1
  import torch
2
+ import fastapi
3
+ import numpy as np
4
+ from PIL import Image
5
  from typing import Any, Type
6
 
 
 
7
  class TorchTensor(torch.Tensor):
8
  pass
9
 
10
  class Prediction:
11
  prediction: TorchTensor
12
 
13
+ app = fastapi.FastAPI()
14
 
15
+ # Load the .bin model
16
+ model = torch.load("/main/best_model-epoch=01-val_loss=3.00.ckpt", map_location='cpu')
17
 
18
+ # Define a function to preprocess the input image
19
+ def preprocess_input(input: Any):
20
+ image = Image.open(BytesIO(input))
21
+ image = image.resize((224, 224))
22
+ input = np.array(image)
23
+ input = torch.from_numpy(input).float()
24
+ input = input.permute(2, 0, 1)
25
+ input = input.unsqueeze(0)
 
26
  return input
27
 
28
+ # Define an endpoint to make predictions
29
  @app.post("/predict")
30
  async def predict_endpoint(input: Any):
31
+ """Make a prediction on an image uploaded by the user."""
32
+
33
+ # Preprocess the input image
34
+ input = preprocess_input(input)
35
+
36
+ # Make a prediction
37
+ prediction = model(input)
38
+
39
+ # Get the predicted class
40
+ predicted_class = prediction.argmax(1).item()
 
41
 
42
+ # Return the predicted class in JSON format
43
+ return {"prediction": predicted_class}
44
 
45
+ if __name__ == "__main__":
46
  import uvicorn
47
  uvicorn.run(app, host="0.0.0.0", port=8000)