File size: 1,092 Bytes
4ecc856
fbb5992
4ecc856
 
8b7cda3
4ecc856
 
8b7cda3
4ecc856
 
fbb5992
ff48a9b
fbb5992
77684da
 
2a9e83f
4ecc856
70485e0
 
14735ff
 
 
 
4ecc856
29c50fd
4ecc856
29c50fd
ca9a717
4ecc856
 
 
 
8b7cda3
fbb5992
4ecc856
29c50fd
4ecc856
 
29c50fd
4ecc856
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import fastapi
import numpy as np
from PIL import Image

class TorchTensor(torch.Tensor):
    pass

class Prediction:
    prediction: TorchTensor

app = fastapi.FastAPI(docs_url="/")

# model = torch.load("model67.bin", map_location='cpu')
model = torch.load("best_model.pth", map_location='cpu')
print(model)
# Define a function to preprocess the input image
def preprocess_input(input: fastapi.UploadFile):
    image = Image.open(input.file)
    image = image.resize((224, 224))
    input = np.array(image)
    input = torch.from_numpy(input).float()
    input = input.unsqueeze(0)
    return input

# Define an endpoint to make predictions
@app.post("/predict")
async def predict_endpoint(input:fastapi.UploadFile):
    """Make a prediction on an image uploaded by the user."""

    # Preprocess the input image
    input = preprocess_input(input)

    # Make a prediction
    prediction = model(input)

    # Get the predicted class
    predicted_class = prediction.argmax(1).item()

    # Return the predicted class in JSON format
    return {"prediction": predicted_class}