Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,76 +1,98 @@
|
|
1 |
-
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request
|
2 |
-
from fastapi.responses import HTMLResponse
|
3 |
-
import gradio as gr
|
4 |
-
from deepface import DeepFace
|
5 |
import os
|
|
|
6 |
import base64
|
7 |
-
from gradio.routes import App as GradioApp
|
8 |
import logging
|
|
|
|
|
9 |
from fastapi.exceptions import RequestValidationError
|
10 |
-
from
|
|
|
|
|
|
|
|
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
13 |
|
|
|
|
|
|
|
|
|
14 |
# FastAPI instance
|
15 |
app = FastAPI()
|
16 |
|
17 |
-
#
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
20 |
|
|
|
21 |
@app.exception_handler(HTTPException)
|
22 |
async def http_exception_handler(request: Request, exc: HTTPException):
|
23 |
logger.error(f"HTTP error occurred: {exc.detail}")
|
24 |
-
return JSONResponse(
|
25 |
-
status_code=exc.status_code,
|
26 |
-
content={"detail": exc.detail},
|
27 |
-
)
|
28 |
|
29 |
@app.exception_handler(RequestValidationError)
|
30 |
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
31 |
body = await request.body()
|
32 |
-
logger.error(f"Validation error: {exc.errors()}")
|
33 |
try:
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
if 'multipart/form-data' in content_type:
|
48 |
-
# For binary data, log a placeholder message
|
49 |
-
logger.info(f"Incoming request: {request.method} {request.url}")
|
50 |
-
logger.info("Request body contains binary data and is not logged.")
|
51 |
-
else:
|
52 |
-
# For non-binary data, log the actual body
|
53 |
-
body = await request.body()
|
54 |
-
logger.info(f"Incoming request: {request.method} {request.url}")
|
55 |
-
logger.info(f"Request body: {body.decode('utf-8')}")
|
56 |
-
|
57 |
-
response = await call_next(request)
|
58 |
-
return response
|
59 |
-
|
60 |
-
# Gradio Interface Function
|
61 |
-
def face_verification_uii(img1, img2, dist="cosine", model="Facenet", detector="ssd"):
|
62 |
"""
|
63 |
-
|
|
|
64 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
try:
|
|
|
|
|
|
|
66 |
result = DeepFace.verify(
|
67 |
-
img1_path=
|
68 |
-
img2_path=
|
69 |
-
distance_metric=dist,
|
70 |
-
model_name=model,
|
71 |
-
detector_backend=detector,
|
72 |
enforce_detection=False,
|
73 |
)
|
|
|
|
|
74 |
return {
|
75 |
"verified": result["verified"],
|
76 |
"distance": result["distance"],
|
@@ -79,92 +101,49 @@ def face_verification_uii(img1, img2, dist="cosine", model="Facenet", detector="
|
|
79 |
"detector_backend": result["detector_backend"],
|
80 |
"similarity_metric": result["similarity_metric"],
|
81 |
}
|
82 |
-
except Exception as e:
|
83 |
-
return {"error": str(e)}
|
84 |
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
img2: UploadFile = File(None),
|
89 |
-
img1_base64: str = Form(None),
|
90 |
-
img2_base64: str = Form(None),
|
91 |
-
dist: str = Form("cosine"),
|
92 |
-
model: str = Form("Facenet"),
|
93 |
-
detector: str = Form("ssd")
|
94 |
-
):
|
95 |
-
if not img1 and not img2 and not img1_base64 and not img2_base64:
|
96 |
-
raise HTTPException(status_code=400, detail="Invalid input: At least one image input is required.")
|
97 |
|
|
|
|
|
98 |
try:
|
99 |
-
# Ensure uploads directory exists
|
100 |
-
if not os.path.exists("uploads"):
|
101 |
-
os.makedirs("uploads")
|
102 |
-
|
103 |
-
img1_path = None
|
104 |
-
img2_path = None
|
105 |
-
|
106 |
-
# Process img1
|
107 |
-
if img1:
|
108 |
-
img1_path = os.path.join("uploads", img1.filename)
|
109 |
-
with open(img1_path, "wb") as f:
|
110 |
-
f.write(await img1.read())
|
111 |
-
elif img1_base64:
|
112 |
-
img1_path = os.path.join("uploads", "img1_base64.png")
|
113 |
-
with open(img1_path, "wb") as f:
|
114 |
-
f.write(base64.b64decode(img1_base64))
|
115 |
-
|
116 |
-
# Process img2
|
117 |
-
if img2:
|
118 |
-
img2_path = os.path.join("uploads", img2.filename)
|
119 |
-
with open(img2_path, "wb") as f:
|
120 |
-
f.write(await img2.read())
|
121 |
-
elif img2_base64:
|
122 |
-
img2_path = os.path.join("uploads", "img2_base64.png")
|
123 |
-
with open(img2_path, "wb") as f:
|
124 |
-
f.write(base64.b64decode(img2_base64))
|
125 |
-
# Run DeepFace verification
|
126 |
result = DeepFace.verify(
|
127 |
-
img1_path=
|
128 |
-
img2_path=
|
129 |
distance_metric=dist,
|
130 |
model_name=model,
|
131 |
detector_backend=detector,
|
132 |
enforce_detection=False,
|
133 |
)
|
134 |
-
|
135 |
-
# Delete uploaded images after processing
|
136 |
-
os.remove(img1_path)
|
137 |
-
os.remove(img2_path)
|
138 |
-
|
139 |
-
# Return verification results
|
140 |
return {
|
141 |
-
"
|
142 |
-
"
|
143 |
-
"
|
144 |
-
"
|
145 |
-
"
|
146 |
-
"
|
147 |
}
|
148 |
-
|
149 |
except Exception as e:
|
150 |
-
|
151 |
-
raise HTTPException(status_code=500, detail=str(e))
|
152 |
-
|
153 |
|
154 |
-
#
|
155 |
with gr.Blocks() as demo:
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
|
|
|
|
|
|
167 |
if __name__ == "__main__":
|
168 |
import uvicorn
|
169 |
-
|
170 |
-
uvicorn.run(app, host="0.0.0.0", port=7860, reload=True)
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
import json
|
3 |
import base64
|
|
|
4 |
import logging
|
5 |
+
from fastapi import FastAPI, HTTPException, Request
|
6 |
+
from fastapi.responses import HTMLResponse, JSONResponse
|
7 |
from fastapi.exceptions import RequestValidationError
|
8 |
+
from pydantic import BaseModel
|
9 |
+
import gradio as gr
|
10 |
+
from deepface import DeepFace
|
11 |
+
from gradio.routes import App as GradioApp
|
12 |
+
from uuid import uuid4
|
13 |
|
14 |
+
# Constants
|
15 |
+
UPLOAD_DIR = "uploads"
|
16 |
+
DEFAULT_MODEL = "Facenet"
|
17 |
+
DEFAULT_DIST = "cosine"
|
18 |
+
DEFAULT_DETECTOR = "ssd"
|
19 |
+
|
20 |
+
# Ensure uploads directory exists
|
21 |
+
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
22 |
+
|
23 |
+
# Environment Configurations
|
24 |
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
25 |
|
26 |
+
# Logging Setup
|
27 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
|
30 |
# FastAPI instance
|
31 |
app = FastAPI()
|
32 |
|
33 |
+
# Request model
|
34 |
+
class ImageRequest(BaseModel):
|
35 |
+
img1_base64: str
|
36 |
+
img2_base64: str
|
37 |
+
dist: str = DEFAULT_DIST
|
38 |
+
model: str = DEFAULT_MODEL
|
39 |
+
detector: str = DEFAULT_DETECTOR
|
40 |
|
41 |
+
# Exception Handlers
|
42 |
@app.exception_handler(HTTPException)
|
43 |
async def http_exception_handler(request: Request, exc: HTTPException):
|
44 |
logger.error(f"HTTP error occurred: {exc.detail}")
|
45 |
+
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
|
|
|
|
|
|
46 |
|
47 |
@app.exception_handler(RequestValidationError)
|
48 |
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
49 |
body = await request.body()
|
|
|
50 |
try:
|
51 |
+
body_data = json.loads(body.decode("utf-8"))
|
52 |
+
# Log errors with truncated base64 fields
|
53 |
+
for key in ["img1_base64", "img2_base64"]:
|
54 |
+
if key in body_data:
|
55 |
+
body_data[key] = f"{body_data[key][:20]}... [truncated]"
|
56 |
+
logger.error(f"Validation error: {exc.errors()}")
|
57 |
+
logger.error(f"Request body: {json.dumps(body_data)}")
|
58 |
+
except Exception as e:
|
59 |
+
logger.warning(f"Error decoding request body: {e}")
|
60 |
+
return JSONResponse(status_code=400, content={"detail": exc.errors()})
|
61 |
+
|
62 |
+
# Utility Functions
|
63 |
+
def save_base64_image(base64_string: str, file_name: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
"""
|
65 |
+
Decodes a base64 string and saves it to a file.
|
66 |
+
Returns the file path.
|
67 |
"""
|
68 |
+
file_path = os.path.join(UPLOAD_DIR, file_name)
|
69 |
+
with open(file_path, "wb") as f:
|
70 |
+
f.write(base64.b64decode(base64_string))
|
71 |
+
return file_path
|
72 |
+
|
73 |
+
def cleanup_files(*file_paths):
|
74 |
+
"""Deletes files from the file system."""
|
75 |
+
for path in file_paths:
|
76 |
+
if os.path.exists(path):
|
77 |
+
os.remove(path)
|
78 |
+
|
79 |
+
# Face Verification Endpoint
|
80 |
+
@app.post("/face_verification")
|
81 |
+
async def face_verification(request: ImageRequest):
|
82 |
try:
|
83 |
+
img1_path = save_base64_image(request.img1_base64, f"{uuid4()}_img1.png")
|
84 |
+
img2_path = save_base64_image(request.img2_base64, f"{uuid4()}_img2.png")
|
85 |
+
|
86 |
result = DeepFace.verify(
|
87 |
+
img1_path=img1_path,
|
88 |
+
img2_path=img2_path,
|
89 |
+
distance_metric=request.dist,
|
90 |
+
model_name=request.model,
|
91 |
+
detector_backend=request.detector,
|
92 |
enforce_detection=False,
|
93 |
)
|
94 |
+
cleanup_files(img1_path, img2_path)
|
95 |
+
|
96 |
return {
|
97 |
"verified": result["verified"],
|
98 |
"distance": result["distance"],
|
|
|
101 |
"detector_backend": result["detector_backend"],
|
102 |
"similarity_metric": result["similarity_metric"],
|
103 |
}
|
|
|
|
|
104 |
|
105 |
+
except Exception as e:
|
106 |
+
logger.error(f"An error occurred during face verification: {str(e)}")
|
107 |
+
raise HTTPException(status_code=500, detail="Internal server error during face verification.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
+
# Gradio Interface
|
110 |
+
def face_verification_ui(img1, img2, dist=DEFAULT_DIST, model=DEFAULT_MODEL, detector=DEFAULT_DETECTOR):
|
111 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
result = DeepFace.verify(
|
113 |
+
img1_path=img1,
|
114 |
+
img2_path=img2,
|
115 |
distance_metric=dist,
|
116 |
model_name=model,
|
117 |
detector_backend=detector,
|
118 |
enforce_detection=False,
|
119 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
return {
|
121 |
+
"Verified": result["verified"],
|
122 |
+
"Distance": result["distance"],
|
123 |
+
"Threshold": result["threshold"],
|
124 |
+
"Model": result["model"],
|
125 |
+
"Detector Backend": result["detector_backend"],
|
126 |
+
"Similarity Metric": result["similarity_metric"],
|
127 |
}
|
|
|
128 |
except Exception as e:
|
129 |
+
return {"Error": str(e)}
|
|
|
|
|
130 |
|
131 |
+
# Gradio Blocks
|
132 |
with gr.Blocks() as demo:
|
133 |
+
with gr.Row():
|
134 |
+
img1 = gr.Image(label="Image 1", sources=["upload", "webcam", "clipboard"])
|
135 |
+
img2 = gr.Image(label="Image 2", sources=["upload", "webcam", "clipboard"])
|
136 |
+
with gr.Row():
|
137 |
+
dist = gr.Dropdown(choices=["cosine", "euclidean", "euclidean_l2"], label="Distance Metric", value=DEFAULT_DIST)
|
138 |
+
model = gr.Dropdown(choices=["VGG-Face", "Facenet", "Facenet512", "ArcFace"], label="Model", value=DEFAULT_MODEL)
|
139 |
+
detector = gr.Dropdown(choices=["opencv", "ssd", "mtcnn", "retinaface", "mediapipe"], label="Detector", value=DEFAULT_DETECTOR)
|
140 |
+
with gr.Row():
|
141 |
+
btn = gr.Button("Verify")
|
142 |
+
output = gr.Textbox(label="Output")
|
143 |
+
|
144 |
+
btn.click(face_verification_ui, inputs=[img1, img2, dist, model, detector], outputs=output)
|
145 |
+
|
146 |
+
# Running Server
|
147 |
if __name__ == "__main__":
|
148 |
import uvicorn
|
149 |
+
uvicorn.run(app, host="0.0.0.0", port=7860, reload=True)
|
|