Spaces:
Running
Running
import time | |
import re | |
import pandas as pd | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer, AutoModel | |
from tokenizers import Tokenizer, AddedToken | |
import streamlit as st | |
from st_click_detector import click_detector | |
DEVICE = "cpu" | |
MODEL_OPTIONS = ["msmarco-distilbert-base-tas-b", "all-mpnet-base-v2"] | |
DESCRIPTION = """ | |
# Semantic search | |
**Enter your query and hit enter** | |
Built with π€ Hugging Face's [transformers](https://huggingface.co/transformers/) library, [SentenceBert](https://www.sbert.net/) models, [Streamlit](https://streamlit.io/) and 44k movie descriptions from the Kaggle [Movies Dataset](https://www.kaggle.com/rounakbanik/the-movies-dataset) | |
""" | |
def load(): | |
models, tokenizers, embeddings = [], [], [] | |
for model_option in MODEL_OPTIONS: | |
tokenizers.append( | |
AutoTokenizer.from_pretrained(f"sentence-transformers/{model_option}") | |
) | |
models.append( | |
AutoModel.from_pretrained(f"sentence-transformers/{model_option}").to( | |
DEVICE | |
) | |
) | |
embeddings.append(np.load("embeddings.npy")) | |
embeddings.append(np.load("embeddings2.npy")) | |
df = pd.read_csv("movies.csv") | |
return tokenizers, models, embeddings, df | |
tokenizers, models, embeddings, df = load() | |
def pooling(model_output): | |
return model_output.last_hidden_state[:, 0] | |
def compute_embeddings(texts): | |
encoded_input = tokenizers[0]( | |
texts, padding=True, truncation=True, return_tensors="pt" | |
).to(DEVICE) | |
with torch.no_grad(): | |
model_output = models[0](**encoded_input, return_dict=True) | |
embeddings = pooling(model_output) | |
return embeddings.cpu().numpy() | |
def pooling2(model_output, attention_mask): | |
token_embeddings = model_output[0] | |
input_mask_expanded = ( | |
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
) | |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( | |
input_mask_expanded.sum(1), min=1e-9 | |
) | |
def compute_embeddings2(list_of_strings): | |
encoded_input = tokenizers[1]( | |
list_of_strings, padding=True, truncation=True, return_tensors="pt" | |
).to(DEVICE) | |
with torch.no_grad(): | |
model_output = models[1](**encoded_input) | |
sentence_embeddings = pooling2(model_output, encoded_input["attention_mask"]) | |
return F.normalize(sentence_embeddings, p=2, dim=1).cpu().numpy() | |
def semantic_search(query, model_id): | |
start = time.time() | |
if len(query.strip()) == 0: | |
return "" | |
if "[Similar:" not in query: | |
if model_id == 0: | |
query_embedding = compute_embeddings([query]) | |
else: | |
query_embedding = compute_embeddings2([query]) | |
else: | |
match = re.match(r"\[Similar:(\d{1,5}).*", query) | |
if match: | |
idx = int(match.groups()[0]) | |
query_embedding = embeddings[model_id][idx : idx + 1, :] | |
if query_embedding.shape[0] == 0: | |
return "" | |
else: | |
return "" | |
indices = np.argsort(embeddings[model_id] @ np.transpose(query_embedding)[:, 0])[ | |
-1:-11:-1 | |
] | |
if len(indices) == 0: | |
return "" | |
result = "<ol>" | |
for i in indices: | |
result += f"<li style='padding-top: 10px'><b>{df.iloc[i].title}</b> ({df.iloc[i].release_date}). {df.iloc[i].overview} " | |
result += f"<a id='{i}' href='#'>Similar movies</a></li>" | |
delay = "%.3f" % (time.time() - start) | |
return f"<p><i>Computation time: {delay} seconds</i></p>{result}</ol>" | |
st.sidebar.markdown(DESCRIPTION) | |
model_choice = st.sidebar.selectbox("Similarity model", options=MODEL_OPTIONS) | |
model_id = 0 if model_choice == MODEL_OPTIONS[0] else 1 | |
if "query" in st.session_state: | |
query = st.text_input("", value=st.session_state["query"]) | |
else: | |
query = st.text_input("", value="time travel") | |
clicked = click_detector(semantic_search(query, model_id)) | |
if clicked != "": | |
change_query = False | |
if "last_clicked" not in st.session_state: | |
st.session_state["last_clicked"] = clicked | |
change_query = True | |
else: | |
if clicked != st.session_state["last_clicked"]: | |
st.session_state["last_clicked"] = clicked | |
change_query = True | |
if change_query: | |
st.session_state["query"] = f"[Similar:{clicked}] {df.iloc[int(clicked)].title}" | |
st.experimental_rerun() | |