SANJAYV10 commited on
Commit
367823f
1 Parent(s): 2a4dd4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -43
app.py CHANGED
@@ -9,10 +9,10 @@ from typing import Any, Type
9
  import torch
10
 
11
  class TorchTensor(torch.Tensor):
12
- pass
13
 
14
- class Prediction(BaseModel):
15
- prediction: TorchTensor
16
 
17
  app = FastAPI()
18
 
@@ -23,55 +23,50 @@ model = torch.load("best_model-epoch=01-val_loss=3.00.ckpt")
23
 
24
 
25
  def preprocess_input(input):
26
- """Preprocess the input image for the PyTorch image classification model.
 
 
 
 
 
27
 
28
- Args:
29
- input: A PIL Image object.
30
 
31
- Returns:
32
- A PyTorch tensor representing the preprocessed image.
33
- """
34
 
35
- # Resize the image to the expected size.
36
- input = input.resize((224, 224))
37
 
38
- # Convert the image to a PyTorch tensor.
39
- input = torch.from_numpy(np.array(input)).float()
40
 
41
- # Normalize the image.
42
- input = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(input)
 
 
 
 
 
 
43
 
44
- # Return the preprocessed image.
45
- return input
46
 
47
- @app.post("/predict", response_model=Prediction)
48
- async def predict_endpoint(input: fastapi.File):
49
- """Predict the output of the PyTorch image classification model.
50
 
51
- Args:
52
- input: A file containing the input image.
53
 
54
- Returns:
55
- A JSON object containing the prediction.
56
- """
57
 
58
- # Load the image.
59
- image = await input.read()
60
- image = Image.open(BytesIO(image))
61
 
62
- # Preprocess the image.
63
- image = preprocess_input(image)
64
 
65
- # Make a prediction.
66
- prediction = model(image.unsqueeze(0))
67
-
68
- # Get the top predicted class.
69
- predicted_class = prediction.argmax(1)
70
-
71
- # Return the prediction.
72
- return {"prediction": predicted_class.item()}
73
-
74
-
75
- if __name__ == "__main__":
76
- import uvicorn
77
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
9
  import torch
10
 
11
  class TorchTensor(torch.Tensor):
12
+  pass
13
 
14
+ class Prediction():
15
+  prediction: TorchTensor
16
 
17
  app = FastAPI()
18
 
 
23
 
24
 
25
  def preprocess_input(input):
26
+  """Preprocess the input image for the PyTorch image classification model.
27
+  Args:
28
+   input: A PIL Image object.
29
+  Returns:
30
+   A PyTorch tensor representing the preprocessed image.
31
+  """
32
 
33
+  # Resize the image to the expected size.
34
+  input = input.resize((224, 224))
35
 
36
+  # Convert the image to a PyTorch tensor.
37
+  input = torch.from_numpy(np.array(input)).float()
 
38
 
39
+  # Normalize the image.
40
+  input = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(input)
41
 
42
+  # Return the preprocessed image.
43
+  return input
44
 
45
+ @app.post("/predict")
46
+ async def predict_endpoint(input: Any):
47
+  """Predict the output of the PyTorch image classification model.
48
+  Args:
49
+   input: A file containing the input image.
50
+  Returns:
51
+   A JSON object containing the prediction.
52
+  """
53
 
54
+  # Load the image.
55
+  image = Image.open(BytesIO(input))
56
 
57
+  # Preprocess the image.
58
+  image = preprocess_input(image)
 
59
 
60
+  # Make a prediction.
61
+  prediction = model(image.unsqueeze(0))
62
 
63
+  # Get the top predicted class.
64
+  predicted_class = prediction.argmax(1)
 
65
 
66
+  # Return the prediction.
67
+  return {"prediction": predicted_class.item()}
 
68
 
 
 
69
 
70
+ if _name_ == "_main_":
71
+  import uvicorn
72
+  uvicorn.run(app, host="0.0.0.0", port=8000)