File size: 1,885 Bytes
af14fe8
 
 
 
 
 
d1b8523
 
af14fe8
 
d1b8523
af14fe8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1b8523
af14fe8
 
 
 
d1b8523
 
 
 
 
 
 
af14fe8
d1b8523
9aa5628
af14fe8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import streamlit as st
import pandas as pd
from plip_support import embed_text
import numpy as np
from PIL import Image
import requests
import transformers
import tokenizers
from io import BytesIO
import streamlit as st
from transformers import CLIPModel
import clip
import torch
from transformers import (
    VisionTextDualEncoderModel,
    AutoFeatureExtractor,
    AutoTokenizer
)
from transformers import AutoProcessor


def embed_texts(model, texts, processor):
    inputs = processor(text=texts, padding="longest")
    input_ids = torch.tensor(inputs["input_ids"])
    attention_mask = torch.tensor(inputs["attention_mask"])

    with torch.no_grad():
        embeddings = model.get_text_features(
            input_ids=input_ids, attention_mask=attention_mask
        )
    return embeddings

@st.cache
def load_embeddings(embeddings_path):
    print("loading embeddings")
    return np.load(embeddings_path)

@st.cache(
    hash_funcs={
        torch.nn.parameter.Parameter: lambda _: None,
        tokenizers.Tokenizer: lambda _: None,
        tokenizers.AddedToken: lambda _: None
    }
)
def load_path_clip():
    model = CLIPModel.from_pretrained("vinid/plip")
    processor = AutoProcessor.from_pretrained("vinid/plip")
    return model, processor

st.title('PLIP Image Search')

plip_dataset = pd.read_csv("tweet_eval_retrieval.tsv", sep="\t")

model, processor = load_path_clip()

image_embedding = load_embeddings("tweet_eval_embeddings.npy")

query = st.text_input('Search Query', '')


if query:

    text_embedding = embed_texts(model, [query], processor)[0].detach().cpu().numpy()

    text_embedding = text_embedding/np.linalg.norm(text_embedding)

    best_id = np.argmax(text_embedding.dot(image_embedding.T))
    url = (plip_dataset.iloc[best_id]["imageURL"])

    response = requests.get(url)
    img = Image.open(BytesIO(response.content))
    st.image(img)