Spaces:
Running
Running
import pickle | |
import gradio as gr | |
from datasets import load_dataset | |
from transformers import AutoModel, AutoFeatureExtractor | |
seed = 42 | |
# Only runs once when the script is first run. | |
with open("index_768.pickle", "rb") as handle: | |
index = pickle.load(handle) | |
# Load model for computing embeddings. | |
feature_extractor = AutoFeatureExtractor.from_pretrained("sasha/autotrain-butterfly-similarity-2490576840") | |
model = AutoModel.from_pretrained("sasha/autotrain-butterfly-similarity-2490576840") | |
# Candidate images. | |
dataset = load_dataset("sasha/butterflies_10k_names_multiple") | |
ds = dataset["train"] | |
def query(image, top_k=4): | |
inputs = feature_extractor(image, return_tensors="pt") | |
model_output = model(**inputs) | |
embedding = model_output.pooler_output.detach() | |
results = index.query(embedding, k=top_k) | |
inx = results[0][0].tolist() | |
images = ds.select(inx)["image"] | |
captions = ds.select(inx)["name"] | |
images_with_captions = [(i, c) for i, c in zip(images,captions)] | |
return images_with_captions | |
title = "Find my Butterfly 🦋" | |
description = "Use this Space to find your butterfly, based on the [iNaturalist butterfly dataset](https://huggingface.co/datasets/huggan/inat_butterflies_top10k)!" | |
gr.Interface( | |
query, | |
inputs=[gr.Image(type="pil")], | |
outputs=gr.Gallery().style(grid=[2], height="auto"), | |
title=title, | |
description=description, | |
examples=[["elton.jpg"],["ken.jpg"],["gaga.jpg"],["taylor.jpg"]], | |
).launch() | |