Spaces:
Runtime error
Runtime error
File size: 1,050 Bytes
010fea9 d72ad25 e143977 c2551f6 06ba20f 208c137 06ba20f 77aaa7a e2590b7 d72ad25 e143977 75a386f e143977 367823f 75a386f e143977 367823f 75a386f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
# app.py
from fastapi import FastAPI
import torch
import torch.nn as nn
import torch
from torchvision import transforms
from typing import Any, Type
import torch
class TorchTensor(torch.Tensor):
pass
class Prediction:
prediction: TorchTensor
app = FastAPI()
# Load the PyTorch model
model = torch.load("best_model-epoch=01-val_loss=3.00.ckpt")
# Define a function to preprocess the input
def preprocess_input(input):
input = input.resize((224, 224))
input = torch.from_numpy(np.array(input)).float()
input = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(input)
return input
@app.post("/predict")
async def predict_endpoint(input: Any):
image = Image.open(BytesIO(input))
image = preprocess_input(image)
prediction = model(image.unsqueeze(0))
predicted_class = prediction.argmax(1)
return {"prediction": predicted_class.item()}
if _name_ == "_main_":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|