CLIP-Keras / CLIP.py
NoteDance's picture
Update CLIP.py
c86eba0 verified
import tensorflow as tf
from tensorflow.keras.layers import Dense,Conv2d,BatchNormalization,LayerNormalization,MultiHeadAttention
from tensorflow.keras.layers import ZeroPadding2D,AveragePooling2D,Identity
from tensorflow.keras import Model
import numpy as np
from typing import Tuple, Union
class Bottleneck(tf.keras.layers.Layer):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
super(Bottleneck, self).__init__()
self.conv1 = Conv2d(planes, 1, use_bias=False)
self.bn1 = BatchNormalization()
self.relu1 = tf.nn.relu
self.zeropadding2d = ZeroPadding2D(padding=1)
self.conv2 = Conv2d(planes, 3, use_bias=False)
self.bn2 = BatchNormalization()
self.relu2 = tf.nn.relu
self.avgpool = AveragePooling2D(stride, stride, 'VALID') if stride > 1 else Identity()
self.conv3 = Conv2d(planes * self.expansion, 1, use_bias=False)
self.bn3 = BatchNormalization()
self.relu3 = tf.nn.relu
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = tf.keras.Sequential()
self.downsample.add(AveragePooling2D(stride, stride, 'VALID'))
self.downsample.add(Conv2d(planes * self.expansion, 1, strides=1, use_bias=False))
self.downsample.add(BatchNormalization())
def __call__(self, x):
identity = x
out = self.relu1(self.bn1(self.conv1(x)))
out = self.zeropadding2d(out)
out = self.relu2(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu3(out)
return out
class AttentionPool2d(tf.keras.layers.Layer):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
self.positional_embedding = self.add_weight(
name='positional_embedding',
shape=[self.spacial_dim ** 2 + 1, self.embed_dim],
initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1./self.embed_dim**0.5),
trainable=True
)
self.k_proj = Dense(embed_dim)
self.q_proj = Dense(embed_dim)
self.v_proj = Dense(embed_dim)
self.c_proj = Dense(output_dim or embed_dim)
self.num_heads = num_heads
def __call__(self, x):
shape = x.shape
batch_size = shape[0]
height = shape[1]
width = shape[2]
channels = shape[3]
new_shape = (batch_size, height * width, channels)
x = tf.transpose(tf.reshape(x, new_shape), (1, 0, 2))
x = tf.concat([tf.reduce_mean(x, axis=0, keepdims=True), x], axis=0) # (HW+1)NC
x = x + tf.cast(self.positional_embedding[:, None, :], x.dtype) # (HW+1)NC
tgt_len, bsz, embed_dim = x.shape
query=self.q_proj(x[:1])
key=self.k_proj(x)
value=self.v_proj(x)
query = tf.reshape(query, [bsz, 1, self.num_heads, -1])
query = tf.transpose(query, [0, 2, 1, 3])
query = tf.multiply(query, 1.0 / tf.math.sqrt(float(embed_dim)))
key = tf.reshape(key, [bsz, tgt_len, self.num_heads, -1])
key = tf.transpose(key, [0, 2, 3, 1])
value = tf.reshape(value, [bsz, tgt_len, self.num_heads, -1])
value = tf.transpose(value, [0, 2, 1, 3])
qk = tf.matmul(query, key)
w = tf.nn.softmax(qk)
wv = tf.reshape(tf.transpose(tf.matmul(w, value), [0, 2, 1, 3]), [1, bsz, -1])
x = self.c_proj(wv)
return tf.squeeze(x, 0)
class ModifiedResNet:
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
self.output_dim = output_dim
self.input_resolution = input_resolution
# the 3-layer stem
self.zeropadding2d = ZeroPadding2D(padding=1)
self.conv1 = Conv2d(width // 2, kernel_size=3, strides=2, use_bias=False)
self.bn1 = BatchNormalization()
self.relu1 = tf.nn.relu
self.conv2 = Conv2d(width // 2, kernel_size=3, use_bias=False)
self.bn2 = BatchNormalization()
self.relu2 = tf.nn.relu
self.conv3 = Conv2d(width, kernel_size=3, use_bias=False)
self.bn3 = BatchNormalization()
self.relu3 = tf.nn.relu
self.avgpool = AveragePooling2D(2, 2, 'VALID')
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
def _make_layer(self, planes, blocks, stride=1):
layers = tf.keras.Sequential()
layers.add(Bottleneck(self._inplanes, planes, stride))
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.add(Bottleneck(self._inplanes, planes))
return layers
def __call__(self, x):
def stem(x):
x = self.zeropadding2d(x)
x = self.conv1(x)
x = self.relu1(self.bn1(x))
x = self.zeropadding2d(x)
x = self.conv2(x)
x = self.relu2(self.bn2(x))
x = self.zeropadding2d(x)
x = self.conv3(x)
x = self.relu3(self.bn3(x))
x = self.avgpool(x)
return x
x = stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
class LayerNorm:
"""Subclass torch's LayerNorm to handle fp16."""
def __init__(self, input_size):
self.layer_norm = LayerNormalization()
def __call__(self, x):
orig_type = x.dtype
ret = self.layer_norm(tf.cast(x, tf.float32))
return tf.cast(ret, orig_type)
class QuickGELU(tf.keras.layers.Layer):
def __init__(self):
super(QuickGELU, self).__init__()
def __call__(self, x):
return x * tf.nn.sigmoid(1.702 * x)
class ResidualAttentionBlock(tf.keras.layers.Layer):
def __init__(self, d_model: int, n_head: int, attn_mask = None):
super(ResidualAttentionBlock, self).__init__()
self.attn = MultiHeadAttention(n_head, d_model)
self.ln_1 = LayerNorm(d_model)
self.mlp = tf.keras.Sequential()
self.mlp.add(Dense(d_model * 4))
self.mlp.add(QuickGELU())
self.mlp.add(Dense(d_model))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x):
self.attn_mask = tf.cast(self.attn_mask, x.dtype) if self.attn_mask is not None else None
return self.attn(x, x, attention_mask=self.attn_mask)[0]
def __call__(self, x):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer:
def __init__(self, width: int, layers: int, heads: int, attn_mask = None):
self.width = width
self.layers = layers
self.resblocks = tf.keras.Sequential()
for _ in range(layers):
self.resblocks.add(ResidualAttentionBlock(width, heads, attn_mask))
def __call__(self, x):
return self.resblocks(x)
class VisionTransformer(tf.keras.layers.Layer):
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = Conv2d(width, kernel_size=patch_size, strides=patch_size, use_bias=False)
scale = width ** -0.5
self.class_embedding = self.add_weight(
name='class_embedding',
shape=[self.width],
initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1.0) * self.scale,
trainable=True
)
self.positional_embedding = self.add_weight(
name='positional_embedding',
shape=[(self.input_resolution // self.patch_size) ** 2 + 1, self.width],
initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1.0) * self.scale,
trainable=True
)
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.ln_post = LayerNorm(width)
self.proj = tf.Variable(scale * tf.random.normal(width, output_dim))
def __call__(self, x, train_flag=True):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = tf.reshape(x, [x.shape[0], x.shape[1], -1]) # shape = [*, width, grid ** 2]
x = tf.transpose(x, (0, 2, 1)) # shape = [*, grid ** 2, width]
x = tf.concat([tf.cast(self.class_embedding, x.dtype) + tf.zeros([x.shape[0], 1, x.shape[-1]], dtype=x.dtype), x], axis=1) # shape = [*, grid ** 2 + 1, width]
x = x + tf.cast(self.positional_embedding, x.dtype)
x = self.ln_pre(x)
x = tf.transpose(x, (1, 0, 2)) # NLD -> LND
x = self.transformer(x)
x = tf.transpose(x, (1, 0, 2)) # LND -> NLD
x = self.ln_post(x[:, 0, :])
if self.proj is not None:
x = tf.matmul(x, self.proj)
return x
class CLIP(Model):
def __init__(self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int
):
super(CLIP, self).__init__()
self.context_length = context_length
if isinstance(vision_layers, (tuple, list)):
vision_heads = vision_width * 32 // 64
self.visual = ModifiedResNet(
layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width
)
else:
vision_heads = vision_width // 64
self.visual = VisionTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim
)
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask()
)
self.vocab_size = vocab_size
self.token_embedding = self.add_weight(
name='token_embedding',
shape=(vocab_size, transformer_width),
initializer=tf.keras.initializers.RandomNormal(stddev=0.02),
trainable=True
)
self.positional_embedding = self.add_weight(
name='positional_embedding',
shape=(self.context_length, transformer_width),
initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
trainable=True
)
self.ln_final = LayerNorm(transformer_width)
self.text_projection = self.add_weight(
name='text_projection',
shape=(transformer_width, embed_dim),
initializer=tf.keras.initializers.RandomNormal(stddev=transformer_width ** -0.5),
trainable=True
)
self.logit_scale = self.add_weight(
name='logit_scale',
shape=[],
initializer=tf.keras.initializers.Constant(np.log(1 / 0.07)),
trainable=True
)
def build_attention_mask(self):
mask = tf.ones((self.context_length, self.context_length))
mask = tf.linalg.band_part(mask, 0, -1) # zero out the upper diagonal
mask = mask * -1e9 # fill with -1e9
return mask
def encode_image(self, image):
return self.visual(image)
def encode_text(self, text):
x = tf.gather(self.token_embedding, text) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding
x = tf.transpose(x, (1, 0, 2)) # NLD -> LND
x = self.transformer(x)
x = tf.transpose(x, (1, 0, 2)) # LND -> NLD
x = self.ln_final(x)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = tf.matmul(tf.gather_nd(x, tf.stack([tf.range(x.shape[0], dtype='int32'),
tf.argmax(text, axis=-1, output_type='int32')], axis=1)), self.text_projection)
return x
def __call__(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# normalized features
image_features = image_features / tf.norm(image_features, axis=1, keepdims=True)
text_features = text_features / tf.norm(text_features, axis=1, keepdims=True)
# cosine similarity as logits
logit_scale = tf.math.exp(self.logit_scale)
logits_per_image = tf.matmul(logit_scale * image_features, tf.transpose(text_features))
logits_per_text = tf.transpose(logits_per_image)
# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text