SpaceVector_v0 / text_to_image.py
LayBraid
:construction: update app
ae92333
raw
history blame
2.44 kB
import json
import os
import numpy as np
import streamlit as st
from PIL import Image
from transformers import CLIPProcessor, FlaxCLIPModel
import nmslib
def load_index(image_vector_file):
filenames, image_vecs = [], []
fvec = open(image_vector_file, "r")
for line in fvec:
cols = line.strip().split(' ')
filename = cols[0]
image_vec = np.array([float(x) for x in cols[1].split(',')])
filenames.append(filename)
image_vecs.append(image_vec)
V = np.array(image_vecs)
index = nmslib.init(method='hnsw', space='cosinesimil')
index.addDataPointBatch(V)
index.createIndex({'post': 2}, print_progress=True)
return filenames, index
def load_captions(caption_file):
image2caption = {}
with open(caption_file, "r") as fcap:
for line in fcap:
data = json.loads(line.strip())
filename = data["filename"]
captions = data["captions"]
image2caption[filename] = captions
return image2caption
def get_image(text):
model = FlaxCLIPModel.from_pretrained("flax-community/clip-rsicd-v2")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
filename, index = load_index("./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv")
image2caption = load_captions("./images/test-captions.json")
inputs = processor(text=[text], image=None, return_tensors="jax", padding=True)
vector = model.get_text_features(**inputs)
vector = np.asarray(vector)
ids, distances = index.knnQuery(vector, k=10)
result_filenames = [filename[id] for id in ids]
for rank, (result_filename, score) in enumerate(zip(result_filenames, distances)):
caption = "{:s} (score: {:.3f})".format(result_filename, 1.0 - score)
col1, col2, col3 = st.columns([2, 10, 10])
col1.markdown("{:d}.".format(rank + 1))
col2.image(Image.open(os.path.join("./images", result_filename)),
caption=caption)
caption_text = []
for caption in image2caption[result_filename]:
caption_text.append("* {:s}".format(caption))
col3.markdown("".join(caption_text))
st.markdown("---")
suggest_idx = -1
def app():
st.title("Welcome to Space Vector")
st.text("You want search an image with given text.")
text = st.text_input("Enter text: ")
if st.button("Search"):
get_image(text)