File size: 14,507 Bytes
5646c73 c86eba0 5646c73 c86eba0 5646c73 c86eba0 5646c73 c86eba0 5646c73 c86eba0 5646c73 c86eba0 5646c73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 |
import tensorflow as tf
from tensorflow.keras.layers import Dense,Conv2d,BatchNormalization,LayerNormalization,MultiHeadAttention
from tensorflow.keras.layers import ZeroPadding2D,AveragePooling2D,Identity
from tensorflow.keras import Model
import numpy as np
from typing import Tuple, Union
class Bottleneck(tf.keras.layers.Layer):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
super(Bottleneck, self).__init__()
self.conv1 = Conv2d(planes, 1, use_bias=False)
self.bn1 = BatchNormalization()
self.relu1 = tf.nn.relu
self.zeropadding2d = ZeroPadding2D(padding=1)
self.conv2 = Conv2d(planes, 3, use_bias=False)
self.bn2 = BatchNormalization()
self.relu2 = tf.nn.relu
self.avgpool = AveragePooling2D(stride, stride, 'VALID') if stride > 1 else Identity()
self.conv3 = Conv2d(planes * self.expansion, 1, use_bias=False)
self.bn3 = BatchNormalization()
self.relu3 = tf.nn.relu
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = tf.keras.Sequential()
self.downsample.add(AveragePooling2D(stride, stride, 'VALID'))
self.downsample.add(Conv2d(planes * self.expansion, 1, strides=1, use_bias=False))
self.downsample.add(BatchNormalization())
def __call__(self, x):
identity = x
out = self.relu1(self.bn1(self.conv1(x)))
out = self.zeropadding2d(out)
out = self.relu2(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu3(out)
return out
class AttentionPool2d(tf.keras.layers.Layer):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
self.positional_embedding = self.add_weight(
name='positional_embedding',
shape=[self.spacial_dim ** 2 + 1, self.embed_dim],
initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1./self.embed_dim**0.5),
trainable=True
)
self.k_proj = Dense(embed_dim)
self.q_proj = Dense(embed_dim)
self.v_proj = Dense(embed_dim)
self.c_proj = Dense(output_dim or embed_dim)
self.num_heads = num_heads
def __call__(self, x):
shape = x.shape
batch_size = shape[0]
height = shape[1]
width = shape[2]
channels = shape[3]
new_shape = (batch_size, height * width, channels)
x = tf.transpose(tf.reshape(x, new_shape), (1, 0, 2))
x = tf.concat([tf.reduce_mean(x, axis=0, keepdims=True), x], axis=0) # (HW+1)NC
x = x + tf.cast(self.positional_embedding[:, None, :], x.dtype) # (HW+1)NC
tgt_len, bsz, embed_dim = x.shape
query=self.q_proj(x[:1])
key=self.k_proj(x)
value=self.v_proj(x)
query = tf.reshape(query, [bsz, 1, self.num_heads, -1])
query = tf.transpose(query, [0, 2, 1, 3])
query = tf.multiply(query, 1.0 / tf.math.sqrt(float(embed_dim)))
key = tf.reshape(key, [bsz, tgt_len, self.num_heads, -1])
key = tf.transpose(key, [0, 2, 3, 1])
value = tf.reshape(value, [bsz, tgt_len, self.num_heads, -1])
value = tf.transpose(value, [0, 2, 1, 3])
qk = tf.matmul(query, key)
w = tf.nn.softmax(qk)
wv = tf.reshape(tf.transpose(tf.matmul(w, value), [0, 2, 1, 3]), [1, bsz, -1])
x = self.c_proj(wv)
return tf.squeeze(x, 0)
class ModifiedResNet:
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
self.output_dim = output_dim
self.input_resolution = input_resolution
# the 3-layer stem
self.zeropadding2d = ZeroPadding2D(padding=1)
self.conv1 = Conv2d(width // 2, kernel_size=3, strides=2, use_bias=False)
self.bn1 = BatchNormalization()
self.relu1 = tf.nn.relu
self.conv2 = Conv2d(width // 2, kernel_size=3, use_bias=False)
self.bn2 = BatchNormalization()
self.relu2 = tf.nn.relu
self.conv3 = Conv2d(width, kernel_size=3, use_bias=False)
self.bn3 = BatchNormalization()
self.relu3 = tf.nn.relu
self.avgpool = AveragePooling2D(2, 2, 'VALID')
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
def _make_layer(self, planes, blocks, stride=1):
layers = tf.keras.Sequential()
layers.add(Bottleneck(self._inplanes, planes, stride))
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.add(Bottleneck(self._inplanes, planes))
return layers
def __call__(self, x):
def stem(x):
x = self.zeropadding2d(x)
x = self.conv1(x)
x = self.relu1(self.bn1(x))
x = self.zeropadding2d(x)
x = self.conv2(x)
x = self.relu2(self.bn2(x))
x = self.zeropadding2d(x)
x = self.conv3(x)
x = self.relu3(self.bn3(x))
x = self.avgpool(x)
return x
x = stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
class LayerNorm:
"""Subclass torch's LayerNorm to handle fp16."""
def __init__(self, input_size):
self.layer_norm = LayerNormalization()
def __call__(self, x):
orig_type = x.dtype
ret = self.layer_norm(tf.cast(x, tf.float32))
return tf.cast(ret, orig_type)
class QuickGELU(tf.keras.layers.Layer):
def __init__(self):
super(QuickGELU, self).__init__()
def __call__(self, x):
return x * tf.nn.sigmoid(1.702 * x)
class ResidualAttentionBlock(tf.keras.layers.Layer):
def __init__(self, d_model: int, n_head: int, attn_mask = None):
super(ResidualAttentionBlock, self).__init__()
self.attn = MultiHeadAttention(n_head, d_model)
self.ln_1 = LayerNorm(d_model)
self.mlp = tf.keras.Sequential()
self.mlp.add(Dense(d_model * 4))
self.mlp.add(QuickGELU())
self.mlp.add(Dense(d_model))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x):
self.attn_mask = tf.cast(self.attn_mask, x.dtype) if self.attn_mask is not None else None
return self.attn(x, x, attention_mask=self.attn_mask)[0]
def __call__(self, x):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer:
def __init__(self, width: int, layers: int, heads: int, attn_mask = None):
self.width = width
self.layers = layers
self.resblocks = tf.keras.Sequential()
for _ in range(layers):
self.resblocks.add(ResidualAttentionBlock(width, heads, attn_mask))
def __call__(self, x):
return self.resblocks(x)
class VisionTransformer(tf.keras.layers.Layer):
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = Conv2d(width, kernel_size=patch_size, strides=patch_size, use_bias=False)
scale = width ** -0.5
self.class_embedding = self.add_weight(
name='class_embedding',
shape=[self.width],
initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1.0) * self.scale,
trainable=True
)
self.positional_embedding = self.add_weight(
name='positional_embedding',
shape=[(self.input_resolution // self.patch_size) ** 2 + 1, self.width],
initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1.0) * self.scale,
trainable=True
)
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.ln_post = LayerNorm(width)
self.proj = tf.Variable(scale * tf.random.normal(width, output_dim))
def __call__(self, x, train_flag=True):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = tf.reshape(x, [x.shape[0], x.shape[1], -1]) # shape = [*, width, grid ** 2]
x = tf.transpose(x, (0, 2, 1)) # shape = [*, grid ** 2, width]
x = tf.concat([tf.cast(self.class_embedding, x.dtype) + tf.zeros([x.shape[0], 1, x.shape[-1]], dtype=x.dtype), x], axis=1) # shape = [*, grid ** 2 + 1, width]
x = x + tf.cast(self.positional_embedding, x.dtype)
x = self.ln_pre(x)
x = tf.transpose(x, (1, 0, 2)) # NLD -> LND
x = self.transformer(x)
x = tf.transpose(x, (1, 0, 2)) # LND -> NLD
x = self.ln_post(x[:, 0, :])
if self.proj is not None:
x = tf.matmul(x, self.proj)
return x
class CLIP(Model):
def __init__(self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int
):
super(CLIP, self).__init__()
self.context_length = context_length
if isinstance(vision_layers, (tuple, list)):
vision_heads = vision_width * 32 // 64
self.visual = ModifiedResNet(
layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width
)
else:
vision_heads = vision_width // 64
self.visual = VisionTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim
)
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask()
)
self.vocab_size = vocab_size
self.token_embedding = self.add_weight(
name='token_embedding',
shape=(vocab_size, transformer_width),
initializer=tf.keras.initializers.RandomNormal(stddev=0.02),
trainable=True
)
self.positional_embedding = self.add_weight(
name='positional_embedding',
shape=(self.context_length, transformer_width),
initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
trainable=True
)
self.ln_final = LayerNorm(transformer_width)
self.text_projection = self.add_weight(
name='text_projection',
shape=(transformer_width, embed_dim),
initializer=tf.keras.initializers.RandomNormal(stddev=transformer_width ** -0.5),
trainable=True
)
self.logit_scale = self.add_weight(
name='logit_scale',
shape=[],
initializer=tf.keras.initializers.Constant(np.log(1 / 0.07)),
trainable=True
)
def build_attention_mask(self):
mask = tf.ones((self.context_length, self.context_length))
mask = tf.linalg.band_part(mask, 0, -1) # zero out the upper diagonal
mask = mask * -1e9 # fill with -1e9
return mask
def encode_image(self, image):
return self.visual(image)
def encode_text(self, text):
x = tf.gather(self.token_embedding, text) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding
x = tf.transpose(x, (1, 0, 2)) # NLD -> LND
x = self.transformer(x)
x = tf.transpose(x, (1, 0, 2)) # LND -> NLD
x = self.ln_final(x)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = tf.matmul(tf.gather_nd(x, tf.stack([tf.range(x.shape[0], dtype='int32'),
tf.argmax(text, axis=-1, output_type='int32')], axis=1)), self.text_projection)
return x
def __call__(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# normalized features
image_features = image_features / tf.norm(image_features, axis=1, keepdims=True)
text_features = text_features / tf.norm(text_features, axis=1, keepdims=True)
# cosine similarity as logits
logit_scale = tf.math.exp(self.logit_scale)
logits_per_image = tf.matmul(logit_scale * image_features, tf.transpose(text_features))
logits_per_text = tf.transpose(logits_per_image)
# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text |