|
import requests |
|
from PIL import Image |
|
from io import BytesIO |
|
import gradio as gr |
|
from transformers import ViTImageProcessor, ViTForImageClassification |
|
|
|
|
|
model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier') |
|
image_processor = ViTImageProcessor.from_pretrained('nateraw/vit-age-classifier') |
|
|
|
def classify_age(image_url): |
|
|
|
r = requests.get(image_url) |
|
im = Image.open(BytesIO(r.content)) |
|
|
|
|
|
inputs = image_processor(images=im, return_tensors='pt') |
|
output = model(**inputs) |
|
|
|
|
|
proba = output.logits.softmax(1) |
|
|
|
|
|
preds = proba.argmax(1) |
|
|
|
|
|
age_ranges = ["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70+"] |
|
|
|
return f"Predicted Age Range: {age_ranges[preds.item()]}" |
|
|
|
|
|
iface = gr.Interface( |
|
fn=classify_age, |
|
inputs=gr.Textbox(label="Image URL"), |
|
outputs=gr.Textbox(label="Predicted Age Range"), |
|
title="Age Classifier", |
|
description="Enter the URL of an image to predict the age range of the person in the image using ViT." |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|