File size: 2,593 Bytes
d72ad25
 
 
 
e143977
 
c2551f6
06ba20f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d72ad25
 
 
 
 
 
 
 
 
 
 
e143977
 
c4de414
e143977
 
c4de414
e143977
 
 
c5f6a7a
e143977
 
c5f6a7a
e143977
 
 
 
 
 
 
 
c5f6a7a
c4de414
e143977
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d72ad25
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from fastapi import FastAPI
from pydantic import BaseModel
import torch
import torch.nn as nn
import torch
from torchvision import transforms

from typing import Any, Type

import pydantic
import torch

class TensorSchema(pydantic.BaseModel):
    shape: list[int]
    dtype: str
    requires_grad: bool

    @classmethod
    def __get_pydantic_core_schema__(cls: Type[Any]) -> pydantic.schema.Schema:
        return pydantic.schema.Schema(
            type="object",
            properties={
                "shape": pydantic.schema.Schema(type="array", items=pydantic.schema.Schema(type="integer")),
                "dtype": pydantic.schema.Schema(type="string"),
                "requires_grad": pydantic.schema.Schema(type="boolean"),
            },
            required=["shape", "dtype", "requires_grad"],
        )

class TorchTensor(torch.Tensor):
    @classmethod
    def __get_pydantic_core_schema__(cls: Type[Any]) -> pydantic.schema.Schema:
        return TensorSchema.__get_pydantic_core_schema__()


class Prediction(BaseModel):
    prediction: torch.Tensor

app = FastAPI()

# Load the PyTorch model
model = torch.load("best_model-epoch=01-val_loss=3.00.ckpt")

# Define a function to preprocess the input


def preprocess_input(input):
    """Preprocess the input image for the PyTorch image classification model.

    Args:
        input: A PIL Image object.

    Returns:
        A PyTorch tensor representing the preprocessed image.
    """

    # Resize the image to the expected size.
    input = input.resize((224, 224))

    # Convert the image to a PyTorch tensor.
    input = torch.from_numpy(np.array(input)).float()

    # Normalize the image.
    input = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(input)

    # Return the preprocessed image.
    return input

@app.post("/predict", response_model=Prediction)
async def predict_endpoint(input: fastapi.File):
    """Predict the output of the PyTorch image classification model.

    Args:
        input: A file containing the input image.

    Returns:
        A JSON object containing the prediction.
    """

    # Load the image.
    image = await input.read()
    image = Image.open(BytesIO(image))

    # Preprocess the image.
    image = preprocess_input(image)

    # Make a prediction.
    prediction = model(image.unsqueeze(0))

    # Get the top predicted class.
    predicted_class = prediction.argmax(1)

    # Return the prediction.
    return Prediction(prediction=predicted_class)


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)