SANJAYV10 commited on
Commit
d72ad25
1 Parent(s): e143977

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -0
app.py CHANGED
@@ -1,6 +1,21 @@
 
 
 
 
1
  import torch
2
  from torchvision import transforms
3
 
 
 
 
 
 
 
 
 
 
 
 
4
  def preprocess_input(input):
5
  """Preprocess the input image for the PyTorch image classification model.
6
 
@@ -49,3 +64,8 @@ async def predict_endpoint(input: fastapi.File):
49
 
50
  # Return the prediction.
51
  return Prediction(prediction=predicted_class)
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ import torch
4
+ import torch.nn as nn
5
  import torch
6
  from torchvision import transforms
7
 
8
+ class Prediction(BaseModel):
9
+ prediction: torch.Tensor
10
+
11
+ app = FastAPI()
12
+
13
+ # Load the PyTorch model
14
+ model = torch.load("best_model-epoch=01-val_loss=3.00.ckpt")
15
+
16
+ # Define a function to preprocess the input
17
+
18
+
19
  def preprocess_input(input):
20
  """Preprocess the input image for the PyTorch image classification model.
21
 
 
64
 
65
  # Return the prediction.
66
  return Prediction(prediction=predicted_class)
67
+
68
+
69
+ if __name__ == "__main__":
70
+ import uvicorn
71
+ uvicorn.run(app, host="0.0.0.0", port=8000)