Spaces:
Running
on
T4
Running
on
T4
#!/usr/bin/env python3 | |
import os | |
import torch | |
from torch.utils.data import DataLoader | |
from PIL import Image | |
import numpy as np | |
from typing import cast | |
import asyncio | |
from colpali_engine.models import ColPali, ColPaliProcessor | |
from colpali_engine.utils.torch_utils import get_torch_device | |
from vespa.application import Vespa | |
from vespa.io import VespaQueryResponse | |
from dotenv import load_dotenv | |
from pathlib import Path | |
MAX_QUERY_TERMS = 64 | |
SAVEDIR = Path(__file__) / "output" / "images" | |
load_dotenv() | |
def process_queries(processor, queries, image): | |
inputs = processor( | |
images=[image] * len(queries), text=queries, return_tensors="pt", padding=True | |
) | |
return inputs | |
def display_query_results(query, response, hits=5): | |
query_time = response.json.get("timing", {}).get("searchtime", -1) | |
query_time = round(query_time, 2) | |
count = response.json.get("root", {}).get("fields", {}).get("totalCount", 0) | |
result_text = f"Query text: '{query}', query time {query_time}s, count={count}, top results:\n" | |
for i, hit in enumerate(response.hits[:hits]): | |
title = hit["fields"]["title"] | |
url = hit["fields"]["url"] | |
page = hit["fields"]["page_number"] | |
image = hit["fields"]["image"] | |
_id = hit["id"] | |
score = hit["relevance"] | |
result_text += f"\nPDF Result {i + 1}\n" | |
result_text += f"Title: {title}, page {page+1} with score {score:.2f}\n" | |
result_text += f"URL: {url}\n" | |
result_text += f"ID: {_id}\n" | |
# Optionally, save or display the image | |
# img_data = base64.b64decode(image) | |
# img_path = SAVEDIR / f"{title}.png" | |
# with open(f"{img_path}", "wb") as f: | |
# f.write(img_data) | |
print(result_text) | |
async def query_vespa_default(app, queries, qs): | |
async with app.asyncio(connections=1, total_timeout=120) as session: | |
for idx, query in enumerate(queries): | |
query_embedding = {k: v.tolist() for k, v in enumerate(qs[idx])} | |
response: VespaQueryResponse = await session.query( | |
yql="select documentid,title,url,image,page_number from pdf_page where userInput(@userQuery)", | |
ranking="default", | |
userQuery=query, | |
timeout=120, | |
hits=3, | |
body={"input.query(qt)": query_embedding, "presentation.timing": True}, | |
) | |
assert response.is_successful() | |
display_query_results(query, response) | |
async def query_vespa_nearest_neighbor(app, queries, qs): | |
# Using nearestNeighbor for retrieval | |
target_hits_per_query_tensor = ( | |
20 # this is a hyper parameter that can be tuned for speed versus accuracy | |
) | |
async with app.asyncio(connections=1, total_timeout=180) as session: | |
for idx, query in enumerate(queries): | |
float_query_embedding = {k: v.tolist() for k, v in enumerate(qs[idx])} | |
binary_query_embeddings = dict() | |
for k, v in float_query_embedding.items(): | |
binary_vector = ( | |
np.packbits(np.where(np.array(v) > 0, 1, 0)) | |
.astype(np.int8) | |
.tolist() | |
) | |
binary_query_embeddings[k] = binary_vector | |
if len(binary_query_embeddings) >= MAX_QUERY_TERMS: | |
print( | |
f"Warning: Query has more than {MAX_QUERY_TERMS} terms. Truncating." | |
) | |
break | |
# The mixed tensors used in MaxSim calculations | |
# We use both binary and float representations | |
query_tensors = { | |
"input.query(qtb)": binary_query_embeddings, | |
"input.query(qt)": float_query_embedding, | |
} | |
# The query tensors used in the nearest neighbor calculations | |
for i in range(0, len(binary_query_embeddings)): | |
query_tensors[f"input.query(rq{i})"] = binary_query_embeddings[i] | |
nn = [] | |
for i in range(0, len(binary_query_embeddings)): | |
nn.append( | |
f"({{targetHits:{target_hits_per_query_tensor}}}nearestNeighbor(embedding,rq{i}))" | |
) | |
# We use an OR operator to combine the nearest neighbor operator | |
nn = " OR ".join(nn) | |
response: VespaQueryResponse = await session.query( | |
body={ | |
**query_tensors, | |
"presentation.timing": True, | |
"yql": f"select documentid, title, url, image, page_number from pdf_page where {nn}", | |
"ranking.profile": "retrieval-and-rerank", | |
"timeout": 120, | |
"hits": 3, | |
}, | |
) | |
assert response.is_successful(), response.json | |
display_query_results(query, response) | |
def main(): | |
vespa_app_url = os.environ.get( | |
"VESPA_APP_URL" | |
) # Ensure this is set to your Vespa app URL | |
vespa_cloud_secret_token = os.environ.get("VESPA_CLOUD_SECRET_TOKEN") | |
if not vespa_app_url or not vespa_cloud_secret_token: | |
raise ValueError( | |
"Please set the VESPA_APP_URL and VESPA_CLOUD_SECRET_TOKEN environment variables" | |
) | |
# Instantiate Vespa connection | |
app = Vespa(url=vespa_app_url, vespa_cloud_secret_token=vespa_cloud_secret_token) | |
status_resp = app.get_application_status() | |
if status_resp.status_code != 200: | |
print(f"Failed to connect to Vespa at {vespa_app_url}") | |
return | |
else: | |
print(f"Connected to Vespa at {vespa_app_url}") | |
# Load the model | |
device = get_torch_device("auto") | |
print(f"Using device: {device}") | |
model_name = "vidore/colpali-v1.2" | |
processor_name = "google/paligemma-3b-mix-448" | |
model = cast( | |
ColPali, | |
ColPali.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
device_map=device, | |
), | |
).eval() | |
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(processor_name)) | |
# Create dummy image | |
dummy_image = Image.new("RGB", (448, 448), (255, 255, 255)) | |
# Define queries | |
queries = [ | |
"Percentage of non-fresh water as source?", | |
"Policies related to nature risk?", | |
"How much of produced water is recycled?", | |
] | |
# Obtain query embeddings | |
dataloader = DataLoader( | |
queries, | |
batch_size=1, | |
shuffle=False, | |
collate_fn=lambda x: process_queries(processor, x, dummy_image), | |
) | |
qs = [] | |
for batch_query in dataloader: | |
with torch.no_grad(): | |
batch_query = {k: v.to(model.device) for k, v in batch_query.items()} | |
embeddings_query = model(**batch_query) | |
qs.extend(list(torch.unbind(embeddings_query.to("cpu")))) | |
# Perform queries using default rank profile | |
print("Performing queries using default rank profile:") | |
asyncio.run(query_vespa_default(app, queries, qs)) | |
# Perform queries using nearestNeighbor | |
print("Performing queries using nearestNeighbor:") | |
asyncio.run(query_vespa_nearest_neighbor(app, queries, qs)) | |
if __name__ == "__main__": | |
main() | |