IsItABird / app.py
A19grey's picture
Added ALL THE BIRDS classes to the list
23d6bd9
import os
import gradio as gr
import glob
import time
import random
import requests
import numpy as np
# Import necessary libraries
from torchvision import models, transforms
from PIL import Image
import torch
# Load pre-trained ResNet model once
model = models.resnet50(pretrained=True)
model.eval()
#
# Define image transformations
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Function to download imagenet_classes.txt
def download_imagenet_classes():
url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
response = requests.get(url)
if response.status_code == 200:
with open("imagenet_classes.txt", "wb") as f:
f.write(response.content)
print("imagenet_classes.txt downloaded successfully.")
else:
print("Failed to download imagenet_classes.txt")
# Check if imagenet_classes.txt exists, if not, download it
if not os.path.exists("imagenet_classes.txt"):
download_imagenet_classes()
# Load class labels
with open('imagenet_classes.txt', 'r') as f:
labels = [line.strip() for line in f.readlines()]
def classify_image(image):
# Wait for a random interval between 0.5 and 1.5 seconds to look useful
# time.sleep(random.uniform(0.5, 1.5))
print("Classifying image...")
# Preprocess the image
img = Image.fromarray(image).convert('RGB')
img_t = transform(img)
batch_t = torch.unsqueeze(img_t, 0)
# Make prediction
with torch.no_grad():
output = model(batch_t)
# Get the predicted class
_, predicted = torch.max(output, 1)
classification = labels[predicted.item()]
# Check if the predicted class is a bird
bird_categories = [
'cock', 'hen', 'ostrich', 'brambling', 'goldfinch', 'house finch', 'junco', 'indigo bunting', 'robin',
'bulbul', 'jay', 'magpie', 'chickadee', 'water ouzel', 'kite', 'bald eagle', 'vulture', 'great grey owl',
'European fire salamander', 'ptarmigan', 'ruffed grouse', 'prairie chicken', 'peacock', 'quail', 'partridge',
'African grey', 'macaw', 'sulphur-crested cockatoo', 'lorikeet', 'coucal', 'bee eater', 'hornbill',
'hummingbird', 'jacamar', 'toucan', 'drake', 'red-breasted merganser', 'goose', 'black swan', 'white stork',
'black stork', 'spoonbill', 'flamingo', 'little blue heron', 'American egret', 'bittern', 'crane', 'limpkin',
'European gallinule', 'American coot', 'bustard', 'ruddy turnstone', 'red-backed sandpiper', 'redshank',
'dowitcher', 'oystercatcher', 'pelican', 'king penguin', 'albatross'
]
is_bird = ('bird' in classification.lower()) or any(category in classification.lower() for category in bird_categories)
#
# Get the confidence score
confidence_score = torch.nn.functional.softmax(output[0], dim=0)[predicted].item()
confidence_percentage = f"{confidence_score:.2%}"
if is_bird:
return f"This is a bird! Specifically, it looks like a {classification}. Model confidence: {confidence_percentage}"
else:
return f"This is not a bird. It appears to be a {classification}. Model confidence: {confidence_percentage}"
#
# Dynamically create the list of example images
example_files = sorted(glob.glob("examples/*.png"))
examples = [[file] for file in example_files]
# Create the Gradio interface
demo = gr.Interface(
fn=classify_image, # The function to run
inputs="image", # The input type is an image
outputs="text", # The output type is text
examples=examples # Add example images
,title="Is this a picture of a bird?" # Title of the app
,description="Uses the latest in machine learning LLM Diffusion models to analyzes every pixel (twice) and to determine conclusively if it is a picture of a bird" # Description of the app
)
# Launch the app
demo.launch()