Spaces:
Configuration error
Configuration error
import os | |
import pickle | |
from collections import defaultdict | |
from typing import List, Tuple | |
import numpy as np | |
import scipy | |
import torch | |
import tqdm | |
from loguru import logger | |
from transformers import AutoModelForMaskedLM, AutoTokenizer | |
from app.config.models.configs import Config, Document | |
from app.utils import torch_device, split | |
class SpladeSparseVectorDB: | |
def __init__( | |
self, | |
config: Config, | |
) -> None: | |
self._config = config | |
# cuda or mps or cpu | |
self._device = torch_device() | |
logger.info(f"Setting device to {self._device}") | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
"naver/splade-v3", device=self._device, use_fast=True | |
) | |
self.model = AutoModelForMaskedLM.from_pretrained("naver/splade-v3") | |
self.model.to(self._device) | |
self._embeddings = None | |
self._ids = None | |
self._l2_norm_matrix = None | |
self._labels_to_ind = defaultdict(list) | |
self._chunk_size_to_ind = defaultdict(list) | |
self.n_batch = config.embeddings.splade_config.n_batch | |
def _get_batch_embeddings( | |
self, docs: List[str] | |
) -> np.ndarray: | |
tokens = self.tokenizer( | |
docs, return_tensors="pt", padding=True, truncation=True | |
).to(self._device) | |
output = self.model(**tokens) | |
vecs = ( | |
torch.max( | |
torch.log(1 + torch.relu(output.logits)) | |
* tokens.attention_mask.unsqueeze(-1), | |
dim=1, | |
)[0] | |
.squeeze() | |
.detach() | |
.cpu() | |
.numpy() | |
) | |
del output | |
del tokens | |
return vecs | |
def _get_embedding_fnames(self): | |
folder_name = os.path.join(self._config.embeddings.embeddings_path, "splade") | |
fn_embeddings = os.path.join(folder_name, "splade_embeddings.npz") | |
fn_ids = os.path.join(folder_name, "splade_ids.pickle") | |
fn_metadatas = os.path.join(folder_name, "splade_metadatas.pickle") | |
return folder_name, fn_embeddings, fn_ids, fn_metadatas | |
def load(self) -> None: | |
_, fn_embeddings, fn_ids, fn_metadatas = self._get_embedding_fnames() | |
try: | |
self._embeddings = scipy.sparse.load_npz(fn_embeddings) | |
with open(fn_ids, "rb") as fp: | |
self._ids = np.array(pickle.load(fp)) | |
with open(fn_metadatas, "rb") as fm: | |
self._metadatas = np.array(pickle.load(fm)) | |
self._l2_norm_matrix = scipy.sparse.linalg.norm(self._embeddings, axis=1) | |
for ind, m in enumerate(self._metadatas): | |
if m["label"]: | |
self._labels_to_ind[m["label"]].append(ind) | |
self._chunk_size_to_ind[m["chunk_size"]].append(ind) | |
logger.info(f"SPLADE: Got {len(self._labels_to_ind)} labels.") | |
except FileNotFoundError: | |
raise FileNotFoundError( | |
"Embeddings don't exist" | |
) | |
logger.info(f"Loaded sparse embeddings from {fn_embeddings}") | |
def generate_embeddings( | |
self, docs: List[Document], persist: bool = True | |
) -> Tuple[np.ndarray, List[str], List[dict]]: | |
chunk_size = self.n_batch | |
ids = [d.metadata["document_id"] for d in docs] | |
metadatas = [d.metadata for d in docs] | |
vecs = [] | |
for chunk in tqdm.tqdm( | |
split(docs, chunk_size=chunk_size), total=int(len(docs) / chunk_size) | |
): | |
texts = [d.page_content for d in chunk if d.page_content] | |
vecs.append(self._get_batch_embeddings(texts)) | |
embeddings = np.vstack(vecs) | |
if persist: | |
self.persist_embeddings(embeddings, metadatas, ids) | |
return embeddings, ids, metadatas | |
def persist_embeddings(self, embeddings, metadatas, ids): | |
folder_name, fn_embeddings, fn_ids, fn_metadatas = self._get_embedding_fnames() | |
csr_embeddings = scipy.sparse.csr_matrix(embeddings) | |
if not os.path.exists(folder_name): | |
os.makedirs(folder_name) | |
scipy.sparse.save_npz(fn_embeddings, csr_embeddings) | |
self.save_list(ids, fn_ids) | |
self.save_list(metadatas, fn_metadatas) | |
logger.info(f"Saved embeddings to {fn_embeddings}") | |
def query( | |
self, search: str, chunk_size: int, n: int = 50, label: str = "" | |
) -> Tuple[np.ndarray, np.ndarray]: | |
if self._embeddings is None or self._ids is None: | |
logger.info("Loading embeddings...") | |
self.load() | |
if ( | |
label | |
and label in self._labels_to_ind | |
and self._embeddings is not None | |
and self._ids is not None | |
): | |
indices = sorted( | |
list( | |
set(self._labels_to_ind[label]).intersection( | |
set(self._chunk_size_to_ind[chunk_size]) | |
) | |
) | |
) | |
else: | |
indices = sorted(list(set(self._chunk_size_to_ind[chunk_size]))) | |
embeddings = self._embeddings[indices] | |
ids = self._ids[indices] | |
l2_norm_matrix = scipy.sparse.linalg.norm(embeddings, axis=1) | |
embed_query = self._get_batch_embeddings(docs=[search]) | |
l2_norm_query = scipy.linalg.norm(embed_query) | |
if embeddings is not None and l2_norm_matrix is not None and ids is not None: | |
cosine_similarity = embeddings.dot(embed_query) / ( | |
l2_norm_matrix * l2_norm_query | |
) | |
most_similar = np.argsort(cosine_similarity) | |
top_similar_indices = most_similar[-n:][::-1] | |
return ( | |
ids[top_similar_indices], | |
cosine_similarity[top_similar_indices], | |
) | |
def save_list(self, list_: list, fname: str) -> None: | |
with open(fname, "wb") as fp: | |
pickle.dump(list_, fp) | |