SANJAYV10 commited on
Commit
32a363f
1 Parent(s): 92d628f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -15
app.py CHANGED
@@ -2,16 +2,22 @@ import fastapi
2
  from transformers import pipeline
3
  import h5py
4
 
5
- # Create a context manager to load the model from the HDF5 file
6
- with h5py.File("model_finetuned.h5", "r") as f:
7
- model = f["model"]
8
 
9
- # Define a dependency to inject the model into the `predict_endpoint()` function
10
- def get_model():
11
- return model
 
 
 
 
 
 
 
12
 
13
  # Define an endpoint to predict the output
14
- @app.post("/predict", dependencies=[Depends(get_model)])
15
  async def predict_endpoint(image: fastapi.File):
16
  # Preprocess the image
17
  image = preprocess_image(image)
@@ -22,14 +28,6 @@ async def predict_endpoint(image: fastapi.File):
22
  # Return the prediction
23
  return {"prediction": prediction}
24
 
25
- # Define a FastAPI exception handler to handle any errors that occur while making predictions
26
- @app.exception_handler(Exception)
27
- async def handle_exception(request: Request, exc: Exception):
28
- return JSONResponse(
29
- status_code=exc.status_code,
30
- content={"message": str(exc)},
31
- )
32
-
33
  # Start the FastAPI app
34
  if __name__ == "__main__":
35
  import uvicorn
 
2
  from transformers import pipeline
3
  import h5py
4
 
5
+ # Load the model from the h5 file
6
+ model = h5py.File("model_finetuned.h5", "r")["model"]
 
7
 
8
+ # Define a function to preprocess the image
9
+ def preprocess_image(image):
10
+ # Resize the image to a fixed size
11
+ image = image.resize((224, 224))
12
+ # Convert the image to a NumPy array
13
+ image = np.array(image)
14
+ # Normalize the image
15
+ image = image / 255.0
16
+ # Return the image
17
+ return image
18
 
19
  # Define an endpoint to predict the output
20
+ @app.post("/predict")
21
  async def predict_endpoint(image: fastapi.File):
22
  # Preprocess the image
23
  image = preprocess_image(image)
 
28
  # Return the prediction
29
  return {"prediction": prediction}
30
 
 
 
 
 
 
 
 
 
31
  # Start the FastAPI app
32
  if __name__ == "__main__":
33
  import uvicorn