import io import csv import sys import pickle from collections import Counter import numpy as np import gradio as gr import gdown import torchvision from torchvision.datasets import ImageFolder from PIL import Image from SimSearch import FaissCosineNeighbors, SearchableTrainingSet from ExtractEmbedding import QueryToEmbedding from CHMCorr import chm_classify_and_visualize from visualization import plot_from_reranker_corrmap csv.field_size_limit(sys.maxsize) concat = lambda x: np.concatenate(x, axis=0) # Embeddings gdown.cached_download( url="https://static.taesiri.com/chm-corr/embeddings.pickle", path="./embeddings.pickle", quiet=False, md5="002b2a7f5c80d910b9cc740c2265f058", ) # embeddings # gdown.download(id="116CiA_cXciGSl72tbAUDoN-f1B9Frp89") # labels gdown.download(id="1SDtq6ap7LPPpYfLbAxaMGGmj0EAV_m_e") # CUB training set gdown.cached_download( url="https://static.taesiri.com/chm-corr/CUB_train.zip", path="./CUB_train.zip", quiet=False, md5="1bd99e73b2fea8e4c2ebcb0e7722f1b1", ) # EXTRACT training set torchvision.datasets.utils.extract_archive( from_path="CUB_train.zip", to_path="data/", remove_finished=False, ) # CHM Weights gdown.cached_download( url="https://static.taesiri.com/chm-corr/pas_psi.pt", path="pas_psi.pt", quiet=False, md5="6b7b4d7bad7f89600fac340d6aa7708b", ) # Caluclate Accuracy with open(f"./embeddings.pickle", "rb") as f: Xtrain = pickle.load(f) # FIXME: re-run the code to get the embeddings in the right format with open(f"./labels.pickle", "rb") as f: ytrain = pickle.load(f) searcher = SearchableTrainingSet(Xtrain, ytrain) searcher.build_index() # Extract label names training_folder = ImageFolder(root="./data/train/") id_to_bird_name = { x[1]: x[0].split("/")[-2].replace(".", " ") for x in training_folder.imgs } def search(query_image, searcher=searcher): query_embedding = QueryToEmbedding(query_image) scores, indices, labels = searcher.search(query_embedding, k=50) result_ctr = Counter(labels[0][:20]).most_common(5) top1_label = result_ctr[0][0] top_indices = [] for a, b in zip(labels[0][:20], indices[0][:20]): if a == top1_label: top_indices.append(b) gallery_images = [training_folder.imgs[int(X)][0] for X in top_indices[:5]] predicted_labels = {id_to_bird_name[X[0]]: X[1] / 20.0 for X in result_ctr} # CHM Prediction kNN_results = (top1_label, result_ctr[0][1], gallery_images) support_files = [training_folder.imgs[int(X)][0] for X in indices[0]] support_labels = [training_folder.imgs[int(X)][1] for X in indices[0]] support = [support_files, support_labels] chm_output = chm_classify_and_visualize( query_image, kNN_results, support, training_folder ) fig, chm_output_label = plot_from_reranker_corrmap(chm_output) # Resize the output img_buf = io.BytesIO() fig.savefig(img_buf, format="jpg") image = Image.open(img_buf) width, height = image.size new_width = width new_height = height left = (width - new_width) / 2 top = (height - new_height) / 2 right = (width + new_width) / 2 bottom = (height + new_height) / 2 viz_image = image.crop((left + 310, top + 60, right - 248, bottom - 80)) chm_output_labels = Counter( [ x.split("/")[-2].replace(".", " ").replace("_", " ") for x in chm_output["chm-nearest-neighbors-all"][:20] ] ) return viz_image, {l: s / 20.0 for l, s in chm_output_labels.items()} blocks = gr.Blocks() with blocks: gr.Markdown(""" # CHM-Corr DEMO""") gr.Markdown( """ ### Parameters: N=50, k=20 - Using ``ImageNet Pretrained ResNet50`` features""" ) input_image = gr.Image(type="filepath") run_btn = gr.Button("Classify") gr.Markdown(""" ### CHM-Corr Output Visualization """) viz_plot = gr.Image(type="pil", label="Visualization") with gr.Row(): with gr.Column(): gr.Markdown(""" ### CHM-Corr Prediction """) labels = gr.Label(label="Prediction") with gr.Column(): gr.Markdown(""" ### Examples """) examples = gr.Examples( examples=[ ["./examples/bird.jpg"], ["./examples/Red_Winged_Blackbird_0012_6015.jpg"], ["./examples/Red_Winged_Blackbird_0025_5342.jpg"], ["./examples/sample1.jpeg"], ["./examples/sample2.jpeg"], ["./examples/Yellow_Headed_Blackbird_0020_8549.jpg"], ["./examples/Yellow_Headed_Blackbird_0026_8545.jpg"], ], inputs=[input_image], outputs=[viz_plot, labels], fn=search, cache_examples=False, ) run_btn.click( search, inputs=[input_image], outputs=[viz_plot, labels], ) if __name__ == "__main__": blocks.launch( debug=True, enable_queue=True, )