mateoluksenberg commited on
Commit
9309109
1 Parent(s): 8f9e0a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -18
app.py CHANGED
@@ -1,4 +1,5 @@
1
- from fastapi import FastAPI, HTTPException
 
2
  from transformers import pipeline
3
  from PIL import Image
4
  import io
@@ -8,14 +9,13 @@ app = FastAPI()
8
  # Load the image classification pipeline
9
  pipe = pipeline("image-classification", model="mateoluksenberg/dit-base-Classifier_CM05")
10
 
11
- # Sample image path (for testing)
12
- image_path = 'cm5.jpg'
13
-
14
- # Async function to classify an image
15
- async def classify_image(image_path: str):
16
  try:
17
- image = Image.open(image_path).convert('RGB')
 
18
 
 
19
  image_bytes = io.BytesIO()
20
  image.save(image_bytes, format='JPEG')
21
  image_bytes = image_bytes.getvalue()
@@ -23,21 +23,16 @@ async def classify_image(image_path: str):
23
  # Perform image classification
24
  result = pipe(image_bytes)
25
 
26
- return result[0] # Return the top prediction
27
 
28
  except Exception as e:
29
  # Handle exceptions, for example: file not found, image format issues, etc.
30
  raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
31
 
32
  @app.get("/")
33
- async def home(image_path: str = image_path):
34
- try:
35
- result = await classify_image(image_path)
36
- return {"message": "Hello World", "classification_result": result}
37
-
38
- except HTTPException as e:
39
- raise e
40
-
41
- except Exception as e:
42
- raise HTTPException(status_code=500, detail=f"Error classifying image: {str(e)}")
43
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from fastapi.responses import JSONResponse
3
  from transformers import pipeline
4
  from PIL import Image
5
  import io
 
9
  # Load the image classification pipeline
10
  pipe = pipeline("image-classification", model="mateoluksenberg/dit-base-Classifier_CM05")
11
 
12
+ @app.post("/classify/")
13
+ async def classify_image(file: UploadFile = File(...)):
 
 
 
14
  try:
15
+ # Read the file contents into a PIL image
16
+ image = Image.open(file.file).convert('RGB')
17
 
18
+ # Convert the image to bytes
19
  image_bytes = io.BytesIO()
20
  image.save(image_bytes, format='JPEG')
21
  image_bytes = image_bytes.getvalue()
 
23
  # Perform image classification
24
  result = pipe(image_bytes)
25
 
26
+ return {"classification_result": result[0]} # Return the top prediction
27
 
28
  except Exception as e:
29
  # Handle exceptions, for example: file not found, image format issues, etc.
30
  raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
31
 
32
  @app.get("/")
33
+ async def home():
34
+ return {"message": "Hello World"}
 
 
 
 
 
 
 
 
35
 
36
+ # Sample usage:
37
+ # 1. Start the FastAPI server
38
+ # 2. Use a tool like Postman or curl to send a POST request to /classify/ with an image file