Spaces:
Sleeping
Sleeping
"""Wrapper for big_vision contrastive models. | |
Before using any of the functions, make sure to call `setup()`. | |
Choose one of the configs in `MODEL_CONFIGS` and then call `load_model()` to get | |
the params and model wrapper. | |
""" | |
import dataclasses | |
import enum | |
import functools | |
import importlib | |
import os | |
import subprocess | |
import sys | |
import tempfile | |
import flax.linen as nn | |
import jax | |
import jax.numpy as jnp | |
import ml_collections | |
import numpy as np | |
import PIL.Image | |
import sentencepiece | |
from tensorflow.io import gfile | |
import transformers | |
def _clone_git(url, destination_folder, commit_hash=None): | |
subprocess.run( | |
['git', 'clone', '--depth=1', url, destination_folder], check=True | |
) | |
if commit_hash: | |
subprocess.run( | |
['git', '-C', destination_folder, 'checkout', commit_hash], check=True | |
) | |
def setup(commit_hash=None): | |
"""Checks out required non-pypi code from Github.""" | |
for url, dst_name in ( | |
('https://github.com/google-research/big_vision', 'big_vision_repo'), | |
('https://github.com/google/flaxformer', 'flaxformer_repo'), | |
): | |
dst_path = os.path.join(tempfile.gettempdir(), dst_name) | |
if not os.path.exists(dst_path): | |
_clone_git(url, dst_path, commit_hash) | |
if dst_path not in sys.path: | |
sys.path.insert(0, dst_path) | |
class ContrastiveModelFamily(enum.Enum): | |
"""Defines a contrastive model family.""" | |
LIT = 'lit' | |
SIGLIP = 'siglip' | |
def paper(self): | |
return { | |
self.LIT: 'https://arxiv.org/abs/2111.07991', | |
self.SIGLIP: 'https://arxiv.org/abs/2303.15343', | |
}[self] | |
def __lt__(self, other): | |
return self.value < other.value | |
class ContrastiveModelConfig: | |
"""Desribes a `big_vision` contrastive model.""" | |
family: ContrastiveModelFamily | |
variant: str | |
res: int | |
textvariant: str | |
embdim: int | |
seqlen: int | |
tokenizer: str | |
vocab_size: int | |
ckpt: str | |
class ContrastiveModel: | |
"""Wraps a `big_vision` contrastive model.""" | |
config: ContrastiveModelConfig | |
flax_module: nn.Module | |
tokenizer_sp: sentencepiece.SentencePieceProcessor | None | |
tokenizer_bert: transformers.BertTokenizer | None | |
def embed_images(self, params, images): | |
assert getattr(images, 'ndim') == 4, 'Must call `.preprocess_images()`' | |
zimg, _, out = self.flax_module.apply(dict(params=params), images, None) | |
return zimg, out | |
def embed_texts(self, params, texts): | |
assert getattr(texts, 'ndim') == 2, 'Must call `.preprocess_texts()`' | |
_, ztxt, out = self.flax_module.apply(dict(params=params), None, texts) | |
return ztxt, out | |
def preprocess_texts(self, texts): | |
"""Converts texts to padded tokens.""" | |
def tokenize_pad(text, seqlen=self.config.seqlen): | |
if self.config.family == ContrastiveModelFamily.LIT: | |
tokens = self.tokenizer_bert.encode(text, add_special_tokens=True) | |
tokens = tokens[:-1] # removes [SEP] | |
tokens = tokens[:seqlen] | |
return tokens + [0] * (seqlen - len(tokens)) | |
if self.config.family == ContrastiveModelFamily.SIGLIP: | |
tokens = self.tokenizer_sp.tokenize(text, add_eos=True) | |
if len(tokens) >= seqlen: | |
eos_id = self.tokenizer_sp.eos_id() | |
return tokens[:seqlen - 1] + [eos_id] # "sticky" eos | |
return tokens + [0] * (seqlen - len(tokens)) | |
return np.array([tokenize_pad(text) for text in texts]) | |
def preprocess_images(self, images): | |
if not isinstance(images, (list, tuple)): | |
images = [images] | |
def topil(image): | |
if not isinstance(image, PIL.Image.Image): | |
image = PIL.Image.fromarray(image) | |
return image | |
return np.array([ | |
topil(image).resize([self.config.res, self.config.res]) | |
for image in images | |
]) / 127.5 - 1.0 | |
def get_bias(self, out): | |
assert ( | |
self.config.family == ContrastiveModelFamily.SIGLIP | |
), self.config.family | |
return out['b'].item() | |
def get_temperature(self, out): | |
return out['t'].item() | |
def get_probabilities(self, zimg, ztxt, temperature, *, axis=None, bias=None): | |
# Note: zimg, ztxt are already normalized. | |
if self.config.family == ContrastiveModelFamily.LIT: | |
assert bias is None | |
assert axis in (-1, -2), 'Must specify axis: -1/-2=normalize texts/images' | |
return jax.nn.softmax(zimg @ ztxt.T * temperature, axis=axis) | |
if self.config.family == ContrastiveModelFamily.SIGLIP: | |
assert axis is None | |
assert bias is not None, 'Must specify bias.' | |
return jax.nn.sigmoid(zimg @ ztxt.T * temperature + bias) | |
def _make_config( | |
family, variant, res, textvariant, ckpt, embdim, seqlen, vocab_size | |
): | |
if family == 'lit': | |
tokenizer = ckpt.replace('.npz', '.txt') | |
else: | |
tokenizer = 'c4_en' | |
return ContrastiveModelConfig( | |
family=ContrastiveModelFamily(family), variant=variant, res=res, | |
textvariant=textvariant, embdim=embdim, seqlen=seqlen, | |
tokenizer=tokenizer, vocab_size=vocab_size, | |
ckpt=ckpt, | |
) | |
# pylint: disable=line-too-long | |
MODEL_CONFIGS = dict( | |
lit_b16b=_make_config('lit', 'B/16', 224, 'B', 'gs://vit_models/lit/LiT-B16B.npz', 768, 16, 32_000), | |
lit_l16l=_make_config('lit', 'L/16', 224, 'L', 'gs://vit_models/lit/LiT-L16L.npz', 1024, 16, 32_000), | |
lit_b16s=_make_config('lit', 'L/16', 224, 'S', 'gs://vit_models/lit/LiT-L16S.npz', 1024, 16, 32_000), | |
lit_b16ti=_make_config('lit', 'L/16', 224, 'Ti', 'gs://vit_models/lit/LiT-L16Ti.npz', 1024, 16, 32_000), | |
siglip_b16b_224=_make_config('siglip', 'B/16', 224, 'B', 'gs://big_vision/siglip/webli_en_b16_224_63724782.npz', 768, 64, 32_000), | |
siglip_b16b_256=_make_config('siglip', 'B/16', 256, 'B', 'gs://big_vision/siglip/webli_en_b16_256_60500360.npz', 768, 64, 32_000), | |
siglip_b16b_384=_make_config('siglip', 'B/16', 384, 'B', 'gs://big_vision/siglip/webli_en_b16_384_68578854.npz', 768, 64, 32_000), | |
siglip_b16b_512=_make_config('siglip', 'B/16', 512, 'B', 'gs://big_vision/siglip/webli_en_b16_512_68580893.npz', 768, 64, 32_000), | |
siglip_l16l_256=_make_config('siglip', 'L/16', 256, 'L', 'gs://big_vision/siglip/webli_en_l16_256_60552751.npz', 1024, 64, 32_000), | |
siglip_l16l_384=_make_config('siglip', 'L/16', 384, 'L', 'gs://big_vision/siglip/webli_en_l16_384_63634585.npz', 1024, 64, 32_000), | |
siglip_so400m14so440m_224=_make_config('siglip', 'So400m/14', 224, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_224_57633886.npz', 1152, 16, 32_000), | |
siglip_so400m14so400m_384=_make_config('siglip', 'So400m/14', 384, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_384_58765454.npz', 1152, 64, 32_000), | |
) | |
# pylint: enable=line-too-long | |
def load_tokenizer_sp(name_or_path): | |
tok = sentencepiece.SentencePieceProcessor() | |
path = { | |
'c4_en': 'gs://t5-data/vocabs/cc_en.32000/sentencepiece.model', | |
}.get(name_or_path, name_or_path) | |
tok.LoadFromSerializedProto(gfile.GFile(path, 'rb').read()) | |
return tok | |
def load_tokenizer_bert(path): | |
if path.startswith('gs://'): | |
dst = tempfile.mktemp() | |
gfile.copy(path, dst) | |
path = dst | |
return transformers.BertTokenizer(path, do_lower_case=True) | |
def load_model(config, check_params=False): | |
"""Loads `big_vision` model.""" | |
assert isinstance(config, ContrastiveModelConfig), type(config) | |
cfg = ml_collections.ConfigDict() | |
cfg.image_model = 'vit' | |
if config.family == ContrastiveModelFamily.LIT: | |
cfg.text_model = 'proj.flaxformer.bert' | |
cfg.image = dict( | |
variant=config.variant, pool_type='tok', head_zeroinit=False | |
) | |
bert_config = {'B': 'base', 'L': 'large'}[config.textvariant] | |
cfg.text = dict(config=bert_config, head_zeroinit=False) | |
tokenizer_bert = load_tokenizer_bert(config.tokenizer) | |
tokenizer_sp = None | |
if config.variant == 'L/16': | |
cfg.out_dim = (None, config.embdim) # (image_out_dim, text_out_dim) | |
else: | |
# (image_out_dim, text_out_dim) | |
cfg.out_dim = (config.embdim, config.embdim) | |
else: | |
cfg.image = dict(variant=config.variant, pool_type='map') | |
# TODO(lbeyer): remove later, default | |
cfg.text_model = 'proj.image_text.text_transformer' | |
cfg.text = dict(variant=config.textvariant, vocab_size=config.vocab_size) | |
cfg.bias_init = -10.0 | |
tokenizer_sp = load_tokenizer_sp(config.tokenizer) | |
tokenizer_bert = None | |
cfg.out_dim = (None, config.embdim) # (image_out_dim, text_out_dim) | |
cfg.temperature_init = 10.0 | |
model_mod = importlib.import_module( | |
'big_vision.models.proj.image_text.two_towers') | |
model = model_mod.Model(**cfg) | |
init_params = None # Faster but bypasses loading sanity-checks. | |
if check_params: | |
imgs = jnp.zeros([1, config.res, config.res, 3]) | |
txts = jnp.zeros([1, config.seqlen], jnp.int32) | |
init_params = model.init(jax.random.PRNGKey(0), imgs, txts)['params'] | |
params_cpu = model_mod.load(init_params, config.ckpt, cfg) | |
return params_cpu, ContrastiveModel( | |
config=config, | |
flax_module=model, | |
tokenizer_sp=tokenizer_sp, | |
tokenizer_bert=tokenizer_bert, | |
) | |