|
--- |
|
language: |
|
- en |
|
license: mit |
|
library_name: Tevatron |
|
tags: |
|
- vidore |
|
datasets: |
|
- Tevatron/docmatix-ir |
|
- HuggingFaceM4/Docmatix |
|
- Tevatron/msmarco-passage-aug |
|
- vidore/colpali_train_set |
|
- Tevatron/wiki-ss-nq |
|
--- |
|
|
|
# DSE-Phi35-Vidore-ft |
|
|
|
DSE-Phi3-Vidore-ft is a bi-encoder model designed to encode document screenshots into dense vectors for document retrieval. The Document Screenshot Embedding ([DSE](https://arxiv.org/abs/2406.11251)) approach captures documents in their original visual format, preserving all information such as text, images, and layout, thus avoiding tedious parsing and potential information loss. |
|
|
|
The model, `Tevatron/dse-phi35-vidore-ft`, is trained using 1/10 of the `Tevatron/docmatix-ir` dataset, a variant of `HuggingFaceM4/Docmatix` specifically adapted for training PDF retrievers with Vision Language Models in open-domain question answering scenarios. For more information on dataset filtering and hard negative mining, refer to the [docmatix-ir](https://huggingface.co/datasets/Tevatron/docmatix-ir/blob/main/README.md) dataset page. |
|
Followed by finetuning on the (vidore)[https://huggingface.co/datasets/vidore/colpali_train_set] training set. The checkpoint is warmed up by text retrieval and webpage retrieval. |
|
|
|
For example, DSE-Phi3-Docmatix-V2 achieves **82.9** nDCG@5 on [ViDoRE](https://huggingface.co/spaces/vidore/vidore-leaderboard) leaderboard. |
|
|
|
## How to train the model from scratch |
|
|
|
Please see https://github.com/texttron/tevatron/tree/main/examples/dse |
|
|
|
## How to Use the Model |
|
|
|
### Load the Model and Processor |
|
|
|
```python |
|
import torch |
|
from transformers import AutoProcessor, AutoModelForCausalLM |
|
|
|
processor = AutoProcessor.from_pretrained('MrLight/dse-phi35-vidore-ft', trust_remote_code=True) |
|
model = AutoModelForCausalLM.from_pretrained('MrLight/dse-phi35-vidore-ft', trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, use_cache=False).to('cuda:0') |
|
|
|
def get_embedding(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
|
sequence_lengths = attention_mask.sum(dim=1) - 1 |
|
bs = last_hidden_state.shape[0] |
|
reps = last_hidden_state[torch.arange(bs, device=last_hidden_state.device), sequence_lengths] |
|
reps = torch.nn.functional.normalize(reps, p=2, dim=-1) |
|
return reps |
|
``` |
|
|
|
### Encode Text Query |
|
|
|
```python |
|
queries = ["query: Where can we see Llama?</s>", "query: What is LLaMA model?</s>"] |
|
query_inputs = processor(queries, return_tensors="pt", padding="longest", max_length=128, truncation=True).to('cuda:0') |
|
with torch.no_grad(): |
|
output = model(**query_inputs, return_dict=True, output_hidden_states=True) |
|
query_embeddings = get_embedding(output.hidden_states[-1], query_inputs["attention_mask"]) |
|
``` |
|
|
|
### Encode Document Screenshot |
|
|
|
```python |
|
from PIL import Image |
|
import requests |
|
from io import BytesIO |
|
|
|
# URLs of the images |
|
url1 = "https://huggingface.co/Tevatron/dse-phi3-docmatix-v2/resolve/main/animal-llama.png" |
|
url2 = "https://huggingface.co/Tevatron/dse-phi3-docmatix-v2/resolve/main/meta-llama.png" |
|
|
|
# Download and open images |
|
response1 = requests.get(url1) |
|
response2 = requests.get(url2) |
|
|
|
passage_image1 = Image.open(BytesIO(response1.content)).resize((1344, 1344)) |
|
passage_image2 = Image.open(BytesIO(response2.content)).resize((1344, 1344)) |
|
|
|
passage_images = [passage_image1, passage_image2] |
|
passage_prompts = ["<|image_1|>\nWhat is shown in this image?</s>", "<|image_2|>\nWhat is shown in this image?</s>"] |
|
|
|
# Process inputs and get embeddings |
|
passage_inputs = processor(passage_prompts, images=passage_images, return_tensors="pt", padding="longest", max_length=4096, truncation=True).to('cuda:0') |
|
passage_inputs['input_ids'] = passage_inputs['input_ids'].squeeze(0) |
|
passage_inputs['attention_mask'] = passage_inputs['attention_mask'].squeeze(0) |
|
passage_inputs['image_sizes'] = passage_inputs['image_sizes'].squeeze(0) |
|
with torch.no_grad(): |
|
output = model(**passage_inputs, return_dict=True, output_hidden_states=True) |
|
doc_embeddings = get_embedding(output.hidden_states[-1], passage_inputs["attention_mask"]) |
|
|
|
``` |
|
|
|
### Compute Similarity |
|
|
|
```python |
|
from torch.nn.functional import cosine_similarity |
|
num_queries = query_embeddings.size(0) |
|
num_passages = doc_embeddings.size(0) |
|
|
|
for i in range(num_queries): |
|
query_embedding = query_embeddings[i].unsqueeze(0) |
|
similarities = cosine_similarity(query_embedding, doc_embeddings) |
|
print(f"Similarities for Query {i+1}: {similarities.cpu().float().numpy()}") |
|
``` |
|
|
|
### Encode Document Text |
|
This DSE checkpoint is warm-up with `Tevatron/msmarco-passage-aug`, thus the model can also effectively encode document as text input. |
|
```python |
|
passage_prompts = [ |
|
"The llama (/ˈlɑːmə/; Spanish pronunciation: [ˈʎama] or [ˈʝama]) (Lama glama) is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era.</s>", |
|
"Llama (acronym for Large Language Model Meta AI, and formerly stylized as LLaMA) is a family of autoregressive large language models (LLMs) released by Meta AI starting in February 2023.[2][3] The latest version is Llama 3.1, released in July 2024.[4]</s>" |
|
] |
|
|
|
passage_inputs = processor(passage_prompts, images=None, return_tensors="pt", padding="longest", max_length=4096, truncation=True).to('cuda:0') |
|
with torch.no_grad(): |
|
output = model(**passage_inputs, return_dict=True, output_hidden_states=True) |
|
doc_embeddings = get_embedding(output.hidden_states[-1], passage_inputs["attention_mask"]) |
|
|
|
for i in range(num_queries): |
|
query_embedding = query_embeddings[i].unsqueeze(0) |
|
similarities = cosine_similarity(query_embedding, doc_embeddings) |
|
print(f"Similarities for Query {i+1}: {similarities.cpu().float().numpy()}") |
|
``` |
|
|
|
### Citation |
|
If you find this checkpoint is helpful, please consider cite Phi3, Docmatix and ViDoRe and our DSE work. |