File size: 1,142 Bytes
4ecc856
fbb5992
4ecc856
 
 
8b7cda3
4ecc856
 
8b7cda3
4ecc856
 
fbb5992
4ecc856
fbb5992
4ecc856
fbb5992
4ecc856
 
 
 
 
 
 
 
 
29c50fd
4ecc856
29c50fd
4ecc856
 
 
 
 
8b7cda3
fbb5992
4ecc856
29c50fd
4ecc856
 
29c50fd
4ecc856
 
29c50fd
4ecc856
 
fbb5992
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import fastapi
import numpy as np
from PIL import Image
from typing import Any, Type

class TorchTensor(torch.Tensor):
    pass

class Prediction:
    prediction: TorchTensor

app = fastapi.FastAPI()

model = torch.load("model67.bin", map_location='cpu')

# Define a function to preprocess the input image
def preprocess_input(input: Any):
    image = Image.open(BytesIO(input))
    image = image.resize((224, 224))
    input = np.array(image)
    input = torch.from_numpy(input).float()
    input = input.permute(2, 0, 1)
    input = input.unsqueeze(0)
    return input

# Define an endpoint to make predictions
@app.post("/predict")
async def predict_endpoint(input: Any):
    """Make a prediction on an image uploaded by the user."""

    # Preprocess the input image
    input = preprocess_input(input)

    # Make a prediction
    prediction = model(input)

    # Get the predicted class
    predicted_class = prediction.argmax(1).item()

    # Return the predicted class in JSON format
    return {"prediction": predicted_class}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)