Spaces:
Runtime error
Runtime error
File size: 2,173 Bytes
bf3fe47 e336559 eb6e722 bf3fe47 2a88d86 bf3fe47 4b863c6 bf3fe47 f0c8df9 9c8a7bc cefc98c f0c8df9 cefc98c f0c8df9 2a88d86 cefc98c 2a88d86 cefc98c 2a88d86 cefc98c 2a88d86 cefc98c 4b863c6 cefc98c bf3fe47 5834f42 bf3fe47 eb6e722 bf3fe47 cefc98c f0c8df9 5834f42 f0c8df9 5834f42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import pickle
import numpy as np
import pandas as pd
import streamlit as st
import torch
from sentence_transformers.util import semantic_search
from transformers import AutoModel, AutoProcessor
st.title("My CLIP Model Test")
st.markdown(
"[Unsplash Lite dataset](https://unsplash.com/data)μμ μ
λ ₯ ν
μ€νΈμ κ°μ₯ μ μ¬ν μ΄λ―Έμ§λ₯Ό κ²μν©λλ€."
)
@st.cache(allow_output_mutation=True, show_spinner=False)
def get_dual_encoder_model(name: str):
with st.spinner("Loading model..."):
model = AutoModel.from_pretrained(name).eval()
processor = AutoProcessor.from_pretrained(name)
return model, processor
@st.cache(allow_output_mutation=True, show_spinner=False)
def get_clip_model(name: str):
with st.spinner("Loading model..."):
model = AutoModel.from_pretrained(name).eval()
processor = AutoProcessor.from_pretrained(name)
return model, processor
model_list = [
"Bingsu/clip-vit-base-patch32-ko",
"openai/clip-vit-base-patch32",
"Bingsu/vitB32_bert_ko_small_clip",
]
model_type = st.radio("Select model", model_list)
if model_type == "Bingsu/vitB32_bert_ko_small_clip":
model, processor = get_dual_encoder_model(model_type)
else:
model, processor = get_clip_model(model_type)
info = pd.read_csv("info.csv")
with open("img_id.pkl", "rb") as f:
img_id = pickle.load(f)
img_emb = np.load("img_emb.npy")
text = st.text_input("Input Text", value="κ²μ κ³ μμ΄")
tokens = processor(text=text, return_tensors="pt")
with torch.no_grad():
text_emb = model.get_text_features(**tokens)
result = semantic_search(text_emb, img_emb, top_k=16)[0]
_result = iter(result)
def get_url() -> str:
# λͺλͺ μ΄λ―Έμ§κ° info.csv λ°μ΄ν°μ μμ΅λλ€.
while True:
r = next(_result)
photo_id = img_id[r["corpus_id"]]
target_series = info.loc[info["photo_id"] == photo_id, "photo_image_url"]
if len(target_series) == 0:
continue
img_url = target_series.iloc[0]
return img_url
columns = st.columns(3) + st.columns(3)
for col in columns:
img_url = get_url()
col.image(img_url, use_column_width=True)
|