"""Wrapper for big_vision contrastive models. Before using any of the functions, make sure to call `setup()`. Choose one of the configs in `MODEL_CONFIGS` and then call `load_model()` to get the params and model wrapper. """ import dataclasses import enum import functools import importlib import os import subprocess import sys import tempfile import flax.linen as nn import jax import jax.numpy as jnp import ml_collections import numpy as np import PIL.Image import sentencepiece from tensorflow.io import gfile import transformers def _clone_git(url, destination_folder, commit_hash=None): subprocess.run( ['git', 'clone', '--depth=1', url, destination_folder], check=True ) if commit_hash: subprocess.run( ['git', '-C', destination_folder, 'checkout', commit_hash], check=True ) def setup(commit_hash=None): """Checks out required non-pypi code from Github.""" for url, dst_name in ( ('https://github.com/google-research/big_vision', 'big_vision_repo'), ('https://github.com/google/flaxformer', 'flaxformer_repo'), ): dst_path = os.path.join(tempfile.gettempdir(), dst_name) if not os.path.exists(dst_path): _clone_git(url, dst_path, commit_hash) if dst_path not in sys.path: sys.path.insert(0, dst_path) class ContrastiveModelFamily(enum.Enum): """Defines a contrastive model family.""" LIT = 'lit' SIGLIP = 'siglip' @property def paper(self): return { self.LIT: 'https://arxiv.org/abs/2111.07991', self.SIGLIP: 'https://arxiv.org/abs/2303.15343', }[self] def __lt__(self, other): return self.value < other.value @dataclasses.dataclass(frozen=True, kw_only=True, order=True) class ContrastiveModelConfig: """Desribes a `big_vision` contrastive model.""" family: ContrastiveModelFamily variant: str res: int textvariant: str embdim: int seqlen: int tokenizer: str vocab_size: int ckpt: str @dataclasses.dataclass(frozen=True, kw_only=True) class ContrastiveModel: """Wraps a `big_vision` contrastive model.""" config: ContrastiveModelConfig flax_module: nn.Module tokenizer_sp: sentencepiece.SentencePieceProcessor | None tokenizer_bert: transformers.BertTokenizer | None def embed_images(self, params, images): assert getattr(images, 'ndim') == 4, 'Must call `.preprocess_images()`' zimg, _, out = self.flax_module.apply(dict(params=params), images, None) return zimg, out def embed_texts(self, params, texts): assert getattr(texts, 'ndim') == 2, 'Must call `.preprocess_texts()`' _, ztxt, out = self.flax_module.apply(dict(params=params), None, texts) return ztxt, out def preprocess_texts(self, texts): """Converts texts to padded tokens.""" def tokenize_pad(text, seqlen=self.config.seqlen): if self.config.family == ContrastiveModelFamily.LIT: tokens = self.tokenizer_bert.encode(text, add_special_tokens=True) tokens = tokens[:-1] # removes [SEP] tokens = tokens[:seqlen] return tokens + [0] * (seqlen - len(tokens)) if self.config.family == ContrastiveModelFamily.SIGLIP: tokens = self.tokenizer_sp.tokenize(text, add_eos=True) if len(tokens) >= seqlen: eos_id = self.tokenizer_sp.eos_id() return tokens[:seqlen - 1] + [eos_id] # "sticky" eos return tokens + [0] * (seqlen - len(tokens)) return np.array([tokenize_pad(text) for text in texts]) def preprocess_images(self, images): if not isinstance(images, (list, tuple)): images = [images] def topil(image): if not isinstance(image, PIL.Image.Image): image = PIL.Image.fromarray(image) return image return np.array([ topil(image).resize([self.config.res, self.config.res]) for image in images ]) / 127.5 - 1.0 def get_bias(self, out): assert ( self.config.family == ContrastiveModelFamily.SIGLIP ), self.config.family return out['b'].item() def get_temperature(self, out): return out['t'].item() def get_probabilities(self, zimg, ztxt, temperature, *, axis=None, bias=None): # Note: zimg, ztxt are already normalized. if self.config.family == ContrastiveModelFamily.LIT: assert bias is None assert axis in (-1, -2), 'Must specify axis: -1/-2=normalize texts/images' return jax.nn.softmax(zimg @ ztxt.T * temperature, axis=axis) if self.config.family == ContrastiveModelFamily.SIGLIP: assert axis is None assert bias is not None, 'Must specify bias.' return jax.nn.sigmoid(zimg @ ztxt.T * temperature + bias) def _make_config( family, variant, res, textvariant, ckpt, embdim, seqlen, vocab_size ): if family == 'lit': tokenizer = ckpt.replace('.npz', '.txt') else: tokenizer = 'c4_en' return ContrastiveModelConfig( family=ContrastiveModelFamily(family), variant=variant, res=res, textvariant=textvariant, embdim=embdim, seqlen=seqlen, tokenizer=tokenizer, vocab_size=vocab_size, ckpt=ckpt, ) # pylint: disable=line-too-long MODEL_CONFIGS = dict( lit_b16b=_make_config('lit', 'B/16', 224, 'B', 'gs://vit_models/lit/LiT-B16B.npz', 768, 16, 32_000), lit_l16l=_make_config('lit', 'L/16', 224, 'L', 'gs://vit_models/lit/LiT-L16L.npz', 1024, 16, 32_000), lit_b16s=_make_config('lit', 'L/16', 224, 'S', 'gs://vit_models/lit/LiT-L16S.npz', 1024, 16, 32_000), lit_b16ti=_make_config('lit', 'L/16', 224, 'Ti', 'gs://vit_models/lit/LiT-L16Ti.npz', 1024, 16, 32_000), siglip_b16b_224=_make_config('siglip', 'B/16', 224, 'B', 'gs://big_vision/siglip/webli_en_b16_224_63724782.npz', 768, 64, 32_000), siglip_b16b_256=_make_config('siglip', 'B/16', 256, 'B', 'gs://big_vision/siglip/webli_en_b16_256_60500360.npz', 768, 64, 32_000), siglip_b16b_384=_make_config('siglip', 'B/16', 384, 'B', 'gs://big_vision/siglip/webli_en_b16_384_68578854.npz', 768, 64, 32_000), siglip_b16b_512=_make_config('siglip', 'B/16', 512, 'B', 'gs://big_vision/siglip/webli_en_b16_512_68580893.npz', 768, 64, 32_000), siglip_l16l_256=_make_config('siglip', 'L/16', 256, 'L', 'gs://big_vision/siglip/webli_en_l16_256_60552751.npz', 1024, 64, 32_000), siglip_l16l_384=_make_config('siglip', 'L/16', 384, 'L', 'gs://big_vision/siglip/webli_en_l16_384_63634585.npz', 1024, 64, 32_000), siglip_so400m14so440m_224=_make_config('siglip', 'So400m/14', 224, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_224_57633886.npz', 1152, 16, 32_000), siglip_so400m14so400m_384=_make_config('siglip', 'So400m/14', 384, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_384_58765454.npz', 1152, 64, 32_000), ) # pylint: enable=line-too-long @functools.cache def load_tokenizer_sp(name_or_path): tok = sentencepiece.SentencePieceProcessor() path = { 'c4_en': 'gs://t5-data/vocabs/cc_en.32000/sentencepiece.model', }.get(name_or_path, name_or_path) tok.LoadFromSerializedProto(gfile.GFile(path, 'rb').read()) return tok @functools.cache def load_tokenizer_bert(path): if path.startswith('gs://'): dst = tempfile.mktemp() gfile.copy(path, dst) path = dst return transformers.BertTokenizer(path, do_lower_case=True) def load_model(config, check_params=False): """Loads `big_vision` model.""" assert isinstance(config, ContrastiveModelConfig), type(config) cfg = ml_collections.ConfigDict() cfg.image_model = 'vit' if config.family == ContrastiveModelFamily.LIT: cfg.text_model = 'proj.flaxformer.bert' cfg.image = dict( variant=config.variant, pool_type='tok', head_zeroinit=False ) bert_config = {'B': 'base', 'L': 'large'}[config.textvariant] cfg.text = dict(config=bert_config, head_zeroinit=False) tokenizer_bert = load_tokenizer_bert(config.tokenizer) tokenizer_sp = None if config.variant == 'L/16': cfg.out_dim = (None, config.embdim) # (image_out_dim, text_out_dim) else: # (image_out_dim, text_out_dim) cfg.out_dim = (config.embdim, config.embdim) else: cfg.image = dict(variant=config.variant, pool_type='map') # TODO(lbeyer): remove later, default cfg.text_model = 'proj.image_text.text_transformer' cfg.text = dict(variant=config.textvariant, vocab_size=config.vocab_size) cfg.bias_init = -10.0 tokenizer_sp = load_tokenizer_sp(config.tokenizer) tokenizer_bert = None cfg.out_dim = (None, config.embdim) # (image_out_dim, text_out_dim) cfg.temperature_init = 10.0 model_mod = importlib.import_module( 'big_vision.models.proj.image_text.two_towers') model = model_mod.Model(**cfg) init_params = None # Faster but bypasses loading sanity-checks. if check_params: imgs = jnp.zeros([1, config.res, config.res, 3]) txts = jnp.zeros([1, config.seqlen], jnp.int32) init_params = model.init(jax.random.PRNGKey(0), imgs, txts)['params'] params_cpu = model_mod.load(init_params, config.ckpt, cfg) return params_cpu, ContrastiveModel( config=config, flax_module=model, tokenizer_sp=tokenizer_sp, tokenizer_bert=tokenizer_bert, )