MrLight commited on
Commit
18e210f
1 Parent(s): 8f2f2a2

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +123 -0
README.md ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: mit
5
+ library_name: Tevatron
6
+ tags:
7
+ - vidore
8
+ datasets:
9
+ - Tevatron/docmatix-ir
10
+ - HuggingFaceM4/Docmatix
11
+ - Tevatron/msmarco-passage-aug
12
+ - vidore/colpali_train_set
13
+ - Tevatron/wiki-ss-nq
14
+ ---
15
+
16
+ # DSE-Phi35-Vidore-ft
17
+
18
+ DSE-Phi3-Docmatix-V2 is a bi-encoder model designed to encode document screenshots into dense vectors for document retrieval. The Document Screenshot Embedding ([DSE](https://arxiv.org/abs/2406.11251)) approach captures documents in their original visual format, preserving all information such as text, images, and layout, thus avoiding tedious parsing and potential information loss.
19
+
20
+ The model, `Tevatron/dse-phi3-docmatix-v2`, is trained using 1/10 of the `Tevatron/docmatix-ir` dataset, a variant of `HuggingFaceM4/Docmatix` specifically adapted for training PDF retrievers with Vision Language Models in open-domain question answering scenarios. For more information on dataset filtering and hard negative mining, refer to the [docmatix-ir](https://huggingface.co/datasets/Tevatron/docmatix-ir/blob/main/README.md) dataset page.
21
+
22
+ DSE has strong zero-shot effectiveness for document retrieval both with visual input and text input.
23
+ For example, DSE-Phi3-Docmatix-V2 achieves **82.9** nDCG@5 on [ViDoRE](https://huggingface.co/spaces/vidore/vidore-leaderboard) leaderboard in **zero-shot setting** (without finetuning with ViDoRe training data).
24
+
25
+ ## How to train the model from scratch
26
+
27
+ Please see https://github.com/texttron/tevatron/tree/main/examples/dse
28
+
29
+ ## How to Use the Model
30
+
31
+ ### Load the Model and Processor
32
+
33
+ ```python
34
+ import torch
35
+ from transformers import AutoProcessor, AutoModelForCausalLM
36
+
37
+ processor = AutoProcessor.from_pretrained('MrLight/dse-phi35-vidore-ft', trust_remote_code=True)
38
+ model = AutoModelForCausalLM.from_pretrained('MrLight/dse-phi35-vidore-ft', trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, use_cache=False).to('cuda:0')
39
+
40
+ def get_embedding(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
41
+ sequence_lengths = attention_mask.sum(dim=1) - 1
42
+ bs = last_hidden_state.shape[0]
43
+ reps = last_hidden_state[torch.arange(bs, device=last_hidden_state.device), sequence_lengths]
44
+ reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
45
+ return reps
46
+ ```
47
+
48
+ ### Encode Text Query
49
+
50
+ ```python
51
+ queries = ["query: Where can we see Llama?</s>", "query: What is LLaMA model?</s>"]
52
+ query_inputs = processor(queries, return_tensors="pt", padding="longest", max_length=128, truncation=True).to('cuda:0')
53
+ with torch.no_grad():
54
+ output = model(**query_inputs, return_dict=True, output_hidden_states=True)
55
+ query_embeddings = get_embedding(output.hidden_states[-1], query_inputs["attention_mask"])
56
+ ```
57
+
58
+ ### Encode Document Screenshot
59
+
60
+ ```python
61
+ from PIL import Image
62
+ import requests
63
+ from io import BytesIO
64
+
65
+ # URLs of the images
66
+ url1 = "https://huggingface.co/Tevatron/dse-phi3-docmatix-v2/resolve/main/animal-llama.png"
67
+ url2 = "https://huggingface.co/Tevatron/dse-phi3-docmatix-v2/resolve/main/meta-llama.png"
68
+
69
+ # Download and open images
70
+ response1 = requests.get(url1)
71
+ response2 = requests.get(url2)
72
+
73
+ passage_image1 = Image.open(BytesIO(response1.content)).resize((1344, 1344))
74
+ passage_image2 = Image.open(BytesIO(response2.content)).resize((1344, 1344))
75
+
76
+ passage_images = [passage_image1, passage_image2]
77
+ passage_prompts = ["<|image_1|>\nWhat is shown in this image?</s>", "<|image_2|>\nWhat is shown in this image?</s>"]
78
+
79
+ # Process inputs and get embeddings
80
+ passage_inputs = processor(passage_prompts, images=passage_images, return_tensors="pt", padding="longest", max_length=4096, truncation=True).to('cuda:0')
81
+ passage_inputs['input_ids'] = passage_inputs['input_ids'].squeeze(0)
82
+ passage_inputs['attention_mask'] = passage_inputs['attention_mask'].squeeze(0)
83
+ passage_inputs['image_sizes'] = passage_inputs['image_sizes'].squeeze(0)
84
+ with torch.no_grad():
85
+ output = model(**passage_inputs, return_dict=True, output_hidden_states=True)
86
+ doc_embeddings = get_embedding(output.hidden_states[-1], passage_inputs["attention_mask"])
87
+
88
+ ```
89
+
90
+ ### Compute Similarity
91
+
92
+ ```python
93
+ from torch.nn.functional import cosine_similarity
94
+ num_queries = query_embeddings.size(0)
95
+ num_passages = doc_embeddings.size(0)
96
+
97
+ for i in range(num_queries):
98
+ query_embedding = query_embeddings[i].unsqueeze(0)
99
+ similarities = cosine_similarity(query_embedding, doc_embeddings)
100
+ print(f"Similarities for Query {i+1}: {similarities.cpu().float().numpy()}")
101
+ ```
102
+
103
+ ### Encode Document Text
104
+ This DSE checkpoint is warm-up with `Tevatron/msmarco-passage-aug`, thus the model can also effectively encode document as text input.
105
+ ```python
106
+ passage_prompts = [
107
+ "The llama (/ˈlɑːmə/; Spanish pronunciation: [ˈʎama] or [ˈʝama]) (Lama glama) is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era.</s>",
108
+ "Llama (acronym for Large Language Model Meta AI, and formerly stylized as LLaMA) is a family of autoregressive large language models (LLMs) released by Meta AI starting in February 2023.[2][3] The latest version is Llama 3.1, released in July 2024.[4]</s>"
109
+ ]
110
+
111
+ passage_inputs = processor(passage_prompts, images=None, return_tensors="pt", padding="longest", max_length=4096, truncation=True).to('cuda:0')
112
+ with torch.no_grad():
113
+ output = model(**passage_inputs, return_dict=True, output_hidden_states=True)
114
+ doc_embeddings = get_embedding(output.hidden_states[-1], passage_inputs["attention_mask"])
115
+
116
+ for i in range(num_queries):
117
+ query_embedding = query_embeddings[i].unsqueeze(0)
118
+ similarities = cosine_similarity(query_embedding, doc_embeddings)
119
+ print(f"Similarities for Query {i+1}: {similarities.cpu().float().numpy()}")
120
+ ```
121
+
122
+ ### Citation
123
+ If you find this checkpoint is helpful, please consider cite Phi3, Docmatix and ViDoRe and our DSE work.