clip-rsicd-demo / utils.py
Sujit Pal
fix: removed commented code
f9d31ee
raw
history blame
1.05 kB
import matplotlib.pyplot as plt
import nmslib
import numpy as np
import os
import streamlit as st
from PIL import Image
from transformers import CLIPProcessor, FlaxCLIPModel
@st.cache(allow_output_mutation=True)
def load_index(image_vector_file):
filenames, image_vecs = [], []
fvec = open(image_vector_file, "r")
for line in fvec:
cols = line.strip().split('\t')
filename = cols[0]
image_vec = np.array([float(x) for x in cols[1].split(',')])
filenames.append(filename)
image_vecs.append(image_vec)
V = np.array(image_vecs)
index = nmslib.init(method='hnsw', space='cosinesimil')
index.addDataPointBatch(V)
index.createIndex({'post': 2}, print_progress=True)
return filenames, index
@st.cache(allow_output_mutation=True)
def load_model(model_path, baseline_model):
model = FlaxCLIPModel.from_pretrained(model_path)
# processor = CLIPProcessor.from_pretrained(baseline_model)
processor = CLIPProcessor.from_pretrained(model_path)
return model, processor