SANJAYV10 commited on
Commit
92d628f
1 Parent(s): 6d489a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -13
app.py CHANGED
@@ -2,22 +2,16 @@ import fastapi
2
  from transformers import pipeline
3
  import h5py
4
 
5
- # Load the model from the h5 file
6
- model = h5py.File("model_finetuned.h5", "r")["model"]
 
7
 
8
- # Define a function to preprocess the image
9
- def preprocess_image(image):
10
- # Resize the image to a fixed size
11
- image = image.resize((224, 224))
12
- # Convert the image to a NumPy array
13
- image = np.array(image)
14
- # Normalize the image
15
- image = image / 255.0
16
- # Return the image
17
- return image
18
 
19
  # Define an endpoint to predict the output
20
- @app.post("/predict")
21
  async def predict_endpoint(image: fastapi.File):
22
  # Preprocess the image
23
  image = preprocess_image(image)
@@ -28,6 +22,14 @@ async def predict_endpoint(image: fastapi.File):
28
  # Return the prediction
29
  return {"prediction": prediction}
30
 
 
 
 
 
 
 
 
 
31
  # Start the FastAPI app
32
  if __name__ == "__main__":
33
  import uvicorn
 
2
  from transformers import pipeline
3
  import h5py
4
 
5
+ # Create a context manager to load the model from the HDF5 file
6
+ with h5py.File("model_finetuned.h5", "r") as f:
7
+ model = f["model"]
8
 
9
+ # Define a dependency to inject the model into the `predict_endpoint()` function
10
+ def get_model():
11
+ return model
 
 
 
 
 
 
 
12
 
13
  # Define an endpoint to predict the output
14
+ @app.post("/predict", dependencies=[Depends(get_model)])
15
  async def predict_endpoint(image: fastapi.File):
16
  # Preprocess the image
17
  image = preprocess_image(image)
 
22
  # Return the prediction
23
  return {"prediction": prediction}
24
 
25
+ # Define a FastAPI exception handler to handle any errors that occur while making predictions
26
+ @app.exception_handler(Exception)
27
+ async def handle_exception(request: Request, exc: Exception):
28
+ return JSONResponse(
29
+ status_code=exc.status_code,
30
+ content={"message": str(exc)},
31
+ )
32
+
33
  # Start the FastAPI app
34
  if __name__ == "__main__":
35
  import uvicorn