|
import os |
|
import gradio as gr |
|
import glob |
|
import time |
|
import random |
|
import requests |
|
import numpy as np |
|
|
|
|
|
from torchvision import models, transforms |
|
from PIL import Image |
|
import torch |
|
|
|
|
|
model = models.resnet50(pretrained=True) |
|
model.eval() |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
|
|
def download_imagenet_classes(): |
|
url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt" |
|
response = requests.get(url) |
|
if response.status_code == 200: |
|
with open("imagenet_classes.txt", "wb") as f: |
|
f.write(response.content) |
|
print("imagenet_classes.txt downloaded successfully.") |
|
else: |
|
print("Failed to download imagenet_classes.txt") |
|
|
|
|
|
if not os.path.exists("imagenet_classes.txt"): |
|
download_imagenet_classes() |
|
|
|
|
|
with open('imagenet_classes.txt', 'r') as f: |
|
labels = [line.strip() for line in f.readlines()] |
|
|
|
def classify_image(image): |
|
|
|
|
|
print("Classifying image...") |
|
|
|
|
|
img = Image.fromarray(image).convert('RGB') |
|
img_t = transform(img) |
|
batch_t = torch.unsqueeze(img_t, 0) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model(batch_t) |
|
|
|
|
|
_, predicted = torch.max(output, 1) |
|
classification = labels[predicted.item()] |
|
|
|
|
|
bird_categories = [ |
|
'cock', 'hen', 'ostrich', 'brambling', 'goldfinch', 'house finch', 'junco', 'indigo bunting', 'robin', |
|
'bulbul', 'jay', 'magpie', 'chickadee', 'water ouzel', 'kite', 'bald eagle', 'vulture', 'great grey owl', |
|
'European fire salamander', 'ptarmigan', 'ruffed grouse', 'prairie chicken', 'peacock', 'quail', 'partridge', |
|
'African grey', 'macaw', 'sulphur-crested cockatoo', 'lorikeet', 'coucal', 'bee eater', 'hornbill', |
|
'hummingbird', 'jacamar', 'toucan', 'drake', 'red-breasted merganser', 'goose', 'black swan', 'white stork', |
|
'black stork', 'spoonbill', 'flamingo', 'little blue heron', 'American egret', 'bittern', 'crane', 'limpkin', |
|
'European gallinule', 'American coot', 'bustard', 'ruddy turnstone', 'red-backed sandpiper', 'redshank', |
|
'dowitcher', 'oystercatcher', 'pelican', 'king penguin', 'albatross' |
|
] |
|
is_bird = ('bird' in classification.lower()) or any(category in classification.lower() for category in bird_categories) |
|
|
|
|
|
confidence_score = torch.nn.functional.softmax(output[0], dim=0)[predicted].item() |
|
confidence_percentage = f"{confidence_score:.2%}" |
|
|
|
if is_bird: |
|
return f"This is a bird! Specifically, it looks like a {classification}. Model confidence: {confidence_percentage}" |
|
else: |
|
return f"This is not a bird. It appears to be a {classification}. Model confidence: {confidence_percentage}" |
|
|
|
|
|
example_files = sorted(glob.glob("examples/*.png")) |
|
examples = [[file] for file in example_files] |
|
|
|
|
|
demo = gr.Interface( |
|
fn=classify_image, |
|
inputs="image", |
|
outputs="text", |
|
examples=examples |
|
,title="Is this a picture of a bird?" |
|
,description="Uses the latest in machine learning LLM Diffusion models to analyzes every pixel (twice) and to determine conclusively if it is a picture of a bird" |
|
) |
|
|
|
demo.launch() |