feat-matryoshka-embeddings
#6
by
koukandre
- opened
- configuration_clip.py +5 -1
- modeling_clip.py +38 -5
configuration_clip.py
CHANGED
@@ -6,7 +6,7 @@
|
|
6 |
|
7 |
import os
|
8 |
from copy import deepcopy
|
9 |
-
from typing import Any, Dict, Optional, Union
|
10 |
|
11 |
from transformers import PretrainedConfig, logging
|
12 |
|
@@ -157,6 +157,8 @@ class JinaCLIPConfig(PretrainedConfig):
|
|
157 |
logit_scale_init_value: float = 2.6592,
|
158 |
use_text_flash_attn: Optional[bool] = None,
|
159 |
use_vision_xformers: Optional[bool] = None,
|
|
|
|
|
160 |
**kwargs,
|
161 |
):
|
162 |
# If `_config_dict` exist, we use them for the backward compatibility.
|
@@ -167,6 +169,8 @@ class JinaCLIPConfig(PretrainedConfig):
|
|
167 |
vision_config_dict: Optional[Dict] = kwargs.pop('vision_config_dict', None)
|
168 |
self.use_text_flash_attn = use_text_flash_attn
|
169 |
self.use_vision_xformers = use_vision_xformers
|
|
|
|
|
170 |
|
171 |
super().__init__(**kwargs)
|
172 |
|
|
|
6 |
|
7 |
import os
|
8 |
from copy import deepcopy
|
9 |
+
from typing import Any, Dict, List, Optional, Union
|
10 |
|
11 |
from transformers import PretrainedConfig, logging
|
12 |
|
|
|
157 |
logit_scale_init_value: float = 2.6592,
|
158 |
use_text_flash_attn: Optional[bool] = None,
|
159 |
use_vision_xformers: Optional[bool] = None,
|
160 |
+
matryoshka_dimensions: Optional[List[int]] = None,
|
161 |
+
truncate_dim: Optional[int] = None,
|
162 |
**kwargs,
|
163 |
):
|
164 |
# If `_config_dict` exist, we use them for the backward compatibility.
|
|
|
169 |
vision_config_dict: Optional[Dict] = kwargs.pop('vision_config_dict', None)
|
170 |
self.use_text_flash_attn = use_text_flash_attn
|
171 |
self.use_vision_xformers = use_vision_xformers
|
172 |
+
self.matryoshka_dimensions = matryoshka_dimensions
|
173 |
+
self.truncate_dim = truncate_dim
|
174 |
|
175 |
super().__init__(**kwargs)
|
176 |
|
modeling_clip.py
CHANGED
@@ -4,12 +4,13 @@
|
|
4 |
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py
|
5 |
# and adjusted for Jina CLIP
|
6 |
|
|
|
7 |
from functools import partial
|
8 |
-
from typing import List, Optional, Tuple, Union
|
9 |
from io import BytesIO
|
10 |
-
import
|
11 |
-
|
12 |
import numpy as np
|
|
|
13 |
import torch
|
14 |
import torch.nn.functional as f
|
15 |
import torch.utils.checkpoint
|
@@ -39,9 +40,14 @@ except ImportError:
|
|
39 |
from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
|
40 |
from .eva_model import EVAVisionTransformer
|
41 |
from .hf_model import HFTextEncoder
|
|
|
42 |
# needed for HF to correctly import in cache
|
43 |
from .rope_embeddings import VisionRotaryEmbeddingFast # noqa: F401
|
44 |
-
from .transform import
|
|
|
|
|
|
|
|
|
45 |
|
46 |
logger = logging.get_logger(__name__)
|
47 |
|
@@ -280,6 +286,20 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
280 |
)
|
281 |
return self.visual_projection(self.vision_model(x=x))
|
282 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
@torch.inference_mode()
|
284 |
def encode_text(
|
285 |
self,
|
@@ -290,6 +310,7 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
290 |
convert_to_tensor: bool = False,
|
291 |
device: Optional[torch.device] = None,
|
292 |
normalize_embeddings: bool = True,
|
|
|
293 |
**tokenizer_kwargs,
|
294 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
295 |
"""
|
@@ -315,6 +336,8 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
315 |
If set to true, returned vectors will have length 1. In that case,
|
316 |
the faster dot-product (util.dot_score) instead of cosine similarity
|
317 |
can be used.
|
|
|
|
|
318 |
tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
|
319 |
Keyword arguments for the tokenizer
|
320 |
Returns:
|
@@ -364,6 +387,7 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
364 |
else:
|
365 |
range_iter = range(0, len(sentences), batch_size)
|
366 |
|
|
|
367 |
for i in range_iter:
|
368 |
encoded_input = self.tokenizer(
|
369 |
sentences[i : i + batch_size],
|
@@ -372,6 +396,9 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
372 |
).to(self.device)
|
373 |
|
374 |
embeddings = self.get_text_features(input_ids=encoded_input)
|
|
|
|
|
|
|
375 |
if normalize_embeddings:
|
376 |
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
377 |
if convert_to_numpy:
|
@@ -406,6 +433,7 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
406 |
convert_to_tensor: bool = False,
|
407 |
device: Optional[torch.device] = None,
|
408 |
normalize_embeddings: bool = True,
|
|
|
409 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
410 |
"""
|
411 |
Computes image embeddings.
|
@@ -431,6 +459,8 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
431 |
If set to true, returned vectors will have length 1. In that case,
|
432 |
the faster dot-product (util.dot_score) instead of cosine similarity
|
433 |
can be used.
|
|
|
|
|
434 |
Returns:
|
435 |
By default, a list of tensors is returned.
|
436 |
If convert_to_tensor, a stacked tensor is returned.
|
@@ -476,7 +506,8 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
476 |
range_iter = range(0, len(images), batch_size)
|
477 |
|
478 |
from PIL import Image
|
479 |
-
|
|
|
480 |
for i in range_iter:
|
481 |
batch_images = images[i:i+batch_size]
|
482 |
processed_inputs = []
|
@@ -501,6 +532,8 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
501 |
processed_inputs = processed_inputs.to(self.device)
|
502 |
embeddings = self.get_image_features(processed_inputs)
|
503 |
|
|
|
|
|
504 |
if normalize_embeddings:
|
505 |
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
506 |
if convert_to_numpy:
|
|
|
4 |
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py
|
5 |
# and adjusted for Jina CLIP
|
6 |
|
7 |
+
import base64
|
8 |
from functools import partial
|
|
|
9 |
from io import BytesIO
|
10 |
+
from typing import List, Optional, Tuple, Union
|
11 |
+
|
12 |
import numpy as np
|
13 |
+
import requests
|
14 |
import torch
|
15 |
import torch.nn.functional as f
|
16 |
import torch.utils.checkpoint
|
|
|
40 |
from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
|
41 |
from .eva_model import EVAVisionTransformer
|
42 |
from .hf_model import HFTextEncoder
|
43 |
+
|
44 |
# needed for HF to correctly import in cache
|
45 |
from .rope_embeddings import VisionRotaryEmbeddingFast # noqa: F401
|
46 |
+
from .transform import ( # noqa: F401
|
47 |
+
OPENAI_DATASET_MEAN,
|
48 |
+
OPENAI_DATASET_STD,
|
49 |
+
image_transform,
|
50 |
+
)
|
51 |
|
52 |
logger = logging.get_logger(__name__)
|
53 |
|
|
|
286 |
)
|
287 |
return self.visual_projection(self.vision_model(x=x))
|
288 |
|
289 |
+
def truncate_embeddings(self, embeddings, truncate_dim):
|
290 |
+
if not self.config.matryoshka_dimensions:
|
291 |
+
logger.warning(
|
292 |
+
"Matryoshka embeddings are not supported, so dimension truncation will not be performed."
|
293 |
+
)
|
294 |
+
return embeddings
|
295 |
+
elif truncate_dim in self.config.matryoshka_dimensions:
|
296 |
+
return embeddings[:, :truncate_dim]
|
297 |
+
else:
|
298 |
+
raise ValueError(
|
299 |
+
f"The provided `truncate_dim` value of {truncate_dim} is not supported. "
|
300 |
+
f"Supported dimensions are {self.config.matryoshka_dimensions}."
|
301 |
+
)
|
302 |
+
|
303 |
@torch.inference_mode()
|
304 |
def encode_text(
|
305 |
self,
|
|
|
310 |
convert_to_tensor: bool = False,
|
311 |
device: Optional[torch.device] = None,
|
312 |
normalize_embeddings: bool = True,
|
313 |
+
truncate_dim: Optional[int] = None,
|
314 |
**tokenizer_kwargs,
|
315 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
316 |
"""
|
|
|
336 |
If set to true, returned vectors will have length 1. In that case,
|
337 |
the faster dot-product (util.dot_score) instead of cosine similarity
|
338 |
can be used.
|
339 |
+
truncate_dim(`int`, *optional*, defaults to None):
|
340 |
+
The dimension to truncate sentence embeddings to. `None` does no truncation.
|
341 |
tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
|
342 |
Keyword arguments for the tokenizer
|
343 |
Returns:
|
|
|
387 |
else:
|
388 |
range_iter = range(0, len(sentences), batch_size)
|
389 |
|
390 |
+
truncate_dim = truncate_dim or self.config.truncate_dim
|
391 |
for i in range_iter:
|
392 |
encoded_input = self.tokenizer(
|
393 |
sentences[i : i + batch_size],
|
|
|
396 |
).to(self.device)
|
397 |
|
398 |
embeddings = self.get_text_features(input_ids=encoded_input)
|
399 |
+
|
400 |
+
if truncate_dim:
|
401 |
+
embeddings = self.truncate_embeddings(embeddings, truncate_dim)
|
402 |
if normalize_embeddings:
|
403 |
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
404 |
if convert_to_numpy:
|
|
|
433 |
convert_to_tensor: bool = False,
|
434 |
device: Optional[torch.device] = None,
|
435 |
normalize_embeddings: bool = True,
|
436 |
+
truncate_dim: Optional[int] = None,
|
437 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
438 |
"""
|
439 |
Computes image embeddings.
|
|
|
459 |
If set to true, returned vectors will have length 1. In that case,
|
460 |
the faster dot-product (util.dot_score) instead of cosine similarity
|
461 |
can be used.
|
462 |
+
truncate_dim(`int`, *optional*, defaults to None):
|
463 |
+
The dimension to truncate sentence embeddings to. `None` does no truncation.
|
464 |
Returns:
|
465 |
By default, a list of tensors is returned.
|
466 |
If convert_to_tensor, a stacked tensor is returned.
|
|
|
506 |
range_iter = range(0, len(images), batch_size)
|
507 |
|
508 |
from PIL import Image
|
509 |
+
|
510 |
+
truncate_dim = truncate_dim or self.config.truncate_dim
|
511 |
for i in range_iter:
|
512 |
batch_images = images[i:i+batch_size]
|
513 |
processed_inputs = []
|
|
|
532 |
processed_inputs = processed_inputs.to(self.device)
|
533 |
embeddings = self.get_image_features(processed_inputs)
|
534 |
|
535 |
+
if truncate_dim:
|
536 |
+
embeddings = self.truncate_embeddings(embeddings, truncate_dim)
|
537 |
if normalize_embeddings:
|
538 |
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
539 |
if convert_to_numpy:
|