lit-demo-bv / big_vision_contrastive_models.py
andsteing's picture
Reformatting, remove TODO.
2805894
"""Wrapper for big_vision contrastive models.
Before using any of the functions, make sure to call `setup()`.
Choose one of the configs in `MODEL_CONFIGS` and then call `load_model()` to get
the params and model wrapper.
"""
import dataclasses
import enum
import functools
import importlib
import os
import subprocess
import sys
import tempfile
import flax.linen as nn
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import PIL.Image
import sentencepiece
from tensorflow.io import gfile
import transformers
def _clone_git(url, destination_folder, commit_hash=None):
subprocess.run(
['git', 'clone', '--depth=1', url, destination_folder], check=True
)
if commit_hash:
subprocess.run(
['git', '-C', destination_folder, 'checkout', commit_hash], check=True
)
def setup(commit_hash=None):
"""Checks out required non-pypi code from Github."""
for url, dst_name in (
('https://github.com/google-research/big_vision', 'big_vision_repo'),
('https://github.com/google/flaxformer', 'flaxformer_repo'),
):
dst_path = os.path.join(tempfile.gettempdir(), dst_name)
if not os.path.exists(dst_path):
_clone_git(url, dst_path, commit_hash)
if dst_path not in sys.path:
sys.path.insert(0, dst_path)
class ContrastiveModelFamily(enum.Enum):
"""Defines a contrastive model family."""
LIT = 'lit'
SIGLIP = 'siglip'
@property
def paper(self):
return {
self.LIT: 'https://arxiv.org/abs/2111.07991',
self.SIGLIP: 'https://arxiv.org/abs/2303.15343',
}[self]
def __lt__(self, other):
return self.value < other.value
@dataclasses.dataclass(frozen=True, kw_only=True, order=True)
class ContrastiveModelConfig:
"""Desribes a `big_vision` contrastive model."""
family: ContrastiveModelFamily
variant: str
res: int
textvariant: str
embdim: int
seqlen: int
tokenizer: str
vocab_size: int
ckpt: str
@dataclasses.dataclass(frozen=True, kw_only=True)
class ContrastiveModel:
"""Wraps a `big_vision` contrastive model."""
config: ContrastiveModelConfig
flax_module: nn.Module
tokenizer_sp: sentencepiece.SentencePieceProcessor | None
tokenizer_bert: transformers.BertTokenizer | None
def embed_images(self, params, images):
assert getattr(images, 'ndim') == 4, 'Must call `.preprocess_images()`'
zimg, _, out = self.flax_module.apply(dict(params=params), images, None)
return zimg, out
def embed_texts(self, params, texts):
assert getattr(texts, 'ndim') == 2, 'Must call `.preprocess_texts()`'
_, ztxt, out = self.flax_module.apply(dict(params=params), None, texts)
return ztxt, out
def preprocess_texts(self, texts):
"""Converts texts to padded tokens."""
def tokenize_pad(text, seqlen=self.config.seqlen):
if self.config.family == ContrastiveModelFamily.LIT:
tokens = self.tokenizer_bert.encode(text, add_special_tokens=True)
tokens = tokens[:-1] # removes [SEP]
tokens = tokens[:seqlen]
return tokens + [0] * (seqlen - len(tokens))
if self.config.family == ContrastiveModelFamily.SIGLIP:
tokens = self.tokenizer_sp.tokenize(text, add_eos=True)
if len(tokens) >= seqlen:
eos_id = self.tokenizer_sp.eos_id()
return tokens[:seqlen - 1] + [eos_id] # "sticky" eos
return tokens + [0] * (seqlen - len(tokens))
return np.array([tokenize_pad(text) for text in texts])
def preprocess_images(self, images):
if not isinstance(images, (list, tuple)):
images = [images]
def topil(image):
if not isinstance(image, PIL.Image.Image):
image = PIL.Image.fromarray(image)
return image
return np.array([
topil(image).resize([self.config.res, self.config.res])
for image in images
]) / 127.5 - 1.0
def get_bias(self, out):
assert (
self.config.family == ContrastiveModelFamily.SIGLIP
), self.config.family
return out['b'].item()
def get_temperature(self, out):
return out['t'].item()
def get_probabilities(self, zimg, ztxt, temperature, *, axis=None, bias=None):
# Note: zimg, ztxt are already normalized.
if self.config.family == ContrastiveModelFamily.LIT:
assert bias is None
assert axis in (-1, -2), 'Must specify axis: -1/-2=normalize texts/images'
return jax.nn.softmax(zimg @ ztxt.T * temperature, axis=axis)
if self.config.family == ContrastiveModelFamily.SIGLIP:
assert axis is None
assert bias is not None, 'Must specify bias.'
return jax.nn.sigmoid(zimg @ ztxt.T * temperature + bias)
def _make_config(
family, variant, res, textvariant, ckpt, embdim, seqlen, vocab_size
):
if family == 'lit':
tokenizer = ckpt.replace('.npz', '.txt')
else:
tokenizer = 'c4_en'
return ContrastiveModelConfig(
family=ContrastiveModelFamily(family), variant=variant, res=res,
textvariant=textvariant, embdim=embdim, seqlen=seqlen,
tokenizer=tokenizer, vocab_size=vocab_size,
ckpt=ckpt,
)
# pylint: disable=line-too-long
MODEL_CONFIGS = dict(
lit_b16b=_make_config('lit', 'B/16', 224, 'B', 'gs://vit_models/lit/LiT-B16B.npz', 768, 16, 32_000),
lit_l16l=_make_config('lit', 'L/16', 224, 'L', 'gs://vit_models/lit/LiT-L16L.npz', 1024, 16, 32_000),
lit_b16s=_make_config('lit', 'L/16', 224, 'S', 'gs://vit_models/lit/LiT-L16S.npz', 1024, 16, 32_000),
lit_b16ti=_make_config('lit', 'L/16', 224, 'Ti', 'gs://vit_models/lit/LiT-L16Ti.npz', 1024, 16, 32_000),
siglip_b16b_224=_make_config('siglip', 'B/16', 224, 'B', 'gs://big_vision/siglip/webli_en_b16_224_63724782.npz', 768, 64, 32_000),
siglip_b16b_256=_make_config('siglip', 'B/16', 256, 'B', 'gs://big_vision/siglip/webli_en_b16_256_60500360.npz', 768, 64, 32_000),
siglip_b16b_384=_make_config('siglip', 'B/16', 384, 'B', 'gs://big_vision/siglip/webli_en_b16_384_68578854.npz', 768, 64, 32_000),
siglip_b16b_512=_make_config('siglip', 'B/16', 512, 'B', 'gs://big_vision/siglip/webli_en_b16_512_68580893.npz', 768, 64, 32_000),
siglip_l16l_256=_make_config('siglip', 'L/16', 256, 'L', 'gs://big_vision/siglip/webli_en_l16_256_60552751.npz', 1024, 64, 32_000),
siglip_l16l_384=_make_config('siglip', 'L/16', 384, 'L', 'gs://big_vision/siglip/webli_en_l16_384_63634585.npz', 1024, 64, 32_000),
siglip_so400m14so440m_224=_make_config('siglip', 'So400m/14', 224, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_224_57633886.npz', 1152, 16, 32_000),
siglip_so400m14so400m_384=_make_config('siglip', 'So400m/14', 384, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_384_58765454.npz', 1152, 64, 32_000),
)
# pylint: enable=line-too-long
@functools.cache
def load_tokenizer_sp(name_or_path):
tok = sentencepiece.SentencePieceProcessor()
path = {
'c4_en': 'gs://t5-data/vocabs/cc_en.32000/sentencepiece.model',
}.get(name_or_path, name_or_path)
tok.LoadFromSerializedProto(gfile.GFile(path, 'rb').read())
return tok
@functools.cache
def load_tokenizer_bert(path):
if path.startswith('gs://'):
dst = tempfile.mktemp()
gfile.copy(path, dst)
path = dst
return transformers.BertTokenizer(path, do_lower_case=True)
def load_model(config, check_params=False):
"""Loads `big_vision` model."""
assert isinstance(config, ContrastiveModelConfig), type(config)
cfg = ml_collections.ConfigDict()
cfg.image_model = 'vit'
if config.family == ContrastiveModelFamily.LIT:
cfg.text_model = 'proj.flaxformer.bert'
cfg.image = dict(
variant=config.variant, pool_type='tok', head_zeroinit=False
)
bert_config = {'B': 'base', 'L': 'large'}[config.textvariant]
cfg.text = dict(config=bert_config, head_zeroinit=False)
tokenizer_bert = load_tokenizer_bert(config.tokenizer)
tokenizer_sp = None
if config.variant == 'L/16':
cfg.out_dim = (None, config.embdim) # (image_out_dim, text_out_dim)
else:
# (image_out_dim, text_out_dim)
cfg.out_dim = (config.embdim, config.embdim)
else:
cfg.image = dict(variant=config.variant, pool_type='map')
# TODO(lbeyer): remove later, default
cfg.text_model = 'proj.image_text.text_transformer'
cfg.text = dict(variant=config.textvariant, vocab_size=config.vocab_size)
cfg.bias_init = -10.0
tokenizer_sp = load_tokenizer_sp(config.tokenizer)
tokenizer_bert = None
cfg.out_dim = (None, config.embdim) # (image_out_dim, text_out_dim)
cfg.temperature_init = 10.0
model_mod = importlib.import_module(
'big_vision.models.proj.image_text.two_towers')
model = model_mod.Model(**cfg)
init_params = None # Faster but bypasses loading sanity-checks.
if check_params:
imgs = jnp.zeros([1, config.res, config.res, 3])
txts = jnp.zeros([1, config.seqlen], jnp.int32)
init_params = model.init(jax.random.PRNGKey(0), imgs, txts)['params']
params_cpu = model_mod.load(init_params, config.ckpt, cfg)
return params_cpu, ContrastiveModel(
config=config,
flax_module=model,
tokenizer_sp=tokenizer_sp,
tokenizer_bert=tokenizer_bert,
)