File size: 4,064 Bytes
7988f40
349b5c2
 
 
 
 
 
 
 
7988f40
349b5c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7988f40
349b5c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7988f40
349b5c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import spaces
import torch
from PIL import Image
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration

device = "cuda" if torch.cuda.is_available() else "cpu"

min_pixels = 1 * 28 * 28
max_pixels = 2560 * 28 * 28


processor = AutoProcessor.from_pretrained(
    "MrLight/dse-qwen2-2b-mrl-v1", min_pixels=min_pixels, max_pixels=max_pixels
)
model = (
    Qwen2VLForConditionalGeneration.from_pretrained(
        "MrLight/dse-qwen2-2b-mrl-v1",
        # attn_implementation="eager",
        attn_implementation="flash_attention_2"
        if device == "cuda"
        else "eager",  # flash_attn is required but is a pain to install on spaces
        torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
    )
    .to(device)
    .eval()
)
processor.tokenizer.padding_side = "left"
model.padding_side = "left"


def get_embedding(last_hidden_state: torch.Tensor, dimension: int):
    reps = last_hidden_state[:, -1]
    reps = torch.nn.functional.normalize(reps[:, :dimension], p=2, dim=-1)
    return reps.to(torch.float32).cpu().numpy()


@spaces.GPU
def encode_queries(queries: list):
    if isinstance(queries, str):
        queries = [queries]
    query_messages = []
    for query in queries:
        message = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "image": Image.new("RGB", (28, 28)),
                        "resized_height": 1,
                        "resized_width": 1,
                    },  # need a dummy image here for an easier process.
                    {"type": "text", "text": f"Query: {query}"},
                ],
            }
        ]
        query_messages.append(message)
    query_texts = [
        processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
        + "<|endoftext|>"
        for msg in query_messages
    ]
    query_image_inputs, query_video_inputs = process_vision_info(query_messages)
    query_inputs = processor(
        text=query_texts,
        images=query_image_inputs,
        videos=query_video_inputs,
        padding="longest",
        return_tensors="pt",
    ).to(device)
    query_inputs = model.prepare_inputs_for_generation(**query_inputs, use_cache=False)
    with torch.no_grad():
        output = model(**query_inputs, return_dict=True, output_hidden_states=True)
        query_embeddings = get_embedding(
            output.hidden_states[-1], 1536
        )  # adjust dimensionality for efficiency trade-off, e.g. 512
    return query_embeddings


@spaces.GPU
def encode_images(images: list):
    if isinstance(images, Image.Image):
        images = [images]
    doc_messages = []
    for image in images:
        message = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "image": image,
                    },  #'resized_height':680 , 'resized_width':680} # adjust the image size for efficiency trade-off
                    {"type": "text", "text": "What is shown in this image?"},
                ],
            }
        ]
        doc_messages.append(message)
    doc_texts = [
        processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
        + "<|endoftext|>"
        for msg in doc_messages
    ]
    doc_image_inputs, doc_video_inputs = process_vision_info(doc_messages)
    doc_inputs = processor(
        text=doc_texts,
        images=doc_image_inputs,
        videos=doc_video_inputs,
        padding="longest",
        return_tensors="pt",
    ).to(device)
    doc_inputs = model.prepare_inputs_for_generation(**doc_inputs, use_cache=False)
    with torch.no_grad():
        output = model(**doc_inputs, return_dict=True, output_hidden_states=True)
    doc_embeddings = get_embedding(
        output.hidden_states[-1], 1536
    )  # adjust dimensionality for efficiency trade-off e.g. 512
    return doc_embeddings