File size: 999 Bytes
e143977
0f3e8d6
60712d2
0f3e8d6
06ba20f
 
208c137
06ba20f
77aaa7a
e2590b7
d72ad25
bfbcab4
d72ad25
bfbcab4
0f3e8d6
 
 
bfbcab4
60712d2
 
 
 
75a386f
e143977
60712d2
367823f
bfbcab4
60712d2
 
 
 
 
 
 
 
bfbcab4
60712d2
e143977
60712d2
bfbcab4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
from fastapi import FastAPI, UploadFile
from PIL import Image
import numpy as np

class TorchTensor(torch.Tensor):
    pass

class Prediction:
    prediction: TorchTensor

app = fastapi.FastAPI(docs_url="/")

model = torch.load("best_model.pth", map_location='cpu')

def preprocess_input(input: UploadFile):
    """Preprocess the input image."""
    image = Image.open(input.file)
    image = image.resize((224, 224))
    input = np.array(image)
    input = torch.from_numpy(input).float()
    input = input.unsqueeze(0)
    return input

# Define an endpoint to make predictions
@app.post("/predict")
async def predict_endpoint(input:fastapi.UploadFile):
    """Make a prediction on an image uploaded by the user."""

    # Preprocess the input image
    input = preprocess_input(input)

    # Make a prediction
    prediction = model(input)


    predicted_class = prediction.argmax(1).item()

    # Return the predicted class in JSON format
    return {"prediction": predicted_class}