import streamlit as st from PIL import Image import torch from torchvision import transforms, utils from facenet_pytorch import MTCNN from torchvision.transforms.functional import to_pil_image # Function to load the ViT model and MTCNN def load_model_and_mtcnn(model_path): model = torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) mtcnn = MTCNN(keep_all=True, device=device) return model, device, mtcnn # Function to preprocess the image and return both the tensor and the final PIL image for display def preprocess_image(image, mtcnn, device): processed_image = image # Initialize with the original image cropped_image = None try: # Directly call mtcnn with the image to get cropped faces cropped_faces = mtcnn(image) if cropped_faces is not None and len(cropped_faces) > 0: # Convert the first detected face tensor back to PIL Image for further processing cropped_image = to_pil_image(cropped_faces[0].cpu()) except Exception as e: st.write(f"Exception in face detection: {e}") processed_image = image transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Apply the transformation to the cropped image if available if cropped_image is not None: processed_image = transform(cropped_image).to(device) # Add a batch dimension processed_image = processed_image.unsqueeze(0) return processed_image, cropped_image # Function for inference def predict(image_tensor, model, device): model.eval() with torch.no_grad(): outputs = model(image_tensor) # Adjust for your model's output if it does not have a 'logits' attribute probabilities = torch.nn.functional.softmax(outputs.logits, dim=1) predicted_class = torch.argmax(probabilities, dim=1) return predicted_class, probabilities # Streamlit UI setup st.title("Face Detection and Classification with ViT") st.write("Upload an image, and the model will detect faces and classify the image.") model_path = "model_v1.0.pt" # Adjust this path as necessary model, device, mtcnn = load_model_and_mtcnn(model_path) uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image = Image.open(uploaded_file).convert("RGB") st.image(image, caption='Uploaded Image', use_column_width=True) image_tensor, final_image = preprocess_image(image, mtcnn, device) predicted_class, probabilities = predict(image_tensor, model, device) st.write(f"Predicted class: {predicted_class.item()}") # Display the final processed image st.image(final_image, caption='Processed Image', use_column_width=True)