SANJAYV10 commited on
Commit
a5aed99
1 Parent(s): 7263161

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -19
app.py CHANGED
@@ -1,18 +1,21 @@
1
- import fastapi
2
- from transformers
3
- import pipeline
4
- import h5py
5
 
6
- # Create a context manager to load the model from the HDF5 file
7
- with h5py.File("model_finetuned.h5", "r") as f:
8
- model = f["model"]
9
 
10
- # Define a dependency to inject the model into the `predict_endpoint()` function
11
- def get_model():
12
- return model
 
 
 
 
 
 
 
13
 
14
  # Define an endpoint to predict the output
15
- @app.post("/predict", dependencies=[Depends(get_model)])
16
  async def predict_endpoint(image: fastapi.File):
17
  # Preprocess the image
18
  image = preprocess_image(image)
@@ -23,14 +26,6 @@ async def predict_endpoint(image: fastapi.File):
23
  # Return the prediction
24
  return {"prediction": prediction}
25
 
26
- # Define a FastAPI exception handler to handle any errors that occur while making predictions
27
- @app.exception_handler(Exception)
28
- async def handle_exception(request: Request, exc: Exception):
29
- return JSONResponse(
30
- status_code=exc.status_code,
31
- content={"message": str(exc)},
32
- )
33
-
34
  # Start the FastAPI app
35
  if __name__ == "__main__":
36
  import uvicorn
 
1
+ import fastapifrom transformers import pipelineimport h5py
 
 
 
2
 
3
+ # Load the model from the h5 file
4
+ model = h5py.File("model_finetuned.h5", "r")["model"]
 
5
 
6
+ # Define a function to preprocess the image
7
+ def preprocess_image(image):
8
+ # Resize the image to a fixed size
9
+ image = image.resize((224, 224))
10
+ # Convert the image to a NumPy array
11
+ image = np.array(image)
12
+ # Normalize the image
13
+ image = image / 255.0
14
+ # Return the image
15
+ return image
16
 
17
  # Define an endpoint to predict the output
18
+ @app.post("/predict")
19
  async def predict_endpoint(image: fastapi.File):
20
  # Preprocess the image
21
  image = preprocess_image(image)
 
26
  # Return the prediction
27
  return {"prediction": prediction}
28
 
 
 
 
 
 
 
 
 
29
  # Start the FastAPI app
30
  if __name__ == "__main__":
31
  import uvicorn