AnishKumbhar commited on
Commit
4ecc856
1 Parent(s): 12f6e2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -25
app.py CHANGED
@@ -1,39 +1,46 @@
 
1
  import fastapi
2
- from transformers import pipeline
3
- import pickle
 
4
 
5
- # Load the model from the pickle file
6
- with open("model2.ckpt", "rb") as f:
7
- model = pickle.load(f)
8
 
9
- # Define a function to preprocess the image
10
- def preprocess_image(image):
11
- # Resize the image to a fixed size
12
- image = image.resize((224, 224))
13
 
14
- # Convert the image to a NumPy array
15
- image = np.array(image)
16
 
17
- # Normalize the image
18
- image = image / 255.0
19
 
20
- # Return the image
21
- return image
 
 
 
 
 
 
 
22
 
23
- # Define an endpoint to predict the output
24
  @app.post("/predict")
25
- async def predict_endpoint(image: fastapi.File):
26
- # Preprocess the image
27
- image = preprocess_image(image)
 
 
28
 
29
  # Make a prediction
30
- prediction = model(image)
31
 
32
- # Return the prediction
33
- return {"prediction": prediction}
34
 
35
- # Start the FastAPI app
36
- if _name_ == "_main_":
37
- import uvicorn
38
 
 
 
39
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ import torch
2
  import fastapi
3
+ import numpy as np
4
+ from PIL import Image
5
+ from typing import Any, Type
6
 
7
+ class TorchTensor(torch.Tensor):
8
+ pass
 
9
 
10
+ class Prediction:
11
+ prediction: TorchTensor
 
 
12
 
13
+ app = fastapi.FastAPI()
 
14
 
15
+ model = torch.load("model67.bin", map_location='cpu')
 
16
 
17
+ # Define a function to preprocess the input image
18
+ def preprocess_input(input: Any):
19
+ image = Image.open(BytesIO(input))
20
+ image = image.resize((224, 224))
21
+ input = np.array(image)
22
+ input = torch.from_numpy(input).float()
23
+ input = input.permute(2, 0, 1)
24
+ input = input.unsqueeze(0)
25
+ return input
26
 
27
+ # Define an endpoint to make predictions
28
  @app.post("/predict")
29
+ async def predict_endpoint(input: Any):
30
+ """Make a prediction on an image uploaded by the user."""
31
+
32
+ # Preprocess the input image
33
+ input = preprocess_input(input)
34
 
35
  # Make a prediction
36
+ prediction = model(input)
37
 
38
+ # Get the predicted class
39
+ predicted_class = prediction.argmax(1).item()
40
 
41
+ # Return the predicted class in JSON format
42
+ return {"prediction": predicted_class}
 
43
 
44
+ if __name__ == "__main__":
45
+ import uvicorn
46
  uvicorn.run(app, host="0.0.0.0", port=8000)