SANJAYV10 commited on
Commit
7263161
1 Parent(s): 2eea5a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -14
app.py CHANGED
@@ -1,31 +1,36 @@
1
  import fastapi
2
- from transformers import pipeline
 
3
  import h5py
4
 
5
- # Load the model from the h5 file
6
- model = h5py.File("model_finetuned.h5", "r")["model"]
 
7
 
8
- # Define a function to preprocess the image
9
- def preprocess_image(image):
10
- # Resize the image to a fixed size
11
- image = image.resize((224, 224))
12
- # Convert the image to a NumPy array
13
- image = np.array(image)
14
- # Normalize the image
15
- image = image / 255.0
16
- # Return the image
17
- return image
18
 
19
  # Define an endpoint to predict the output
20
- @app.post("/predict")
21
  async def predict_endpoint(image: fastapi.File):
22
  # Preprocess the image
23
  image = preprocess_image(image)
 
24
  # Make a prediction
25
  prediction = model.predict(image)
 
26
  # Return the prediction
27
  return {"prediction": prediction}
28
 
 
 
 
 
 
 
 
 
29
  # Start the FastAPI app
30
  if __name__ == "__main__":
31
  import uvicorn
 
1
  import fastapi
2
+ from transformers
3
+ import pipeline
4
  import h5py
5
 
6
+ # Create a context manager to load the model from the HDF5 file
7
+ with h5py.File("model_finetuned.h5", "r") as f:
8
+ model = f["model"]
9
 
10
+ # Define a dependency to inject the model into the `predict_endpoint()` function
11
+ def get_model():
12
+ return model
 
 
 
 
 
 
 
13
 
14
  # Define an endpoint to predict the output
15
+ @app.post("/predict", dependencies=[Depends(get_model)])
16
  async def predict_endpoint(image: fastapi.File):
17
  # Preprocess the image
18
  image = preprocess_image(image)
19
+
20
  # Make a prediction
21
  prediction = model.predict(image)
22
+
23
  # Return the prediction
24
  return {"prediction": prediction}
25
 
26
+ # Define a FastAPI exception handler to handle any errors that occur while making predictions
27
+ @app.exception_handler(Exception)
28
+ async def handle_exception(request: Request, exc: Exception):
29
+ return JSONResponse(
30
+ status_code=exc.status_code,
31
+ content={"message": str(exc)},
32
+ )
33
+
34
  # Start the FastAPI app
35
  if __name__ == "__main__":
36
  import uvicorn