AnishKumbhar commited on
Commit
4499fe6
1 Parent(s): cc9c350

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -21
app.py CHANGED
@@ -1,39 +1,38 @@
1
  import fastapi
2
- from transformers import pipeline
3
- import pickle
4
 
5
- # Load the model from the pickle file
6
- with open("model_finetuned.pkl", "rb") as f:
7
- model = pickle.load(f)
8
 
9
  # Define a function to preprocess the image
10
  def preprocess_image(image):
11
- # Resize the image to a fixed size
12
- image = image.resize((224, 224))
13
 
14
- # Convert the image to a NumPy array
15
- image = np.array(image)
16
 
17
- # Normalize the image
18
- image = image / 255.0
19
 
20
- # Return the image
21
- return image
22
 
23
  # Define an endpoint to predict the output
24
  @app.post("/predict")
25
  async def predict_endpoint(image: fastapi.File):
26
- # Preprocess the image
27
- image = preprocess_image(image)
28
 
29
- # Make a prediction
30
- prediction = model(image)
31
 
32
- # Return the prediction
33
- return {"prediction": prediction}
34
 
35
  # Start the FastAPI app
36
  if _name_ == "_main_":
37
- import uvicorn
38
 
39
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
  import fastapi
2
+ import torch
3
+ from transformers import AutoModelForImageClassification
4
 
5
+ # Load the model from the local file
6
+ model = AutoModelForImageClassification.from_pretrained("./model.ckpt")
 
7
 
8
  # Define a function to preprocess the image
9
  def preprocess_image(image):
10
+   # Resize the image to a fixed size
11
+   image = image.resize((224, 224))
12
 
13
+   # Convert the image to a NumPy array
14
+   image = np.array(image)
15
 
16
+   # Normalize the image
17
+   image = image / 255.0
18
 
19
+   # Return the image
20
+   return image
21
 
22
  # Define an endpoint to predict the output
23
  @app.post("/predict")
24
  async def predict_endpoint(image: fastapi.File):
25
+   # Preprocess the image
26
+   image = preprocess_image(image)
27
 
28
+   # Make a prediction
29
+   prediction = model(image)
30
 
31
+   # Return the prediction
32
+   return {"prediction": prediction}
33
 
34
  # Start the FastAPI app
35
  if _name_ == "_main_":
36
+   import uvicorn
37
 
38
+   uvicorn.run(app, host="0.0.0.0", port=8000)