Spaces:
Sleeping
Sleeping
import os | |
import tempfile | |
from colpali_engine.models.paligemma_colbert_architecture import ColPali | |
from colpali_engine.utils.colpali_processing_utils import process_images | |
from colpali_engine.utils.colpali_processing_utils import process_queries | |
import google.generativeai as genai | |
import numpy as np | |
import pdf2image | |
from PIL import Image | |
import requests | |
import streamlit as st | |
import torch | |
from torch.utils.data import DataLoader | |
from transformers import AutoProcessor | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
SS = st.session_state | |
def initialize_session_state(): | |
keys = [ | |
"colpali_model", | |
"page_images", | |
"retrieved_page_images", | |
"response", | |
] | |
for key in keys: | |
if key not in SS: | |
SS[key] = None | |
def get_device(): | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
elif torch.backends.mps.is_available(): | |
device = torch.device("mps") | |
else: | |
device = torch.device("cpu") | |
return device | |
def get_dtype(device: torch.device): | |
if device == torch.device("cuda"): | |
dtype = torch.bfloat16 | |
elif device == torch.device("mps"): | |
dtype = torch.float32 | |
else: | |
dtype = torch.float32 | |
return dtype | |
def load_colpali_model(): | |
paligemma_model_name = "google/paligemma-3b-mix-448" | |
colpali_model_name = "vidore/colpali" | |
device = get_device() | |
dtype = get_dtype(device) | |
model = ColPali.from_pretrained( | |
paligemma_model_name, | |
torch_dtype=dtype, | |
token=st.secrets["hf_access_token"], | |
).eval() | |
model.load_adapter(colpali_model_name) | |
model.to(device) | |
processor = AutoProcessor.from_pretrained(colpali_model_name) | |
return model, processor | |
def embed_page_images(model, processor, page_images, batch_size=2): | |
dataloader = DataLoader( | |
page_images, | |
batch_size=batch_size, | |
shuffle=False, | |
collate_fn=lambda x: process_images(processor, x), | |
) | |
page_embeddings = [] | |
for batch in dataloader: | |
with torch.no_grad(): | |
batch = {k: v.to(model.device) for k, v in batch.items()} | |
embeddings = model(**batch) | |
page_embeddings.extend(list(torch.unbind(embeddings.to("cpu")))) | |
return np.array(page_embeddings) | |
def embed_query_texts(model, processor, query_texts, batch_size=1): | |
# 448 is from the paligemma resolution we loaded | |
dummy_image = Image.new("RGB", (448, 448), (255, 255, 255)) | |
dataloader = DataLoader( | |
query_texts, | |
batch_size=batch_size, | |
shuffle=False, | |
collate_fn=lambda x: process_queries(processor, x, dummy_image), | |
) | |
query_embeddings = [] | |
for batch in dataloader: | |
with torch.no_grad(): | |
batch = {k: v.to(model.device) for k, v in batch.items()} | |
embeddings = model(**batch) | |
query_embeddings.extend(list(torch.unbind(embeddings.to("cpu")))) | |
return np.array(query_embeddings)[0] | |
def get_pdf_page_images_from_bytes( | |
pdf_bytes: bytes, | |
use_tmp_dir=False, | |
): | |
if use_tmp_dir: | |
with tempfile.TemporaryDirectory() as tmp_path: | |
page_images = pdf2image.convert_from_bytes(pdf_bytes, output_folder=tmp_path) | |
else: | |
page_images = pdf2image.convert_from_bytes(pdf_bytes) | |
return page_images | |
def get_pdf_bytes_from_url(url: str) -> bytes | None: | |
response = requests.get(url) | |
if response.status_code == 200: | |
return response.content | |
else: | |
print(f"failed to fetch {url}") | |
print(response) | |
return None | |
def display_pages(page_images, key): | |
n_cols = st.slider("ncol", min_value=1, max_value=8, value=4, step=1, key=key) | |
cols = st.columns(n_cols) | |
for ii_page, page_image in enumerate(page_images): | |
ii_col = ii_page % n_cols | |
with cols[ii_col]: | |
st.image(page_image) | |
initialize_session_state() | |
if SS["colpali_model"] is None: | |
SS["colpali_model"], SS["processor"] = load_colpali_model() | |
with st.sidebar: | |
url = st.text_input("arxiv url", "https://arxiv.org/pdf/2112.01488.pdf") | |
if st.button("load paper"): | |
pdf_bytes = get_pdf_bytes_from_url(url) | |
SS["page_images"] = get_pdf_page_images_from_bytes(pdf_bytes) | |
if st.button("embed pages"): | |
SS["page_embeddings"] = embed_page_images( | |
SS["colpali_model"], | |
SS["processor"], | |
SS["page_images"], | |
) | |
with st.container(border=True): | |
query = st.text_area("query") | |
top_k = st.slider("num pages to retrieve", min_value=1, max_value=8, value=3, step=1) | |
if st.button("answer query"): | |
SS["query_embeddings"] = embed_query_texts( | |
SS["colpali_model"], | |
SS["processor"], | |
[query], | |
) | |
page_query_scores = [] | |
for ipage in range(len(SS["page_embeddings"])): | |
# for every query token find the max_sim with every page patch | |
patch_query_scores = np.dot( | |
SS['page_embeddings'][ipage], | |
SS["query_embeddings"].T, | |
) | |
max_sim_score = patch_query_scores.max(axis=0).sum() | |
page_query_scores.append(max_sim_score) | |
page_query_scores = np.array(page_query_scores) | |
i_ranked_pages = np.argsort(-page_query_scores) | |
page_images = [] | |
for ii in range(top_k): | |
page_images.append(SS["page_images"][i_ranked_pages[ii]]) | |
SS["retrieved_page_images"] = page_images | |
prompt = [ | |
query + | |
" Think through your answer step by step. " | |
"Support your answer with descriptions of the images. " | |
"Do not infer information that is not in the images.", | |
] + page_images | |
genai.configure(api_key=st.secrets["google_genai_api_key"]) | |
# genai_model_name = "gemini-1.5-flash" | |
genai_model_name = "gemini-1.5-pro" | |
gen_model = genai.GenerativeModel( | |
model_name=genai_model_name, | |
generation_config=genai.GenerationConfig( | |
temperature=0.1, | |
), | |
) | |
response = gen_model.generate_content(prompt) | |
text = response.candidates[0].content.parts[0].text | |
SS["response"] = text | |
if SS["response"] is not None: | |
st.write(SS["response"]) | |
st.header("Retrieved Pages") | |
display_pages(SS["retrieved_page_images"], "retrieved_pages") | |
if SS["page_images"] is not None: | |
st.header("All PDF Pages") | |
display_pages(SS["page_images"], "all_pages") | |