Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import torch | |
from PIL import Image | |
from qwen_vl_utils import process_vision_info | |
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
min_pixels = 1 * 28 * 28 | |
max_pixels = 2560 * 28 * 28 | |
processor = AutoProcessor.from_pretrained( | |
"MrLight/dse-qwen2-2b-mrl-v1", min_pixels=min_pixels, max_pixels=max_pixels | |
) | |
model = ( | |
Qwen2VLForConditionalGeneration.from_pretrained( | |
"MrLight/dse-qwen2-2b-mrl-v1", | |
# attn_implementation="eager", | |
attn_implementation="flash_attention_2" | |
if device == "cuda" | |
else "eager", # flash_attn is required but is a pain to install on spaces | |
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32, | |
) | |
.to(device) | |
.eval() | |
) | |
processor.tokenizer.padding_side = "left" | |
model.padding_side = "left" | |
def get_embedding(last_hidden_state: torch.Tensor, dimension: int): | |
reps = last_hidden_state[:, -1] | |
reps = torch.nn.functional.normalize(reps[:, :dimension], p=2, dim=-1) | |
return reps.to(torch.float32).cpu().numpy() | |
def encode_queries(queries: list): | |
if isinstance(queries, str): | |
queries = [queries] | |
query_messages = [] | |
for query in queries: | |
message = [ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "image", | |
"image": Image.new("RGB", (28, 28)), | |
"resized_height": 1, | |
"resized_width": 1, | |
}, # need a dummy image here for an easier process. | |
{"type": "text", "text": f"Query: {query}"}, | |
], | |
} | |
] | |
query_messages.append(message) | |
query_texts = [ | |
processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) | |
+ "<|endoftext|>" | |
for msg in query_messages | |
] | |
query_image_inputs, query_video_inputs = process_vision_info(query_messages) | |
query_inputs = processor( | |
text=query_texts, | |
images=query_image_inputs, | |
videos=query_video_inputs, | |
padding="longest", | |
return_tensors="pt", | |
).to(device) | |
query_inputs = model.prepare_inputs_for_generation(**query_inputs, use_cache=False) | |
with torch.no_grad(): | |
output = model(**query_inputs, return_dict=True, output_hidden_states=True) | |
query_embeddings = get_embedding( | |
output.hidden_states[-1], 1536 | |
) # adjust dimensionality for efficiency trade-off, e.g. 512 | |
return query_embeddings | |
def encode_images(images: list): | |
if isinstance(images, Image.Image): | |
images = [images] | |
doc_messages = [] | |
for image in images: | |
message = [ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "image", | |
"image": image, | |
}, #'resized_height':680 , 'resized_width':680} # adjust the image size for efficiency trade-off | |
{"type": "text", "text": "What is shown in this image?"}, | |
], | |
} | |
] | |
doc_messages.append(message) | |
doc_texts = [ | |
processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) | |
+ "<|endoftext|>" | |
for msg in doc_messages | |
] | |
doc_image_inputs, doc_video_inputs = process_vision_info(doc_messages) | |
doc_inputs = processor( | |
text=doc_texts, | |
images=doc_image_inputs, | |
videos=doc_video_inputs, | |
padding="longest", | |
return_tensors="pt", | |
).to(device) | |
doc_inputs = model.prepare_inputs_for_generation(**doc_inputs, use_cache=False) | |
with torch.no_grad(): | |
output = model(**doc_inputs, return_dict=True, output_hidden_states=True) | |
doc_embeddings = get_embedding( | |
output.hidden_states[-1], 1536 | |
) # adjust dimensionality for efficiency trade-off e.g. 512 | |
return doc_embeddings | |