Spaces:
Build error
Build error
File size: 3,895 Bytes
357b0b8 f9d31ee 357b0b8 96ac3ab 357b0b8 a78bf29 357b0b8 f9d31ee a78bf29 357b0b8 2a06c48 c0c0d12 357b0b8 9cde513 357b0b8 9cde513 a78bf29 357b0b8 a78bf29 357b0b8 a78bf29 357b0b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import matplotlib.pyplot as plt
import nmslib
import numpy as np
import os
import streamlit as st
from PIL import Image
from transformers import CLIPProcessor, FlaxCLIPModel
import utils
BASELINE_MODEL = "openai/clip-vit-base-patch32"
# MODEL_PATH = "/home/shared/models/clip-rsicd/bs128x8-lr5e-6-adam/ckpt-1"
MODEL_PATH = "flax-community/clip-rsicd-v2"
# IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-baseline.tsv"
# IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
# IMAGES_DIR = "/home/shared/data/rsicd_images"
IMAGES_DIR = "./images"
@st.cache(allow_output_mutation=True)
def load_example_images():
example_images = {}
image_names = os.listdir(IMAGES_DIR)
for image_name in image_names:
if image_name.find("_") < 0:
continue
image_class = image_name.split("_")[0]
if image_class in example_images.keys():
example_images[image_class].append(image_name)
else:
example_images[image_class] = [image_name]
return example_images
def app():
filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
example_images = load_example_images()
example_image_list = sorted([v[np.random.randint(0, len(v))]
for k, v in example_images.items()][0:10])
st.title("Image to Image Retrieval")
st.markdown("""
The CLIP model from OpenAI is trained in a self-supervised manner using
contrastive learning to project images and caption text onto a common
embedding space. We have fine-tuned the model (see [Model card](https://huggingface.co/flax-community/clip-rsicd-v2))
using the RSICD dataset (10k images and ~50k captions from the remote
sensing domain). Click here for [more information about our project](https://github.com/arampacha/CLIP-rsicd).
This demo shows the image to image retrieval capabilities of this model, i.e.,
given an image file name as a query, we use our fine-tuned CLIP model
to project the query image to the image/caption embedding space and search
for nearby images (by cosine similarity) in this space.
Our fine-tuned CLIP model was previously used to generate image vectors for
our demo, and NMSLib was used for fast vector access.
Here are some randomly generated image files from our corpus. You can
copy paste one of these below or use one from the results of a text to
image search -- {:s}
""".format(", ".join("`{:s}`".format(example) for example in example_image_list)))
image_name = st.text_input("Provide an Image File Name")
submit_button = st.button("Find Similar")
if submit_button:
image = Image.fromarray(plt.imread(os.path.join(IMAGES_DIR, image_name)))
inputs = processor(images=image, return_tensors="jax", padding=True)
query_vec = model.get_image_features(**inputs)
query_vec = np.asarray(query_vec)
ids, distances = index.knnQuery(query_vec, k=11)
result_filenames = [filenames[id] for id in ids]
images, captions = [], []
for result_filename, score in zip(result_filenames, distances):
if result_filename == image_name:
continue
images.append(
plt.imread(os.path.join(IMAGES_DIR, result_filename)))
captions.append("{:s} (score: {:.3f})".format(result_filename, 1.0 - score))
images = images[0:10]
captions = captions[0:10]
st.image(images[0:3], caption=captions[0:3])
st.image(images[3:6], caption=captions[3:6])
st.image(images[6:9], caption=captions[6:9])
st.image(images[9:], caption=captions[9:])
|