Spaces:
Runtime error
Runtime error
File size: 1,885 Bytes
af14fe8 d1b8523 af14fe8 d1b8523 af14fe8 d1b8523 af14fe8 d1b8523 af14fe8 d1b8523 9aa5628 af14fe8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import streamlit as st
import pandas as pd
from plip_support import embed_text
import numpy as np
from PIL import Image
import requests
import transformers
import tokenizers
from io import BytesIO
import streamlit as st
from transformers import CLIPModel
import clip
import torch
from transformers import (
VisionTextDualEncoderModel,
AutoFeatureExtractor,
AutoTokenizer
)
from transformers import AutoProcessor
def embed_texts(model, texts, processor):
inputs = processor(text=texts, padding="longest")
input_ids = torch.tensor(inputs["input_ids"])
attention_mask = torch.tensor(inputs["attention_mask"])
with torch.no_grad():
embeddings = model.get_text_features(
input_ids=input_ids, attention_mask=attention_mask
)
return embeddings
@st.cache
def load_embeddings(embeddings_path):
print("loading embeddings")
return np.load(embeddings_path)
@st.cache(
hash_funcs={
torch.nn.parameter.Parameter: lambda _: None,
tokenizers.Tokenizer: lambda _: None,
tokenizers.AddedToken: lambda _: None
}
)
def load_path_clip():
model = CLIPModel.from_pretrained("vinid/plip")
processor = AutoProcessor.from_pretrained("vinid/plip")
return model, processor
st.title('PLIP Image Search')
plip_dataset = pd.read_csv("tweet_eval_retrieval.tsv", sep="\t")
model, processor = load_path_clip()
image_embedding = load_embeddings("tweet_eval_embeddings.npy")
query = st.text_input('Search Query', '')
if query:
text_embedding = embed_texts(model, [query], processor)[0].detach().cpu().numpy()
text_embedding = text_embedding/np.linalg.norm(text_embedding)
best_id = np.argmax(text_embedding.dot(image_embedding.T))
url = (plip_dataset.iloc[best_id]["imageURL"])
response = requests.get(url)
img = Image.open(BytesIO(response.content))
st.image(img)
|