abstracts-index / app.py
colonelwatch's picture
Minimize ZeroGPU utilization time by cutting out SentenceTransformer overhead
303669c
# app.py
# Loads all completed shards and finds the most similar vector to a given query vector.
from dataclasses import dataclass
from itertools import chain
import json
import os
from math import log10
from pathlib import Path
from sys import stderr
from typing import TypedDict, TypeVar, Any, Callable
from datasets import Dataset
from datasets.search import FaissIndex
import faiss
from huggingface_hub import snapshot_download
import numpy as np
import numpy.typing as npt
import gradio as gr
import requests
from sentence_transformers import SentenceTransformer
import torch
try:
import spaces
except ImportError:
spaces = None
T = TypeVar("T")
U = TypeVar("U")
class IndexParameters(TypedDict):
recall: float # in this case 10-recall@10
exec_time: float # seconds (raw faiss measure is in milliseconds)
param_string: str # pass directly to faiss index
class Params(TypedDict):
dimensions: int | None
normalize: bool
optimal_params: list[IndexParameters]
@dataclass
class Work:
title: str | None
abstract: str | None # recovered from abstract_inverted_index
authors: list[str] # takes raw_author_name field from Authorship objects
journal_name: str | None # takes the display_name field of the first location
year: int
citations: int
doi: str | None
def __post_init__(self):
self._check_type(self.title, str, nullable=True)
self._check_type(self.abstract, str, nullable=True)
self._check_type(self.authors, list)
for author in self.authors:
self._check_type(author, str)
self._check_type(self.journal_name, str, nullable=True)
self._check_type(self.year, int)
self._check_type(self.citations, int)
self._check_type(self.doi, str, nullable=True)
@classmethod
def from_dict(cls, d: dict) -> "Work":
inverted_index: None | dict[str, list[int]] = d["abstract_inverted_index"]
abstract = cls._recover_abstract(inverted_index) if inverted_index else None
try:
journal_name = d["primary_location"]["source"]["display_name"]
except (TypeError, KeyError): # key didn't exist or a value was null
journal_name = None
return cls(
title=d["title"],
abstract=abstract,
authors=[authorship["raw_author_name"] for authorship in d["authorships"]],
journal_name=journal_name,
year=d["publication_year"],
citations=d["cited_by_count"],
doi=d["doi"],
)
@staticmethod
def get_raw_fields() -> list[str]:
return [
"title",
"abstract_inverted_index",
"authorships",
"primary_location",
"publication_year",
"cited_by_count",
"doi"
]
@staticmethod
def _check_type(v: Any, t: type, nullable: bool = False):
if not ((nullable and v is None) or isinstance(v, t)):
v_type_name = f"{type(v)}" if v is not None else "None"
t_name = f"{t}"
if nullable:
t_name += " | None"
raise ValueError(f"expected {t_name}, got {v_type_name}")
@staticmethod
def _recover_abstract(inverted_index: dict[str, list[int]]) -> str:
abstract_size = max(max(locs) for locs in inverted_index.values())+1
abstract_words: list[str | None] = [None] * abstract_size
for word, locs in inverted_index.items():
for loc in locs:
abstract_words[loc] = word
return " ".join(word for word in abstract_words if word is not None)
def get_env_var(key: str, type_: Callable[[str], T] = str, default: U = None) -> T | U:
var = os.getenv(key)
if var is not None:
var = type_(var)
else:
var = default
return var
def get_model(
model_name: str, params_dir: Path, trust_remote_code: bool
) -> tuple[bool, SentenceTransformer]:
# TODO: params["normalize"] for models like all-MiniLM-v6, which already normalize?
with open(params_dir / "params.json", "r") as f:
params: Params = json.load(f)
return params["normalize"], SentenceTransformer(
model_name,
trust_remote_code=trust_remote_code,
truncate_dim=params["dimensions"]
)
def open_ondisk(dir: Path) -> faiss.Index:
# without IO_FLAG_ONDISK_SAME_DIR, read_index gets on-disk indices in working dir
return faiss.read_index(str(dir / "index.faiss"), faiss.IO_FLAG_ONDISK_SAME_DIR)
def get_index(dir: Path, search_time_s: float) -> Dataset:
# NOTE: use a private attr to load the index with IO_FLAG_ONDISK_SAME_DIR!
index: Dataset = Dataset.from_parquet(str(dir / "ids.parquet")) # type: ignore
faiss_index = open_ondisk(dir)
index._indexes["embedding"] = FaissIndex(None, None, None, faiss_index)
with open(dir / "params.json", "r") as f:
params: Params = json.load(f)
under = [p for p in params["optimal_params"] if p["exec_time"] < search_time_s]
optimal = max(under, key=(lambda p: p["recall"]))
optimal_string = optimal["param_string"]
ps = faiss.ParameterSpace()
ps.initialize(faiss_index)
ps.set_index_parameters(faiss_index, optimal_string)
return index
def execute_request(ids: list[str], mailto: str | None) -> list[Work]:
if len(ids) > 100:
raise ValueError("querying /works endpoint with more than 100 works")
# query with the /works endpoint with a specific list of IDs and fields
search_filter = f'openalex_id:{"|".join(ids)}'
search_select = ",".join(["id"] + Work.get_raw_fields())
params = {"filter": search_filter, "select": search_select, "per-page": 100}
if mailto is not None:
params["mailto"] = mailto
response = requests.get("https://api.openalex.org/works", params)
response.raise_for_status()
# the response is not necessarily ordered, so order them
response = {d["id"]: Work.from_dict(d) for d in response.json()["results"]}
return [response[id_] for id_ in ids]
def collapse_newlines(x: str) -> str:
return x.replace("\r\n", " ").replace("\n", " ").replace("\r", " ")
def format_response(
neighbors: list[Work], distances: list[float], calculate_similarity: bool = False
) -> str:
result_string = ""
for work, distance in zip(neighbors, distances):
entry_string = "## "
if work.title and work.doi:
entry_string += f"[{collapse_newlines(work.title)}]({work.doi})"
elif work.title:
entry_string += f"{collapse_newlines(work.title)}"
elif work.doi:
entry_string += f"[No title]({work.doi})"
else:
entry_string += "No title"
entry_string += "\n\n**"
if len(work.authors) >= 3: # truncate to 3 if necessary
entry_string += ", ".join(work.authors[:3]) + ", ..."
elif work.authors:
entry_string += ", ".join(work.authors)
else:
entry_string += "No author"
entry_string += f", {work.year}"
if work.journal_name:
entry_string += " - " + work.journal_name
entry_string += "**\n\n"
if work.abstract:
abstract = collapse_newlines(work.abstract)
if len(abstract) > 2000:
abstract = abstract[:2000] + "..."
entry_string += abstract
else:
entry_string += "No abstract"
entry_string += "\n\n*"
meta: list[tuple[str, str]] = []
if work.citations: # don't tack "Cited-by count: 0" on someones's work
meta.append(("Cited-by count", str(work.citations)))
if work.doi:
meta.append(("DOI", work.doi.replace("https://doi.org/", "")))
if calculate_similarity:
# if query and result are unit vectors, the cosine sim is 1 - dist^2 / 2
meta.append(("Similarity", f"{1 - distance / 2:.2f}")) # faiss gives dist^2
else:
meta.append(("Distance", f"{distance:.2f}"))
entry_string += ("&nbsp;" * 4).join(": ".join(tup) for tup in meta)
entry_string += "*\n"
result_string += entry_string
return result_string
def main():
# TODO: figure out some better defaults?
model_name = get_env_var("MODEL_NAME", default="all-MiniLM-L6-v2")
prompt_name = get_env_var("PROMPT_NAME")
trust_remote_code = get_env_var("TRUST_REMOTE_CODE", bool, default=False)
fp16 = get_env_var("FP16", bool, default=False)
dir = get_env_var("DIR", Path)
repo = get_env_var("REPO", str)
search_time_s = get_env_var("SEARCH_TIME_S", float, default=1)
k = get_env_var("K", int, default=20) # TODO: can't go higher than 20 yet
mailto = get_env_var("MAILTO", str, None)
if dir is None: # acquire the index if it's not local
if repo is None:
repo = "colonelwatch/abstracts-faiss"
dir = Path(snapshot_download(repo, repo_type="dataset")) / "index"
elif repo is not None:
print('warning: used "REPO" and also "DIR", ignoring "REPO"...', file=stderr)
normalize, model = get_model(model_name, dir, trust_remote_code)
index = get_index(dir, search_time_s)
# follow model.encode logic for acquiring the prompt
if prompt_name is None and model.default_prompt_name is not None:
prompt_name = model.default_prompt_name
if not isinstance(prompt_name, str):
raise TypeError("invalid prompt name type")
prompt: str | None = model.prompts[prompt_name] if prompt_name is not None else None
# follow model.encode logic for setting extra_features
extra_features: dict[str, Any] = {}
if prompt is not None:
tokenized = model.tokenize([prompt])
if "input_ids" in tokenized:
extra_features["prompt_length"] = tokenized["input_ids"].shape[-1] - 1
model.eval()
if torch.cuda.is_available():
model = model.half().cuda() if fp16 else model.bfloat16().cuda()
# TODO: if huggingface datasets exposes an fp16 gpu option, use it here
elif fp16:
print('warning: used "FP16" on CPU-only system, ignoring...', file=stderr)
model.compile(mode="reduce-overhead")
def encode_tokens(features: dict[str, Any]) -> npt.NDArray[np.float32]:
# Tokenize (which yields a dict) then do a non-blocking transfer
features = {
k: v.to(model.device, non_blocking=True) for k, v in features.items()
} | extra_features
with torch.no_grad():
out_features = model.forward(features)
embeddings = out_features["sentence_embedding"]
embeddings = embeddings[0]
if model.truncate_dim:
embeddings = embeddings[:model.truncate_dim]
if normalize:
embeddings = torch.nn.functional.normalize(embeddings, dim=0)
return embeddings.cpu().float().numpy() # faiss expected CPU float32 numpy arr
if spaces:
encode_tokens = spaces.GPU(encode_tokens)
def encode_string(query: str) -> npt.NDArray[np.float32]:
if prompt:
query = prompt + query
tokens = model.tokenize([query])
return encode_tokens(tokens)
def search(query: str) -> str:
query_embedding = encode_string(query)
distances, faiss_ids = index.search("embedding", query_embedding, k)
openalex_ids = index[faiss_ids]["id"]
works = execute_request(openalex_ids, mailto)
return format_response(works, distances, calculate_similarity=normalize)
with gr.Blocks() as demo:
# figure out the words to describe the quantity
n_entries = len(index)
n_digits = int(log10(n_entries))
divisor, postfix = {
0: (1, ""),
1: (1000, " thousand"),
2: (1000000, " million"),
3: (1000000000, " billion"),
}[n_digits // 3]
significand = n_entries / divisor
significand = round(significand, 1 if (n_digits % 3 == 1) else None)
quantity = str(significand) + postfix
# split the (huggingface) model name and get the link
model_publisher, model_human_name = model_name.split("/")
model_link = f"https://huggingface.co/{model_publisher}/{model_human_name}"
gr.Markdown("# abstracts-index")
gr.Markdown(
f"Explore {quantity} academic publications selected from the "
"[OpenAlex](https://openalex.org) dataset (as of January 1st, 2025) with "
"semantic search, not keyword search. This project is an index of the "
"embeddings generated from their titles and abstracts. The embeddings were "
f"generated using the [{model_human_name}]({model_link}) model, and the "
"index was built using the "
"[faiss](https://github.com/facebookresearch/faiss) module. The build "
"scripts and more information available at the main repo "
"[abstracts-search](https://github.com/colonelwatch/abstracts-search) on "
"Github."
)
query = gr.Textbox(
lines=1, placeholder="Enter your query here", show_label=False
)
btn = gr.Button("Search")
results = gr.Markdown(
latex_delimiters=[
{"left": "$$", "right": "$$", "display": False},
{"left": "$", "right": "$", "display": False},
],
container=True,
)
# NOTE: ZeroGPU doesn't seem to support batching
query.submit(search, inputs=[query], outputs=[results])
btn.click(search, inputs=[query], outputs=[results])
demo.queue()
demo.launch()
if __name__ == "__main__":
main()