|
import tensorflow as tf |
|
from tensorflow.keras.layers import Dense,Conv2d,BatchNormalization,LayerNormalization,MultiHeadAttention |
|
from tensorflow.keras.layers import ZeroPadding2D,AveragePooling2D,Identity |
|
from tensorflow.keras import Model |
|
import numpy as np |
|
from typing import Tuple, Union |
|
|
|
|
|
class Bottleneck(tf.keras.layers.Layer): |
|
expansion = 4 |
|
|
|
def __init__(self, inplanes, planes, stride=1): |
|
|
|
super(Bottleneck, self).__init__() |
|
self.conv1 = Conv2d(planes, 1, use_bias=False) |
|
self.bn1 = BatchNormalization() |
|
self.relu1 = tf.nn.relu |
|
|
|
self.zeropadding2d = ZeroPadding2D(padding=1) |
|
self.conv2 = Conv2d(planes, 3, use_bias=False) |
|
self.bn2 = BatchNormalization() |
|
self.relu2 = tf.nn.relu |
|
|
|
self.avgpool = AveragePooling2D(stride, stride, 'VALID') if stride > 1 else Identity() |
|
|
|
self.conv3 = Conv2d(planes * self.expansion, 1, use_bias=False) |
|
self.bn3 = BatchNormalization() |
|
self.relu3 = tf.nn.relu |
|
|
|
self.downsample = None |
|
self.stride = stride |
|
|
|
if stride > 1 or inplanes != planes * Bottleneck.expansion: |
|
|
|
self.downsample = tf.keras.Sequential() |
|
self.downsample.add(AveragePooling2D(stride, stride, 'VALID')) |
|
self.downsample.add(Conv2d(planes * self.expansion, 1, strides=1, use_bias=False)) |
|
self.downsample.add(BatchNormalization()) |
|
|
|
def __call__(self, x): |
|
identity = x |
|
|
|
out = self.relu1(self.bn1(self.conv1(x))) |
|
out = self.zeropadding2d(out) |
|
out = self.relu2(self.bn2(self.conv2(out))) |
|
out = self.avgpool(out) |
|
out = self.bn3(self.conv3(out)) |
|
|
|
if self.downsample is not None: |
|
identity = self.downsample(x) |
|
|
|
out += identity |
|
out = self.relu3(out) |
|
return out |
|
|
|
|
|
class AttentionPool2d(tf.keras.layers.Layer): |
|
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): |
|
self.positional_embedding = self.add_weight( |
|
name='positional_embedding', |
|
shape=[self.spacial_dim ** 2 + 1, self.embed_dim], |
|
initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1./self.embed_dim**0.5), |
|
trainable=True |
|
) |
|
self.k_proj = Dense(embed_dim) |
|
self.q_proj = Dense(embed_dim) |
|
self.v_proj = Dense(embed_dim) |
|
self.c_proj = Dense(output_dim or embed_dim) |
|
self.num_heads = num_heads |
|
|
|
def __call__(self, x): |
|
shape = x.shape |
|
batch_size = shape[0] |
|
height = shape[1] |
|
width = shape[2] |
|
channels = shape[3] |
|
new_shape = (batch_size, height * width, channels) |
|
x = tf.transpose(tf.reshape(x, new_shape), (1, 0, 2)) |
|
x = tf.concat([tf.reduce_mean(x, axis=0, keepdims=True), x], axis=0) |
|
x = x + tf.cast(self.positional_embedding[:, None, :], x.dtype) |
|
tgt_len, bsz, embed_dim = x.shape |
|
query=self.q_proj(x[:1]) |
|
key=self.k_proj(x) |
|
value=self.v_proj(x) |
|
query = tf.reshape(query, [bsz, 1, self.num_heads, -1]) |
|
query = tf.transpose(query, [0, 2, 1, 3]) |
|
query = tf.multiply(query, 1.0 / tf.math.sqrt(float(embed_dim))) |
|
key = tf.reshape(key, [bsz, tgt_len, self.num_heads, -1]) |
|
key = tf.transpose(key, [0, 2, 3, 1]) |
|
value = tf.reshape(value, [bsz, tgt_len, self.num_heads, -1]) |
|
value = tf.transpose(value, [0, 2, 1, 3]) |
|
qk = tf.matmul(query, key) |
|
w = tf.nn.softmax(qk) |
|
wv = tf.reshape(tf.transpose(tf.matmul(w, value), [0, 2, 1, 3]), [1, bsz, -1]) |
|
x = self.c_proj(wv) |
|
return tf.squeeze(x, 0) |
|
|
|
|
|
class ModifiedResNet: |
|
""" |
|
A ResNet class that is similar to torchvision's but contains the following changes: |
|
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. |
|
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 |
|
- The final pooling layer is a QKV attention instead of an average pool |
|
""" |
|
|
|
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): |
|
self.output_dim = output_dim |
|
self.input_resolution = input_resolution |
|
|
|
|
|
self.zeropadding2d = ZeroPadding2D(padding=1) |
|
self.conv1 = Conv2d(width // 2, kernel_size=3, strides=2, use_bias=False) |
|
self.bn1 = BatchNormalization() |
|
self.relu1 = tf.nn.relu |
|
self.conv2 = Conv2d(width // 2, kernel_size=3, use_bias=False) |
|
self.bn2 = BatchNormalization() |
|
self.relu2 = tf.nn.relu |
|
self.conv3 = Conv2d(width, kernel_size=3, use_bias=False) |
|
self.bn3 = BatchNormalization() |
|
self.relu3 = tf.nn.relu |
|
self.avgpool = AveragePooling2D(2, 2, 'VALID') |
|
|
|
|
|
self._inplanes = width |
|
self.layer1 = self._make_layer(width, layers[0]) |
|
self.layer2 = self._make_layer(width * 2, layers[1], stride=2) |
|
self.layer3 = self._make_layer(width * 4, layers[2], stride=2) |
|
self.layer4 = self._make_layer(width * 8, layers[3], stride=2) |
|
|
|
embed_dim = width * 32 |
|
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) |
|
|
|
def _make_layer(self, planes, blocks, stride=1): |
|
layers = tf.keras.Sequential() |
|
layers.add(Bottleneck(self._inplanes, planes, stride)) |
|
|
|
self._inplanes = planes * Bottleneck.expansion |
|
for _ in range(1, blocks): |
|
layers.add(Bottleneck(self._inplanes, planes)) |
|
|
|
return layers |
|
|
|
def __call__(self, x): |
|
def stem(x): |
|
x = self.zeropadding2d(x) |
|
x = self.conv1(x) |
|
x = self.relu1(self.bn1(x)) |
|
x = self.zeropadding2d(x) |
|
x = self.conv2(x) |
|
x = self.relu2(self.bn2(x)) |
|
x = self.zeropadding2d(x) |
|
x = self.conv3(x) |
|
x = self.relu3(self.bn3(x)) |
|
x = self.avgpool(x) |
|
return x |
|
|
|
x = stem(x) |
|
x = self.layer1(x) |
|
x = self.layer2(x) |
|
x = self.layer3(x) |
|
x = self.layer4(x) |
|
x = self.attnpool(x) |
|
|
|
return x |
|
|
|
|
|
class LayerNorm: |
|
"""Subclass torch's LayerNorm to handle fp16.""" |
|
def __init__(self, input_size): |
|
self.layer_norm = LayerNormalization() |
|
|
|
def __call__(self, x): |
|
orig_type = x.dtype |
|
ret = self.layer_norm(tf.cast(x, tf.float32)) |
|
return tf.cast(ret, orig_type) |
|
|
|
|
|
class QuickGELU(tf.keras.layers.Layer): |
|
def __init__(self): |
|
super(QuickGELU, self).__init__() |
|
|
|
def __call__(self, x): |
|
return x * tf.nn.sigmoid(1.702 * x) |
|
|
|
|
|
class ResidualAttentionBlock(tf.keras.layers.Layer): |
|
def __init__(self, d_model: int, n_head: int, attn_mask = None): |
|
super(ResidualAttentionBlock, self).__init__() |
|
self.attn = MultiHeadAttention(n_head, d_model) |
|
self.ln_1 = LayerNorm(d_model) |
|
self.mlp = tf.keras.Sequential() |
|
self.mlp.add(Dense(d_model * 4)) |
|
self.mlp.add(QuickGELU()) |
|
self.mlp.add(Dense(d_model)) |
|
self.ln_2 = LayerNorm(d_model) |
|
self.attn_mask = attn_mask |
|
|
|
def attention(self, x): |
|
self.attn_mask = tf.cast(self.attn_mask, x.dtype) if self.attn_mask is not None else None |
|
return self.attn(x, x, attention_mask=self.attn_mask)[0] |
|
|
|
def __call__(self, x): |
|
x = x + self.attention(self.ln_1(x)) |
|
x = x + self.mlp(self.ln_2(x)) |
|
return x |
|
|
|
|
|
class Transformer: |
|
def __init__(self, width: int, layers: int, heads: int, attn_mask = None): |
|
self.width = width |
|
self.layers = layers |
|
self.resblocks = tf.keras.Sequential() |
|
for _ in range(layers): |
|
self.resblocks.add(ResidualAttentionBlock(width, heads, attn_mask)) |
|
|
|
def __call__(self, x): |
|
return self.resblocks(x) |
|
|
|
|
|
class VisionTransformer(tf.keras.layers.Layer): |
|
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): |
|
self.input_resolution = input_resolution |
|
self.output_dim = output_dim |
|
self.conv1 = Conv2d(width, kernel_size=patch_size, strides=patch_size, use_bias=False) |
|
|
|
scale = width ** -0.5 |
|
self.class_embedding = self.add_weight( |
|
name='class_embedding', |
|
shape=[self.width], |
|
initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1.0) * self.scale, |
|
trainable=True |
|
) |
|
self.positional_embedding = self.add_weight( |
|
name='positional_embedding', |
|
shape=[(self.input_resolution // self.patch_size) ** 2 + 1, self.width], |
|
initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1.0) * self.scale, |
|
trainable=True |
|
) |
|
self.ln_pre = LayerNorm(width) |
|
|
|
self.transformer = Transformer(width, layers, heads) |
|
|
|
self.ln_post = LayerNorm(width) |
|
self.proj = tf.Variable(scale * tf.random.normal(width, output_dim)) |
|
|
|
def __call__(self, x, train_flag=True): |
|
x = self.conv1(x) |
|
x = tf.reshape(x, [x.shape[0], x.shape[1], -1]) |
|
x = tf.transpose(x, (0, 2, 1)) |
|
x = tf.concat([tf.cast(self.class_embedding, x.dtype) + tf.zeros([x.shape[0], 1, x.shape[-1]], dtype=x.dtype), x], axis=1) |
|
x = x + tf.cast(self.positional_embedding, x.dtype) |
|
x = self.ln_pre(x) |
|
|
|
x = tf.transpose(x, (1, 0, 2)) |
|
x = self.transformer(x) |
|
x = tf.transpose(x, (1, 0, 2)) |
|
|
|
x = self.ln_post(x[:, 0, :]) |
|
|
|
if self.proj is not None: |
|
x = tf.matmul(x, self.proj) |
|
|
|
return x |
|
|
|
|
|
class CLIP(Model): |
|
def __init__(self, |
|
embed_dim: int, |
|
|
|
image_resolution: int, |
|
vision_layers: Union[Tuple[int, int, int, int], int], |
|
vision_width: int, |
|
vision_patch_size: int, |
|
|
|
context_length: int, |
|
vocab_size: int, |
|
transformer_width: int, |
|
transformer_heads: int, |
|
transformer_layers: int |
|
): |
|
super(CLIP, self).__init__() |
|
|
|
self.context_length = context_length |
|
|
|
if isinstance(vision_layers, (tuple, list)): |
|
vision_heads = vision_width * 32 // 64 |
|
self.visual = ModifiedResNet( |
|
layers=vision_layers, |
|
output_dim=embed_dim, |
|
heads=vision_heads, |
|
input_resolution=image_resolution, |
|
width=vision_width |
|
) |
|
else: |
|
vision_heads = vision_width // 64 |
|
self.visual = VisionTransformer( |
|
input_resolution=image_resolution, |
|
patch_size=vision_patch_size, |
|
width=vision_width, |
|
layers=vision_layers, |
|
heads=vision_heads, |
|
output_dim=embed_dim |
|
) |
|
|
|
self.transformer = Transformer( |
|
width=transformer_width, |
|
layers=transformer_layers, |
|
heads=transformer_heads, |
|
attn_mask=self.build_attention_mask() |
|
) |
|
|
|
self.vocab_size = vocab_size |
|
self.token_embedding = self.add_weight( |
|
name='token_embedding', |
|
shape=(vocab_size, transformer_width), |
|
initializer=tf.keras.initializers.RandomNormal(stddev=0.02), |
|
trainable=True |
|
) |
|
self.positional_embedding = self.add_weight( |
|
name='positional_embedding', |
|
shape=(self.context_length, transformer_width), |
|
initializer=tf.keras.initializers.RandomNormal(stddev=0.01), |
|
trainable=True |
|
) |
|
self.ln_final = LayerNorm(transformer_width) |
|
|
|
self.text_projection = self.add_weight( |
|
name='text_projection', |
|
shape=(transformer_width, embed_dim), |
|
initializer=tf.keras.initializers.RandomNormal(stddev=transformer_width ** -0.5), |
|
trainable=True |
|
) |
|
self.logit_scale = self.add_weight( |
|
name='logit_scale', |
|
shape=[], |
|
initializer=tf.keras.initializers.Constant(np.log(1 / 0.07)), |
|
trainable=True |
|
) |
|
|
|
def build_attention_mask(self): |
|
mask = tf.ones((self.context_length, self.context_length)) |
|
mask = tf.linalg.band_part(mask, 0, -1) |
|
mask = mask * -1e9 |
|
return mask |
|
|
|
def encode_image(self, image): |
|
return self.visual(image) |
|
|
|
def encode_text(self, text): |
|
x = tf.gather(self.token_embedding, text) |
|
|
|
x = x + self.positional_embedding |
|
x = tf.transpose(x, (1, 0, 2)) |
|
x = self.transformer(x) |
|
x = tf.transpose(x, (1, 0, 2)) |
|
x = self.ln_final(x) |
|
|
|
|
|
|
|
x = tf.matmul(tf.gather_nd(x, tf.stack([tf.range(x.shape[0], dtype='int32'), |
|
tf.argmax(text, axis=-1, output_type='int32')], axis=1)), self.text_projection) |
|
|
|
return x |
|
|
|
def __call__(self, image, text): |
|
image_features = self.encode_image(image) |
|
text_features = self.encode_text(text) |
|
|
|
|
|
image_features = image_features / tf.norm(image_features, axis=1, keepdims=True) |
|
text_features = text_features / tf.norm(text_features, axis=1, keepdims=True) |
|
|
|
|
|
logit_scale = tf.math.exp(self.logit_scale) |
|
logits_per_image = tf.matmul(logit_scale * image_features, tf.transpose(text_features)) |
|
logits_per_text = tf.transpose(logits_per_image) |
|
|
|
|
|
return logits_per_image, logits_per_text |