Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,21 @@
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
from torchvision import transforms
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
def preprocess_input(input):
|
5 |
"""Preprocess the input image for the PyTorch image classification model.
|
6 |
|
@@ -49,3 +64,8 @@ async def predict_endpoint(input: fastapi.File):
|
|
49 |
|
50 |
# Return the prediction.
|
51 |
return Prediction(prediction=predicted_class)
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI
|
2 |
+
from pydantic import BaseModel
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
import torch
|
6 |
from torchvision import transforms
|
7 |
|
8 |
+
class Prediction(BaseModel):
|
9 |
+
prediction: torch.Tensor
|
10 |
+
|
11 |
+
app = FastAPI()
|
12 |
+
|
13 |
+
# Load the PyTorch model
|
14 |
+
model = torch.load("best_model-epoch=01-val_loss=3.00.ckpt")
|
15 |
+
|
16 |
+
# Define a function to preprocess the input
|
17 |
+
|
18 |
+
|
19 |
def preprocess_input(input):
|
20 |
"""Preprocess the input image for the PyTorch image classification model.
|
21 |
|
|
|
64 |
|
65 |
# Return the prediction.
|
66 |
return Prediction(prediction=predicted_class)
|
67 |
+
|
68 |
+
|
69 |
+
if __name__ == "__main__":
|
70 |
+
import uvicorn
|
71 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|