age / app.py
Daniel Johnson
Initial commit
dcf1fc9
import requests
from PIL import Image
from io import BytesIO
import gradio as gr
from transformers import ViTImageProcessor, ViTForImageClassification
# Initialize model and image processor
model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier')
image_processor = ViTImageProcessor.from_pretrained('nateraw/vit-age-classifier')
def classify_age(image_url):
# Get image from URL
r = requests.get(image_url)
im = Image.open(BytesIO(r.content))
# Transform the image
inputs = image_processor(images=im, return_tensors='pt')
output = model(**inputs)
# Predicted Class probabilities
proba = output.logits.softmax(1)
# Predicted Class
preds = proba.argmax(1)
# Map class indices to age ranges (assuming you have a mapping, adjust accordingly)
age_ranges = ["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70+"]
return f"Predicted Age Range: {age_ranges[preds.item()]}"
# Create a Gradio interface
iface = gr.Interface(
fn=classify_age,
inputs=gr.Textbox(label="Image URL"),
outputs=gr.Textbox(label="Predicted Age Range"),
title="Age Classifier",
description="Enter the URL of an image to predict the age range of the person in the image using ViT."
)
if __name__ == "__main__":
iface.launch()