# | |
# Loads all completed shards and finds the most similar vector to a given query vector. | |
from dataclasses import dataclass | |
from itertools import chain | |
import json | |
import os | |
from math import log10 | |
from pathlib import Path | |
from sys import stderr | |
from typing import TypedDict, TypeVar, Any, Callable | |
from datasets import Dataset | |
from import FaissIndex | |
import faiss | |
from huggingface_hub import snapshot_download | |
import numpy as np | |
import numpy.typing as npt | |
import gradio as gr | |
import requests | |
from sentence_transformers import SentenceTransformer | |
import torch | |
try: | |
import spaces | |
except ImportError: | |
spaces = None | |
T = TypeVar("T") | |
U = TypeVar("U") | |
class IndexParameters(TypedDict): | |
recall: float # in this case 10-recall@10 | |
exec_time: float # seconds (raw faiss measure is in milliseconds) | |
param_string: str # pass directly to faiss index | |
class Params(TypedDict): | |
dimensions: int | None | |
normalize: bool | |
optimal_params: list[IndexParameters] | |
class Work: | |
title: str | None | |
abstract: str | None # recovered from abstract_inverted_index | |
authors: list[str] # takes raw_author_name field from Authorship objects | |
journal_name: str | None # takes the display_name field of the first location | |
year: int | |
citations: int | |
doi: str | None | |
def __post_init__(self): | |
self._check_type(self.title, str, nullable=True) | |
self._check_type(self.abstract, str, nullable=True) | |
self._check_type(self.authors, list) | |
for author in self.authors: | |
self._check_type(author, str) | |
self._check_type(self.journal_name, str, nullable=True) | |
self._check_type(self.year, int) | |
self._check_type(self.citations, int) | |
self._check_type(self.doi, str, nullable=True) | |
def from_dict(cls, d: dict) -> "Work": | |
inverted_index: None | dict[str, list[int]] = d["abstract_inverted_index"] | |
abstract = cls._recover_abstract(inverted_index) if inverted_index else None | |
try: | |
journal_name = d["primary_location"]["source"]["display_name"] | |
except (TypeError, KeyError): # key didn't exist or a value was null | |
journal_name = None | |
return cls( | |
title=d["title"], | |
abstract=abstract, | |
authors=[authorship["raw_author_name"] for authorship in d["authorships"]], | |
journal_name=journal_name, | |
year=d["publication_year"], | |
citations=d["cited_by_count"], | |
doi=d["doi"], | |
) | |
def get_raw_fields() -> list[str]: | |
return [ | |
"title", | |
"abstract_inverted_index", | |
"authorships", | |
"primary_location", | |
"publication_year", | |
"cited_by_count", | |
"doi" | |
] | |
def _check_type(v: Any, t: type, nullable: bool = False): | |
if not ((nullable and v is None) or isinstance(v, t)): | |
v_type_name = f"{type(v)}" if v is not None else "None" | |
t_name = f"{t}" | |
if nullable: | |
t_name += " | None" | |
raise ValueError(f"expected {t_name}, got {v_type_name}") | |
def _recover_abstract(inverted_index: dict[str, list[int]]) -> str: | |
abstract_size = max(max(locs) for locs in inverted_index.values())+1 | |
abstract_words: list[str | None] = [None] * abstract_size | |
for word, locs in inverted_index.items(): | |
for loc in locs: | |
abstract_words[loc] = word | |
return " ".join(word for word in abstract_words if word is not None) | |
def get_env_var(key: str, type_: Callable[[str], T] = str, default: U = None) -> T | U: | |
var = os.getenv(key) | |
if var is not None: | |
var = type_(var) | |
else: | |
var = default | |
return var | |
def get_model( | |
model_name: str, params_dir: Path, trust_remote_code: bool | |
) -> tuple[bool, SentenceTransformer]: | |
# TODO: params["normalize"] for models like all-MiniLM-v6, which already normalize? | |
with open(params_dir / "params.json", "r") as f: | |
params: Params = json.load(f) | |
return params["normalize"], SentenceTransformer( | |
model_name, | |
trust_remote_code=trust_remote_code, | |
truncate_dim=params["dimensions"] | |
) | |
def open_ondisk(dir: Path) -> faiss.Index: | |
# without IO_FLAG_ONDISK_SAME_DIR, read_index gets on-disk indices in working dir | |
return faiss.read_index(str(dir / "index.faiss"), faiss.IO_FLAG_ONDISK_SAME_DIR) | |
def get_index(dir: Path, search_time_s: float) -> Dataset: | |
# NOTE: use a private attr to load the index with IO_FLAG_ONDISK_SAME_DIR! | |
index: Dataset = Dataset.from_parquet(str(dir / "ids.parquet")) # type: ignore | |
faiss_index = open_ondisk(dir) | |
index._indexes["embedding"] = FaissIndex(None, None, None, faiss_index) | |
with open(dir / "params.json", "r") as f: | |
params: Params = json.load(f) | |
under = [p for p in params["optimal_params"] if p["exec_time"] < search_time_s] | |
optimal = max(under, key=(lambda p: p["recall"])) | |
optimal_string = optimal["param_string"] | |
ps = faiss.ParameterSpace() | |
ps.initialize(faiss_index) | |
ps.set_index_parameters(faiss_index, optimal_string) | |
return index | |
def execute_request(ids: list[str], mailto: str | None) -> list[Work]: | |
if len(ids) > 100: | |
raise ValueError("querying /works endpoint with more than 100 works") | |
# query with the /works endpoint with a specific list of IDs and fields | |
search_filter = f'openalex_id:{"|".join(ids)}' | |
search_select = ",".join(["id"] + Work.get_raw_fields()) | |
params = {"filter": search_filter, "select": search_select, "per-page": 100} | |
if mailto is not None: | |
params["mailto"] = mailto | |
response = requests.get("", params) | |
response.raise_for_status() | |
# the response is not necessarily ordered, so order them | |
response = {d["id"]: Work.from_dict(d) for d in response.json()["results"]} | |
return [response[id_] for id_ in ids] | |
def collapse_newlines(x: str) -> str: | |
return x.replace("\r\n", " ").replace("\n", " ").replace("\r", " ") | |
def format_response( | |
neighbors: list[Work], distances: list[float], calculate_similarity: bool = False | |
) -> str: | |
result_string = "" | |
for work, distance in zip(neighbors, distances): | |
entry_string = "## " | |
if work.title and work.doi: | |
entry_string += f"[{collapse_newlines(work.title)}]({work.doi})" | |
elif work.title: | |
entry_string += f"{collapse_newlines(work.title)}" | |
elif work.doi: | |
entry_string += f"[No title]({work.doi})" | |
else: | |
entry_string += "No title" | |
entry_string += "\n\n**" | |
if len(work.authors) >= 3: # truncate to 3 if necessary | |
entry_string += ", ".join(work.authors[:3]) + ", ..." | |
elif work.authors: | |
entry_string += ", ".join(work.authors) | |
else: | |
entry_string += "No author" | |
entry_string += f", {work.year}" | |
if work.journal_name: | |
entry_string += " - " + work.journal_name | |
entry_string += "**\n\n" | |
if work.abstract: | |
abstract = collapse_newlines(work.abstract) | |
if len(abstract) > 2000: | |
abstract = abstract[:2000] + "..." | |
entry_string += abstract | |
else: | |
entry_string += "No abstract" | |
entry_string += "\n\n*" | |
meta: list[tuple[str, str]] = [] | |
if work.citations: # don't tack "Cited-by count: 0" on someones's work | |
meta.append(("Cited-by count", str(work.citations))) | |
if work.doi: | |
meta.append(("DOI", work.doi.replace("", ""))) | |
if calculate_similarity: | |
# if query and result are unit vectors, the cosine sim is 1 - dist^2 / 2 | |
meta.append(("Similarity", f"{1 - distance / 2:.2f}")) # faiss gives dist^2 | |
else: | |
meta.append(("Distance", f"{distance:.2f}")) | |
entry_string += (" " * 4).join(": ".join(tup) for tup in meta) | |
entry_string += "*\n" | |
result_string += entry_string | |
return result_string | |
def main(): | |
# TODO: figure out some better defaults? | |
model_name = get_env_var("MODEL_NAME", default="all-MiniLM-L6-v2") | |
prompt_name = get_env_var("PROMPT_NAME") | |
trust_remote_code = get_env_var("TRUST_REMOTE_CODE", bool, default=False) | |
fp16 = get_env_var("FP16", bool, default=False) | |
dir = get_env_var("DIR", Path) | |
repo = get_env_var("REPO", str) | |
search_time_s = get_env_var("SEARCH_TIME_S", float, default=1) | |
k = get_env_var("K", int, default=20) # TODO: can't go higher than 20 yet | |
mailto = get_env_var("MAILTO", str, None) | |
if dir is None: # acquire the index if it's not local | |
if repo is None: | |
repo = "colonelwatch/abstracts-faiss" | |
dir = Path(snapshot_download(repo, repo_type="dataset")) / "index" | |
elif repo is not None: | |
print('warning: used "REPO" and also "DIR", ignoring "REPO"...', file=stderr) | |
normalize, model = get_model(model_name, dir, trust_remote_code) | |
index = get_index(dir, search_time_s) | |
# follow model.encode logic for acquiring the prompt | |
if prompt_name is None and model.default_prompt_name is not None: | |
prompt_name = model.default_prompt_name | |
if not isinstance(prompt_name, str): | |
raise TypeError("invalid prompt name type") | |
prompt: str | None = model.prompts[prompt_name] if prompt_name is not None else None | |
# follow model.encode logic for setting extra_features | |
extra_features: dict[str, Any] = {} | |
if prompt is not None: | |
tokenized = model.tokenize([prompt]) | |
if "input_ids" in tokenized: | |
extra_features["prompt_length"] = tokenized["input_ids"].shape[-1] - 1 | |
model.eval() | |
if torch.cuda.is_available(): | |
model = model.half().cuda() if fp16 else model.bfloat16().cuda() | |
# TODO: if huggingface datasets exposes an fp16 gpu option, use it here | |
elif fp16: | |
print('warning: used "FP16" on CPU-only system, ignoring...', file=stderr) | |
model.compile(mode="reduce-overhead") | |
def encode_tokens(features: dict[str, Any]) -> npt.NDArray[np.float32]: | |
# Tokenize (which yields a dict) then do a non-blocking transfer | |
features = { | |
k:, non_blocking=True) for k, v in features.items() | |
} | extra_features | |
with torch.no_grad(): | |
out_features = model.forward(features) | |
embeddings = out_features["sentence_embedding"] | |
embeddings = embeddings[0] | |
if model.truncate_dim: | |
embeddings = embeddings[:model.truncate_dim] | |
if normalize: | |
embeddings = torch.nn.functional.normalize(embeddings, dim=0) | |
return embeddings.cpu().float().numpy() # faiss expected CPU float32 numpy arr | |
if spaces: | |
encode_tokens = spaces.GPU(encode_tokens) | |
def encode_string(query: str) -> npt.NDArray[np.float32]: | |
if prompt: | |
query = prompt + query | |
tokens = model.tokenize([query]) | |
return encode_tokens(tokens) | |
def search(query: str) -> str: | |
query_embedding = encode_string(query) | |
distances, faiss_ids ="embedding", query_embedding, k) | |
openalex_ids = index[faiss_ids]["id"] | |
works = execute_request(openalex_ids, mailto) | |
return format_response(works, distances, calculate_similarity=normalize) | |
with gr.Blocks() as demo: | |
# figure out the words to describe the quantity | |
n_entries = len(index) | |
n_digits = int(log10(n_entries)) | |
divisor, postfix = { | |
0: (1, ""), | |
1: (1000, " thousand"), | |
2: (1000000, " million"), | |
3: (1000000000, " billion"), | |
}[n_digits // 3] | |
significand = n_entries / divisor | |
significand = round(significand, 1 if (n_digits % 3 == 1) else None) | |
quantity = str(significand) + postfix | |
# split the (huggingface) model name and get the link | |
model_publisher, model_human_name = model_name.split("/") | |
model_link = f"{model_publisher}/{model_human_name}" | |
gr.Markdown("# abstracts-index") | |
gr.Markdown( | |
f"Explore {quantity} academic publications selected from the " | |
"[OpenAlex]( dataset (as of January 1st, 2025) with " | |
"semantic search, not keyword search. This project is an index of the " | |
"embeddings generated from their titles and abstracts. The embeddings were " | |
f"generated using the [{model_human_name}]({model_link}) model, and the " | |
"index was built using the " | |
"[faiss]( module. The build " | |
"scripts and more information available at the main repo " | |
"[abstracts-search]( on " | |
"Github." | |
) | |
query = gr.Textbox( | |
lines=1, placeholder="Enter your query here", show_label=False | |
) | |
btn = gr.Button("Search") | |
results = gr.Markdown( | |
latex_delimiters=[ | |
{"left": "$$", "right": "$$", "display": False}, | |
{"left": "$", "right": "$", "display": False}, | |
], | |
container=True, | |
) | |
# NOTE: ZeroGPU doesn't seem to support batching | |
query.submit(search, inputs=[query], outputs=[results]) | |, inputs=[query], outputs=[results]) | |
demo.queue() | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |