Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
import torch | |
from torchvision import transforms, utils | |
from facenet_pytorch import MTCNN | |
from torchvision.transforms.functional import to_pil_image | |
# Function to load the ViT model and MTCNN | |
def load_model_and_mtcnn(model_path): | |
model = torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.to(device) | |
mtcnn = MTCNN(keep_all=True, device=device) | |
return model, device, mtcnn | |
# Function to preprocess the image and return both the tensor and the final PIL image for display | |
def preprocess_image(image, mtcnn, device): | |
processed_image = image # Initialize with the original image | |
cropped_image = None | |
try: | |
# Directly call mtcnn with the image to get cropped faces | |
cropped_faces = mtcnn(image) | |
if cropped_faces is not None and len(cropped_faces) > 0: | |
# Convert the first detected face tensor back to PIL Image for further processing | |
cropped_image = to_pil_image(cropped_faces[0].cpu()) | |
except Exception as e: | |
st.write(f"Exception in face detection: {e}") | |
processed_image = image | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
# Apply the transformation to the cropped image if available | |
if cropped_image is not None: | |
processed_image = transform(cropped_image).to(device) | |
# Add a batch dimension | |
processed_image = processed_image.unsqueeze(0) | |
return processed_image, cropped_image | |
# Function for inference | |
def predict(image_tensor, model, device): | |
model.eval() | |
with torch.no_grad(): | |
outputs = model(image_tensor) | |
# Adjust for your model's output if it does not have a 'logits' attribute | |
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1) | |
predicted_class = torch.argmax(probabilities, dim=1) | |
return predicted_class, probabilities | |
# Streamlit UI setup | |
st.title("Face Detection and Classification with ViT") | |
st.write("Upload an image, and the model will detect faces and classify the image.") | |
model_path = "model_v1.0.pt" # Adjust this path as necessary | |
model, device, mtcnn = load_model_and_mtcnn(model_path) | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file).convert("RGB") | |
st.image(image, caption='Uploaded Image', use_column_width=True) | |
image_tensor, final_image = preprocess_image(image, mtcnn, device) | |
predicted_class, probabilities = predict(image_tensor, model, device) | |
st.write(f"Predicted class: {predicted_class.item()}") | |
# Display the final processed image | |
st.image(final_image, caption='Processed Image', use_column_width=True) |