Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
import torch | |
from torchvision import transforms, utils | |
from facenet_pytorch import MTCNN | |
from torchvision.transforms.functional import to_pil_image | |
# Function to load the ViT model and MTCNN | |
def load_model_and_mtcnn(model_path): | |
model = torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.to(device) | |
mtcnn = MTCNN(keep_all=True, device=device) | |
return model, device, mtcnn | |
def detect_and_process_skin(image_bytes): | |
"""Detects faces in an image, crops the skin region, and returns it as an image object.""" | |
# Load image from bytes | |
img = Image.open(io.BytesIO(image_bytes)) | |
img_np = np.array(img) | |
img_rgb = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB) | |
# Detect faces in the image | |
detections = mtcnn.detect_faces(img_rgb) | |
# Check if any faces were detected | |
if detections: | |
x, y, width, height = detections[0]['box'] | |
# Crop the face region | |
face_img_np = img_np[y:y+height, x:x+width] | |
# Convert to PIL Image for return | |
pil_img = Image.fromarray(face_img_np) | |
return pil_img | |
else: | |
# Return original image if no face was detected | |
return img | |
# Function to preprocess the image and return both the tensor and the final PIL image for display | |
def preprocess_image(image, mtcnn, device): | |
processed_image = image # Initialize with the original image | |
try: | |
# Directly call mtcnn with the image to get cropped faces | |
cropped_faces = mtcnn(image) | |
if cropped_faces is not None and len(cropped_faces) > 0: | |
# Convert the first detected face tensor back to PIL Image for further processing | |
processed_image = to_pil_image(cropped_faces[0].cpu(),mode='BGR;16') | |
except Exception as e: | |
st.write(f"Exception in face detection: {e}") | |
processed_image = image | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
image_tensor = transform(processed_image).to(device) | |
image_tensor = image_tensor.unsqueeze(0) # Add a batch dimension | |
return image_tensor, processed_image | |
# Function for inference | |
def predict(image_tensor, model, device): | |
model.eval() | |
with torch.no_grad(): | |
outputs = model(image_tensor) | |
# Adjust for your model's output if it does not have a 'logits' attribute | |
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1) | |
predicted_class = torch.argmax(probabilities, dim=1) | |
return predicted_class, probabilities | |
# Streamlit UI setup | |
st.title("Face Detection and Classification with ViT") | |
st.write("Upload an image, and the model will detect faces and classify the image.") | |
model_path = "model_v1.0.pt" # Adjust this path as necessary | |
model, device, mtcnn = load_model_and_mtcnn(model_path) | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file).convert("RGB") | |
st.image(image, caption='Uploaded Image', use_column_width=True) | |
image1 = image.getvalue() | |
image_ten = detect_and_process_skin(image1) | |
image_tensor, final_image = preprocess_image(image, mtcnn, device) | |
predicted_class, probabilities = predict(image_ten, model, device) | |
st.write(f"Predicted class: {predicted_class.item()}") | |
# Display the final processed image | |
# st.image(final_image, caption='Processed Image', use_column_width=True) | |
img_bytes = io.BytesIO() | |
detect_and_process_skin(image1.getvalue()).save(img_bytes, format='JPEG') | |
st.image(img_bytes.getvalue(), width=250, caption="Processed Image") | |