SANJAYV10 commited on
Commit
75a386f
1 Parent(s): e2590b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -41
app.py CHANGED
@@ -25,50 +25,31 @@ model = torch.load("best_model-epoch=01-val_loss=3.00.ckpt")
25
 
26
 
27
  def preprocess_input(input):
28
-  """Preprocess the input image for the PyTorch image classification model.
29
-  Args:
30
-  input: A PIL Image object.
31
-  Returns:
32
-  A PyTorch tensor representing the preprocessed image.
33
-  """
34
-
35
-  # Resize the image to the expected size.
36
-  input = input.resize((224, 224))
37
-
38
-  # Convert the image to a PyTorch tensor.
39
-  input = torch.from_numpy(np.array(input)).float()
40
-
41
-  # Normalize the image.
42
-  input = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(input)
43
-
44
-  # Return the preprocessed image.
45
-  return input
46
 
47
  @app.post("/predict")
48
  async def predict_endpoint(input: Any):
49
-  """Predict the output of the PyTorch image classification model.
50
-  Args:
51
-  input: A file containing the input image.
52
-  Returns:
53
-  A JSON object containing the prediction.
54
-  """
55
-
56
-  # Load the image.
57
-  image = Image.open(BytesIO(input))
58
-
59
-  # Preprocess the image.
60
-  image = preprocess_input(image)
61
-
62
-  # Make a prediction.
63
-  prediction = model(image.unsqueeze(0))
64
-
65
-  # Get the top predicted class.
66
-  predicted_class = prediction.argmax(1)
67
-
68
-  # Return the prediction.
69
-  return {"prediction": predicted_class.item()}
70
 
71
 
72
  if _name_ == "_main_":
73
-  import uvicorn
74
-  uvicorn.run(app, host="0.0.0.0", port=8000)
 
25
 
26
 
27
  def preprocess_input(input):
28
+
29
+
30
+ input = input.resize((224, 224))
31
+
32
+ input = torch.from_numpy(np.array(input)).float()
33
+
34
+ input = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(input)
35
+
36
+ return input
 
 
 
 
 
 
 
 
 
37
 
38
  @app.post("/predict")
39
  async def predict_endpoint(input: Any):
40
+
41
+
42
+ image = Image.open(BytesIO(input))
43
+
44
+ image = preprocess_input(image)
45
+
46
+ prediction = model(image.unsqueeze(0))
47
+
48
+ predicted_class = prediction.argmax(1)
49
+
50
+ return {"prediction": predicted_class.item()}
 
 
 
 
 
 
 
 
 
 
51
 
52
 
53
  if _name_ == "_main_":
54
+ import uvicorn
55
+ uvicorn.run(app, host="0.0.0.0", port=8000)