import os import numpy as np import pypdfium2 as pdfium import torch import tqdm from model import encode_images, encode_queries from PIL import Image from sqlitedict import SqliteDict from voyager import Index, Space def iter_batch( X: list[str], batch_size: int, tqdm_bar: bool = True, desc: str = "" ) -> list: """Iterate over a list of elements by batch.""" batchs = [X[pos : pos + batch_size] for pos in range(0, len(X), batch_size)] if tqdm_bar: for batch in tqdm.tqdm( iterable=batchs, position=0, total=1 + len(X) // batch_size, desc=desc, ): yield batch else: yield from batchs class Voyager: """Voyager index. The Voyager index is a fast and efficient index for approximate nearest neighbor search. Parameters ---------- name The name of the collection. override Whether to override the collection if it already exists. embedding_size The number of dimensions of the embeddings. M The number of subquantizers. ef_construction The number of candidates to evaluate during the construction of the index. ef_search The number of candidates to evaluate during the search. """ def __init__( self, index_folder: str = "indexes", index_name: str = "base_collection", override: bool = False, embedding_size: int = 128, M: int = 64, ef_construction: int = 200, ef_search: int = 200, ) -> None: self.ef_search = ef_search if not os.path.exists(path=index_folder): os.makedirs(name=index_folder) self.index_path = os.path.join(index_folder, f"{index_name}.voyager") self.page_ids_to_data_path = os.path.join( index_folder, f"{index_name}_page_ids_to_data.sqlite" ) self.index = self._create_collection( index_path=self.index_path, embedding_size=embedding_size, M=M, ef_constructions=ef_construction, override=override, ) def _load_page_ids_to_data(self) -> SqliteDict: """Load the SQLite database that maps document IDs to images.""" return SqliteDict(self.page_ids_to_data_path, outer_stack=False) def _create_collection( self, index_path: str, embedding_size: int, M: int, ef_constructions: int, override: bool, ) -> None: """Create a new Voyager collection. Parameters ---------- index_path The path to the index. embedding_size The size of the embeddings. M The number of subquantizers. ef_constructions The number of candidates to evaluate during the construction of the index. override Whether to override the collection if it already exists. """ if os.path.exists(path=index_path) and not override: return Index.load(index_path) if os.path.exists(path=index_path): os.remove(index_path) # Create the Voyager index index = Index( Space.Cosine, num_dimensions=embedding_size, M=M, ef_construction=ef_constructions, ) index.save(index_path) if override and os.path.exists(path=self.page_ids_to_data_path): os.remove(path=self.page_ids_to_data_path) # Create the SQLite databases page_ids_to_data = self._load_page_ids_to_data() page_ids_to_data.close() return index def add_documents( self, paths: str | list[str], batch_size: int = 1, ) -> None: """Add documents to the index. Note that batch_size means the number of pages to encode at once, not documents.""" if isinstance(paths, str): paths = [paths] page_ids_to_data = self._load_page_ids_to_data() images = [] num_pages = [] for path in paths: if path.lower().endswith(".pdf"): pdf = pdfium.PdfDocument(path) n_pages = len(pdf) num_pages.append(n_pages) for page_number in range(n_pages): page = pdf.get_page(page_number) pil_image = page.render( scale=1, rotation=0, ) pil_image = pil_image.to_pil() images.append(pil_image) pdf.close() else: pil_image = Image.open(path) images.append(pil_image) num_pages.append(1) embeddings = [] for batch in iter_batch( X=images, batch_size=batch_size, desc=f"Encoding pages (bs={batch_size})" ): embeddings.extend(encode_images(batch)) embeddings_ids = self.index.add_items(embeddings) current_index = 0 for i, path in enumerate(paths): for page_number in range(num_pages[i]): page_ids_to_data[embeddings_ids[current_index]] = { "path": path, "image": images[current_index], "page_number": page_number, } current_index += 1 page_ids_to_data.commit() self.index.save(self.index_path) return self def __call__( self, queries: np.ndarray | torch.Tensor, k: int = 10, ) -> dict: """Query the index for the nearest neighbors of the queries embeddings. Parameters ---------- queries_embeddings The queries embeddings. k The number of nearest neighbors to return. """ queries_embeddings = encode_queries(queries) page_ids_to_data = self._load_page_ids_to_data() k = min(k, len(page_ids_to_data)) n_queries = len(queries_embeddings) indices, distances = self.index.query( queries_embeddings, k, query_ef=self.ef_search ) if len(indices) == 0: raise ValueError("Index is empty, add documents before querying.") documents = [ [page_ids_to_data[str(indice)] for indice in query_indices] for query_indices in indices ] page_ids_to_data.close() return { "documents": documents, "distances": distances.reshape(n_queries, -1, k), }