import streamlit as st from PIL import Image import torch from torchvision import transforms, utils from facenet_pytorch import MTCNN from torchvision.transforms.functional import to_pil_image # Function to load the ViT model and MTCNN def load_model_and_mtcnn(model_path): model = torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) mtcnn = MTCNN(keep_all=True, device=device) return model, device, mtcnn def detect_and_process_skin(image_bytes): """Detects faces in an image, crops the skin region, and returns it as an image object.""" # Load image from bytes img = Image.open(io.BytesIO(image_bytes)) img_np = np.array(img) img_rgb = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB) # Detect faces in the image detections = mtcnn.detect_faces(img_rgb) # Check if any faces were detected if detections: x, y, width, height = detections[0]['box'] # Crop the face region face_img_np = img_np[y:y+height, x:x+width] # Convert to PIL Image for return pil_img = Image.fromarray(face_img_np) return pil_img else: # Return original image if no face was detected return img # Function to preprocess the image and return both the tensor and the final PIL image for display def preprocess_image(image, mtcnn, device): processed_image = image # Initialize with the original image try: # Directly call mtcnn with the image to get cropped faces cropped_faces = mtcnn(image) if cropped_faces is not None and len(cropped_faces) > 0: # Convert the first detected face tensor back to PIL Image for further processing processed_image = to_pil_image(cropped_faces[0].cpu(),mode='BGR;16') except Exception as e: st.write(f"Exception in face detection: {e}") processed_image = image transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) image_tensor = transform(processed_image).to(device) image_tensor = image_tensor.unsqueeze(0) # Add a batch dimension return image_tensor, processed_image # Function for inference def predict(image_tensor, model, device): model.eval() with torch.no_grad(): outputs = model(image_tensor) # Adjust for your model's output if it does not have a 'logits' attribute probabilities = torch.nn.functional.softmax(outputs.logits, dim=1) predicted_class = torch.argmax(probabilities, dim=1) return predicted_class, probabilities # Streamlit UI setup st.title("Face Detection and Classification with ViT") st.write("Upload an image, and the model will detect faces and classify the image.") model_path = "model_v1.0.pt" # Adjust this path as necessary model, device, mtcnn = load_model_and_mtcnn(model_path) uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file).convert("RGB") st.image(image, caption='Uploaded Image', use_column_width=True) image1 = image.getvalue() image_ten = detect_and_process_skin(image1) image_tensor, final_image = preprocess_image(image, mtcnn, device) predicted_class, probabilities = predict(image_ten, model, device) st.write(f"Predicted class: {predicted_class.item()}") # Display the final processed image # st.image(final_image, caption='Processed Image', use_column_width=True) img_bytes = io.BytesIO() detect_and_process_skin(image1.getvalue()).save(img_bytes, format='JPEG') st.image(img_bytes.getvalue(), width=250, caption="Processed Image")