Spaces:
Runtime error
Runtime error
""" | |
Gradio demo of image classification with OOD detection. | |
If the image example is probably OOD, the model will abstain from the prediction. | |
""" | |
import json | |
import logging | |
import pickle | |
from glob import glob | |
import gradio as gr | |
import numpy as np | |
import timm | |
import torch | |
import torch.nn.functional as F | |
from gradio.components import JSON, Image, Label | |
from timm.data import resolve_data_config | |
from timm.data.transforms_factory import create_transform | |
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names | |
_logger = logging.getLogger(__name__) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
TOPK = 3 | |
# load model | |
print("Loading model...") | |
model = timm.create_model("resnet50.tv2_in1k", pretrained=True) | |
model.to(device) | |
model.eval() | |
# dataset labels | |
idx2label = json.loads(open("ilsvrc2012.json").read()) | |
idx2label = {int(k): v for k, v in idx2label.items()} | |
print(idx2label) | |
print(idx2label.values()) | |
# transformation | |
config = resolve_data_config({}, model=model) | |
config["is_training"] = False | |
transform = create_transform(**config) | |
# create feature extractor | |
penultimate_features_key = "global_pool.flatten" | |
logits_key = "fc" | |
features_names = [penultimate_features_key, logits_key] | |
feature_extractor = create_feature_extractor(model, features_names) | |
# load centroids | |
centroids = torch.load("centroids_resnet50.tv2_in1k_igeood_logits.pt") | |
# OOD detector thresholds | |
msp_threshold = 0.3796 | |
energy_threshold = 8 | |
igeood_threshold = 2.4984 | |
def mahalanobis_penult(features): | |
scores = torch.norm(features, dim=1, keepdims=True) | |
s = torch.min(scores, dim=1)[0] | |
return -s.item() | |
def msp(logits): | |
return torch.softmax(logits, dim=1).max(-1)[0].item() | |
def energy(logits): | |
return torch.logsumexp(logits, dim=1).item() | |
def igeoodlogits_vec(logits, temperature, centroids, epsilon=1e-12): | |
logits = torch.sqrt(F.softmax(logits / temperature, dim=1)) | |
centroids = torch.sqrt(F.softmax(centroids / temperature, dim=1)) | |
mult = logits @ centroids.T | |
stack = 2 * torch.acos(torch.clamp(mult, -1 + epsilon, 1 - epsilon)) | |
return stack.mean(dim=1).item() | |
def predict(image): | |
# forward pass | |
inputs = transform(image).unsqueeze(0) | |
inputs = inputs.to(device) | |
with torch.no_grad(): | |
features = feature_extractor(inputs) | |
# top 5 predictions | |
probabilities = torch.softmax(features[logits_key], dim=-1) | |
softmax, class_idxs = torch.topk(probabilities, TOPK) | |
_logger.info(softmax) | |
_logger.info(class_idxs) | |
result = {idx2label[i.item()]: v.item() for i, v in zip(class_idxs.squeeze(), softmax.squeeze())} | |
# OOD | |
msp_score = round(msp(features[logits_key]), 4) | |
energy_score = round(energy(features[logits_key]), 4) | |
igeood_scores = round(igeoodlogits_vec(features[logits_key], 1, centroids), 4) | |
ood_scores = { | |
"MSP": msp_score, | |
"MSP, is the input OOD?": msp_score < msp_threshold, | |
"Energy": energy_score, | |
"Energy, is the input OOD?": energy_score < energy_threshold, | |
"Igeood": igeood_scores, | |
"Igeood, is the input OOD?": igeood_scores < igeood_threshold, | |
} | |
_logger.info(ood_scores) | |
return result, ood_scores | |
def main(): | |
# image examples for demo shuffled | |
examples = glob("images/imagenet/*") + glob("images/ood/*") | |
np.random.seed(42) | |
# np.random.shuffle(examples) | |
# gradio interface | |
interface = gr.Interface( | |
fn=predict, | |
inputs=Image(type="pil"), | |
outputs=[ | |
Label(num_top_classes=TOPK, label="Model prediction"), | |
JSON(label="OOD scores"), | |
], | |
examples=examples, | |
examples_per_page=len(examples), | |
allow_flagging="never", | |
theme="default", | |
title="OOD Detection 🧐", | |
description=( | |
"Out-of-distribution (OOD) detection is an essential safety measure for machine learning models. " | |
"The objective of an OOD detector is to determine wether the input sample comes from the distribution known by the AI model. " | |
"For instance, an input that does not belong to any of the known classes or is from a different domain should be flagged by the detector.\n" | |
"In this demo we will display the decision of three OOD detectors on a ResNet-50 model trained to classify on the ImageNet-1K dataset (top-1 accuracy 80%)." | |
"This model can classify among 1000 classes from several categories, including `animals`, `vehicles`, `clothing`, `instruments`, `plants`, etc. " | |
"For the complete hierarchy of classes, please check the website https://observablehq.com/@mbostock/imagenet-hierarchy. " | |
"\n\n" | |
"## Instructions:\n" | |
"1. Upload an image of your choice or select one from the examples bar.\n" | |
"2. The model will predict the top 3 most likely classes for the image.\n" | |
"3. The OOD detectors will output their scores and decision on the image. The smaller the score, the least confident the detector is on the sample being in-distribution.\n" | |
"4. If the image is OOD, the model will abstain from the prediction and flag it to the practicioner.\n" | |
"\n\n\nEnjoy the demo!" | |
), | |
cache_examples=True, | |
) | |
interface.launch(server_port=7860) | |
interface.close() | |
if __name__ == "__main__": | |
logging.basicConfig(level=logging.WARN) | |
gr.close_all() | |
main() | |