File size: 10,831 Bytes
f53b39e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 |
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
# --------------------------------------------------------
# CroCo model during pretraining
# --------------------------------------------------------
import torch
import torch.nn as nn
torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
from functools import partial
from models.blocks import Block, DecoderBlock, PatchEmbed
from models.pos_embed import get_2d_sincos_pos_embed, RoPE2D
from models.masking import RandomMask
class CroCoNet(nn.Module):
def __init__(self,
img_size=224, # input image size
patch_size=16, # patch_size
mask_ratio=0.9, # ratios of masked tokens
enc_embed_dim=768, # encoder feature dimension
enc_depth=12, # encoder depth
enc_num_heads=12, # encoder number of heads in the transformer block
dec_embed_dim=512, # decoder feature dimension
dec_depth=8, # decoder depth
dec_num_heads=16, # decoder number of heads in the transformer block
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
norm_im2_in_dec=True, # whether to apply normalization of the 'memory' = (second image) in the decoder
pos_embed='cosine', # positional embedding (either cosine or RoPE100)
):
super(CroCoNet, self).__init__()
self.enc_depth = enc_depth
self.enc_embed_dim = enc_embed_dim
self.dec_depth = dec_depth
self.dec_embed_dim = dec_embed_dim
# patch embeddings (with initialization done as in MAE)
self._set_patch_embed(img_size, patch_size, enc_embed_dim)
# mask generations
self._set_mask_generator(self.patch_embed.num_patches, mask_ratio)
self.pos_embed = pos_embed
if pos_embed=='cosine':
# positional embedding of the encoder
enc_pos_embed = get_2d_sincos_pos_embed(enc_embed_dim, self.patch_embed.grid_size, n_cls_token=0)
self.register_buffer('enc_pos_embed', torch.from_numpy(enc_pos_embed).float())
# positional embedding of the decoder
dec_pos_embed = get_2d_sincos_pos_embed(dec_embed_dim, self.patch_embed.grid_size, n_cls_token=0)
self.register_buffer('dec_pos_embed', torch.from_numpy(dec_pos_embed).float())
# pos embedding in each block
self.rope = None # nothing for cosine
elif pos_embed.startswith('RoPE'): # eg RoPE100
self.enc_pos_embed = None # nothing to add in the encoder with RoPE
self.dec_pos_embed = None # nothing to add in the decoder with RoPE
if RoPE2D is None: raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions")
freq = float(pos_embed[len('RoPE'):])
self.rope = RoPE2D(freq=freq)
else:
raise NotImplementedError('Unknown pos_embed '+pos_embed)
# transformer for the encoder
self.enc_blocks = nn.ModuleList([
Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, rope=self.rope)
for i in range(enc_depth)])
self.enc_norm = norm_layer(enc_embed_dim)
self.dec_blocks_pc = nn.ModuleList([
Block(dec_embed_dim, dec_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, rope=self.rope)
for i in range(dec_depth//2-2)])
# masked tokens
self._set_mask_token(dec_embed_dim)
# decoder
self._set_decoder(enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec)
# prediction head
self._set_prediction_head(dec_embed_dim, patch_size)
# initializer weights
self.initialize_weights()
def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
self.patch_embed = PatchEmbed(img_size, patch_size, 3, enc_embed_dim)
def _set_mask_generator(self, num_patches, mask_ratio):
self.mask_generator = RandomMask(num_patches, mask_ratio)
def _set_mask_token(self, dec_embed_dim):
self.mask_token = nn.Parameter(torch.zeros(1, 1, dec_embed_dim))
def _set_decoder(self, enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec):
self.dec_depth = dec_depth
self.dec_embed_dim = dec_embed_dim
# transfer from encoder to decoder
self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True)
# transformer for the decoder
self.dec_blocks = nn.ModuleList([
DecoderBlock(dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope)
for i in range(dec_depth)])
# final norm layer
self.dec_norm = norm_layer(dec_embed_dim)
def _set_prediction_head(self, dec_embed_dim, patch_size):
self.prediction_head = nn.Linear(dec_embed_dim, patch_size**2 * 3, bias=True)
def initialize_weights(self):
# patch embed
self.patch_embed._init_weights()
# mask tokens
if self.mask_token is not None: torch.nn.init.normal_(self.mask_token, std=.02)
# linears and layer norms
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def _encode_image(self, image, do_mask=False, return_all_blocks=False):
"""
image has B x 3 x img_size x img_size
do_mask: whether to perform masking or not
return_all_blocks: if True, return the features at the end of every block
instead of just the features from the last block (eg for some prediction heads)
"""
# embed the image into patches (x has size B x Npatches x C)
# and get position if each return patch (pos has size B x Npatches x 2)
x, pos = self.patch_embed(image)
# add positional embedding without cls token
if self.enc_pos_embed is not None:
x = x + self.enc_pos_embed[None,...]
# apply masking
B,N,C = x.size()
if do_mask:
masks = self.mask_generator(x)
x = x[~masks].view(B, -1, C)
posvis = pos[~masks].view(B, -1, 2)
else:
B,N,C = x.size()
masks = torch.zeros((B,N), dtype=bool)
posvis = pos
# now apply the transformer encoder and normalization
if return_all_blocks:
out = []
for blk in self.enc_blocks:
x = blk(x, posvis)
out.append(x)
out[-1] = self.enc_norm(out[-1])
return out, pos, masks
else:
for blk in self.enc_blocks:
x = blk(x, posvis)
x = self.enc_norm(x)
return x, pos, masks
def _decoder(self, feat1, pos1, masks1, feat2, pos2, return_all_blocks=False):
"""
return_all_blocks: if True, return the features at the end of every block
instead of just the features from the last block (eg for some prediction heads)
masks1 can be None => assume image1 fully visible
"""
# encoder to decoder layer
visf1 = self.decoder_embed(feat1)
f2 = self.decoder_embed(feat2)
# append masked tokens to the sequence
B,Nenc,C = visf1.size()
if masks1 is None: # downstreams
f1_ = visf1
else: # pretraining
Ntotal = masks1.size(1)
f1_ = self.mask_token.repeat(B, Ntotal, 1).to(dtype=visf1.dtype)
f1_[~masks1] = visf1.view(B * Nenc, C)
# add positional embedding
if self.dec_pos_embed is not None:
f1_ = f1_ + self.dec_pos_embed
f2 = f2 + self.dec_pos_embed
# apply Transformer blocks
out = f1_
out2 = f2
if return_all_blocks:
_out, out = out, []
for blk in self.dec_blocks:
_out, out2 = blk(_out, out2, pos1, pos2)
out.append(_out)
out[-1] = self.dec_norm(out[-1])
else:
for blk in self.dec_blocks:
out, out2 = blk(out, out2, pos1, pos2)
out = self.dec_norm(out)
return out
def patchify(self, imgs):
"""
imgs: (B, 3, H, W)
x: (B, L, patch_size**2 *3)
"""
p = self.patch_embed.patch_size[0]
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
return x
def unpatchify(self, x, channels=3):
"""
x: (N, L, patch_size**2 *channels)
imgs: (N, 3, H, W)
"""
patch_size = self.patch_embed.patch_size[0]
h = w = int(x.shape[1]**.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, patch_size, patch_size, channels))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], channels, h * patch_size, h * patch_size))
return imgs
def forward(self, img1, img2):
"""
img1: tensor of size B x 3 x img_size x img_size
img2: tensor of size B x 3 x img_size x img_size
out will be B x N x (3*patch_size*patch_size)
masks are also returned as B x N just in case
"""
# encoder of the masked first image
feat1, pos1, mask1 = self._encode_image(img1, do_mask=True)
# encoder of the second image
feat2, pos2, _ = self._encode_image(img2, do_mask=False)
# decoder
decfeat = self._decoder(feat1, pos1, mask1, feat2, pos2)
# prediction head
out = self.prediction_head(decfeat)
# get target
target = self.patchify(img1)
return out, mask1, target
|