import requests from PIL import Image from io import BytesIO import gradio as gr from transformers import ViTImageProcessor, ViTForImageClassification # Initialize model and image processor model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier') image_processor = ViTImageProcessor.from_pretrained('nateraw/vit-age-classifier') def classify_age(image_url): # Get image from URL r = requests.get(image_url) im = Image.open(BytesIO(r.content)) # Transform the image inputs = image_processor(images=im, return_tensors='pt') output = model(**inputs) # Predicted Class probabilities proba = output.logits.softmax(1) # Predicted Class preds = proba.argmax(1) # Map class indices to age ranges (assuming you have a mapping, adjust accordingly) age_ranges = ["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70+"] return f"Predicted Age Range: {age_ranges[preds.item()]}" # Create a Gradio interface iface = gr.Interface( fn=classify_age, inputs=gr.Textbox(label="Image URL"), outputs=gr.Textbox(label="Predicted Age Range"), title="Age Classifier", description="Enter the URL of an image to predict the age range of the person in the image using ViT." ) if __name__ == "__main__": iface.launch()