mateoluksenberg commited on
Commit
39591e8
1 Parent(s): 1bbdb1c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -43
app.py CHANGED
@@ -1,43 +1,20 @@
1
- from fastapi import FastAPI, HTTPException
2
- from transformers import pipeline
3
- from PIL import Image
4
- import io
5
-
6
- app = FastAPI()
7
-
8
- # Load the image classification pipeline
9
- pipe = pipeline("image-classification", model="mateoluksenberg/dit-base-Classifier_CM05")
10
-
11
- # Sample image path (for testing)
12
- image_path = 'cm5.jpg'
13
-
14
- # Async function to classify an image
15
- async def classify_image(image_path: str):
16
- try:
17
- image = Image.open(image_path).convert('RGB')
18
-
19
- image_bytes = io.BytesIO()
20
- image.save(image_bytes, format='JPEG')
21
- image_bytes = image_bytes.getvalue()
22
-
23
- # Perform image classification
24
- result = pipe(image_bytes)
25
-
26
- return result[0] # Return the top prediction
27
-
28
- except Exception as e:
29
- # Handle exceptions, for example: file not found, image format issues, etc.
30
- raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
31
-
32
- @app.get("/")
33
- async def home(image_path: str = image_path):
34
- try:
35
- result = await classify_image(image_path)
36
- return {"message": "Hello World", "classification_result": result}
37
-
38
- except HTTPException as e:
39
- raise e
40
-
41
- except Exception as e:
42
- raise HTTPException(status_code=500, detail=f"Error classifying image: {str(e)}")
43
-
 
1
+
2
+ from transformers import pipeline
3
+ from PIL import Image
4
+
5
+ # Cargar el pipeline para clasificación de imágenes
6
+ pipe = pipeline("image-classification", model="mateoluksenberg/dit-base-Classifier_CM05")
7
+
8
+ # Ruta a la imagen que deseas clasificar dentro del contenedor
9
+ image_path = 'cm5.jpg'
10
+
11
+ # Abrir la imagen desde la ruta local y convertirla a RGB
12
+ image = Image.open(image_path).convert('RGB')
13
+
14
+ # Realizar la clasificación de la imagen
15
+ result = pipe(image)
16
+
17
+ # Imprimir el resultado
18
+ print(result)
19
+
20
+