Spaces:
Running
on
Zero
Running
on
Zero
SunderAli17
commited on
Create utils.py
Browse files- evaclip/utils.py +323 -0
evaclip/utils.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from itertools import repeat
|
2 |
+
import collections.abc
|
3 |
+
import logging
|
4 |
+
import math
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn as nn
|
9 |
+
from torchvision.ops.misc import FrozenBatchNorm2d
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
# open CLIP
|
13 |
+
def resize_clip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
|
14 |
+
# Rescale the grid of position embeddings when loading from state_dict
|
15 |
+
old_pos_embed = state_dict.get('visual.positional_embedding', None)
|
16 |
+
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
|
17 |
+
return
|
18 |
+
grid_size = to_2tuple(model.visual.grid_size)
|
19 |
+
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
20 |
+
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
21 |
+
if new_seq_len == old_pos_embed.shape[0]:
|
22 |
+
return
|
23 |
+
|
24 |
+
if extra_tokens:
|
25 |
+
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
26 |
+
else:
|
27 |
+
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
28 |
+
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
|
29 |
+
|
30 |
+
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
31 |
+
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
32 |
+
pos_emb_img = F.interpolate(
|
33 |
+
pos_emb_img,
|
34 |
+
size=grid_size,
|
35 |
+
mode=interpolation,
|
36 |
+
align_corners=True,
|
37 |
+
)
|
38 |
+
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
39 |
+
if pos_emb_tok is not None:
|
40 |
+
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
41 |
+
else:
|
42 |
+
new_pos_embed = pos_emb_img
|
43 |
+
state_dict['visual.positional_embedding'] = new_pos_embed
|
44 |
+
|
45 |
+
|
46 |
+
def resize_visual_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
|
47 |
+
# Rescale the grid of position embeddings when loading from state_dict
|
48 |
+
old_pos_embed = state_dict.get('positional_embedding', None)
|
49 |
+
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
|
50 |
+
return
|
51 |
+
grid_size = to_2tuple(model.visual.grid_size)
|
52 |
+
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
53 |
+
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
54 |
+
if new_seq_len == old_pos_embed.shape[0]:
|
55 |
+
return
|
56 |
+
|
57 |
+
if extra_tokens:
|
58 |
+
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
59 |
+
else:
|
60 |
+
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
61 |
+
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
|
62 |
+
|
63 |
+
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
64 |
+
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
65 |
+
pos_emb_img = F.interpolate(
|
66 |
+
pos_emb_img,
|
67 |
+
size=grid_size,
|
68 |
+
mode=interpolation,
|
69 |
+
align_corners=True,
|
70 |
+
)
|
71 |
+
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
72 |
+
if pos_emb_tok is not None:
|
73 |
+
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
74 |
+
else:
|
75 |
+
new_pos_embed = pos_emb_img
|
76 |
+
state_dict['positional_embedding'] = new_pos_embed
|
77 |
+
|
78 |
+
def resize_evaclip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
|
79 |
+
all_keys = list(state_dict.keys())
|
80 |
+
# interpolate position embedding
|
81 |
+
if 'visual.pos_embed' in state_dict:
|
82 |
+
pos_embed_checkpoint = state_dict['visual.pos_embed']
|
83 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
84 |
+
num_patches = model.visual.patch_embed.num_patches
|
85 |
+
num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
|
86 |
+
# height (== width) for the checkpoint position embedding
|
87 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
88 |
+
# height (== width) for the new position embedding
|
89 |
+
new_size = int(num_patches ** 0.5)
|
90 |
+
# class_token and dist_token are kept unchanged
|
91 |
+
if orig_size != new_size:
|
92 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
93 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
94 |
+
# only the position tokens are interpolated
|
95 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
96 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
97 |
+
pos_tokens = torch.nn.functional.interpolate(
|
98 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
99 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
100 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
101 |
+
state_dict['visual.pos_embed'] = new_pos_embed
|
102 |
+
|
103 |
+
patch_embed_proj = state_dict['visual.patch_embed.proj.weight']
|
104 |
+
patch_size = model.visual.patch_embed.patch_size
|
105 |
+
state_dict['visual.patch_embed.proj.weight'] = torch.nn.functional.interpolate(
|
106 |
+
patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
|
107 |
+
|
108 |
+
|
109 |
+
def resize_eva_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
|
110 |
+
all_keys = list(state_dict.keys())
|
111 |
+
# interpolate position embedding
|
112 |
+
if 'pos_embed' in state_dict:
|
113 |
+
pos_embed_checkpoint = state_dict['pos_embed']
|
114 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
115 |
+
num_patches = model.visual.patch_embed.num_patches
|
116 |
+
num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
|
117 |
+
# height (== width) for the checkpoint position embedding
|
118 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
119 |
+
# height (== width) for the new position embedding
|
120 |
+
new_size = int(num_patches ** 0.5)
|
121 |
+
# class_token and dist_token are kept unchanged
|
122 |
+
if orig_size != new_size:
|
123 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
124 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
125 |
+
# only the position tokens are interpolated
|
126 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
127 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
128 |
+
pos_tokens = torch.nn.functional.interpolate(
|
129 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
130 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
131 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
132 |
+
state_dict['pos_embed'] = new_pos_embed
|
133 |
+
|
134 |
+
patch_embed_proj = state_dict['patch_embed.proj.weight']
|
135 |
+
patch_size = model.visual.patch_embed.patch_size
|
136 |
+
state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate(
|
137 |
+
patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
|
138 |
+
|
139 |
+
|
140 |
+
def resize_rel_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
|
141 |
+
all_keys = list(state_dict.keys())
|
142 |
+
for key in all_keys:
|
143 |
+
if "relative_position_index" in key:
|
144 |
+
state_dict.pop(key)
|
145 |
+
|
146 |
+
if "relative_position_bias_table" in key:
|
147 |
+
rel_pos_bias = state_dict[key]
|
148 |
+
src_num_pos, num_attn_heads = rel_pos_bias.size()
|
149 |
+
dst_num_pos, _ = model.visual.state_dict()[key].size()
|
150 |
+
dst_patch_shape = model.visual.patch_embed.patch_shape
|
151 |
+
if dst_patch_shape[0] != dst_patch_shape[1]:
|
152 |
+
raise NotImplementedError()
|
153 |
+
num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
|
154 |
+
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
|
155 |
+
dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
|
156 |
+
if src_size != dst_size:
|
157 |
+
print("Position interpolate for %s from %dx%d to %dx%d" % (
|
158 |
+
key, src_size, src_size, dst_size, dst_size))
|
159 |
+
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
|
160 |
+
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
|
161 |
+
|
162 |
+
def geometric_progression(a, r, n):
|
163 |
+
return a * (1.0 - r ** n) / (1.0 - r)
|
164 |
+
|
165 |
+
left, right = 1.01, 1.5
|
166 |
+
while right - left > 1e-6:
|
167 |
+
q = (left + right) / 2.0
|
168 |
+
gp = geometric_progression(1, q, src_size // 2)
|
169 |
+
if gp > dst_size // 2:
|
170 |
+
right = q
|
171 |
+
else:
|
172 |
+
left = q
|
173 |
+
|
174 |
+
# if q > 1.090307:
|
175 |
+
# q = 1.090307
|
176 |
+
|
177 |
+
dis = []
|
178 |
+
cur = 1
|
179 |
+
for i in range(src_size // 2):
|
180 |
+
dis.append(cur)
|
181 |
+
cur += q ** (i + 1)
|
182 |
+
|
183 |
+
r_ids = [-_ for _ in reversed(dis)]
|
184 |
+
|
185 |
+
x = r_ids + [0] + dis
|
186 |
+
y = r_ids + [0] + dis
|
187 |
+
|
188 |
+
t = dst_size // 2.0
|
189 |
+
dx = np.arange(-t, t + 0.1, 1.0)
|
190 |
+
dy = np.arange(-t, t + 0.1, 1.0)
|
191 |
+
|
192 |
+
print("Original positions = %s" % str(x))
|
193 |
+
print("Target positions = %s" % str(dx))
|
194 |
+
|
195 |
+
all_rel_pos_bias = []
|
196 |
+
|
197 |
+
for i in range(num_attn_heads):
|
198 |
+
z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
|
199 |
+
f = F.interpolate.interp2d(x, y, z, kind='cubic')
|
200 |
+
all_rel_pos_bias.append(
|
201 |
+
torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
|
202 |
+
|
203 |
+
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
|
204 |
+
|
205 |
+
new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
|
206 |
+
state_dict[key] = new_rel_pos_bias
|
207 |
+
|
208 |
+
# interpolate position embedding
|
209 |
+
if 'pos_embed' in state_dict:
|
210 |
+
pos_embed_checkpoint = state_dict['pos_embed']
|
211 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
212 |
+
num_patches = model.visual.patch_embed.num_patches
|
213 |
+
num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
|
214 |
+
# height (== width) for the checkpoint position embedding
|
215 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
216 |
+
# height (== width) for the new position embedding
|
217 |
+
new_size = int(num_patches ** 0.5)
|
218 |
+
# class_token and dist_token are kept unchanged
|
219 |
+
if orig_size != new_size:
|
220 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
221 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
222 |
+
# only the position tokens are interpolated
|
223 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
224 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
225 |
+
pos_tokens = torch.nn.functional.interpolate(
|
226 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
227 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
228 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
229 |
+
state_dict['pos_embed'] = new_pos_embed
|
230 |
+
|
231 |
+
patch_embed_proj = state_dict['patch_embed.proj.weight']
|
232 |
+
patch_size = model.visual.patch_embed.patch_size
|
233 |
+
state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate(
|
234 |
+
patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
|
235 |
+
|
236 |
+
|
237 |
+
def freeze_batch_norm_2d(module, module_match={}, name=''):
|
238 |
+
"""
|
239 |
+
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
|
240 |
+
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
|
241 |
+
returned. Otherwise, the module is walked recursively and submodules are converted in place.
|
242 |
+
Args:
|
243 |
+
module (torch.nn.Module): Any PyTorch module.
|
244 |
+
module_match (dict): Dictionary of full module names to freeze (all if empty)
|
245 |
+
name (str): Full module name (prefix)
|
246 |
+
Returns:
|
247 |
+
torch.nn.Module: Resulting module
|
248 |
+
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
|
249 |
+
"""
|
250 |
+
res = module
|
251 |
+
is_match = True
|
252 |
+
if module_match:
|
253 |
+
is_match = name in module_match
|
254 |
+
if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
|
255 |
+
res = FrozenBatchNorm2d(module.num_features)
|
256 |
+
res.num_features = module.num_features
|
257 |
+
res.affine = module.affine
|
258 |
+
if module.affine:
|
259 |
+
res.weight.data = module.weight.data.clone().detach()
|
260 |
+
res.bias.data = module.bias.data.clone().detach()
|
261 |
+
res.running_mean.data = module.running_mean.data
|
262 |
+
res.running_var.data = module.running_var.data
|
263 |
+
res.eps = module.eps
|
264 |
+
else:
|
265 |
+
for child_name, child in module.named_children():
|
266 |
+
full_child_name = '.'.join([name, child_name]) if name else child_name
|
267 |
+
new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
|
268 |
+
if new_child is not child:
|
269 |
+
res.add_module(child_name, new_child)
|
270 |
+
return res
|
271 |
+
|
272 |
+
|
273 |
+
# From PyTorch internals
|
274 |
+
def _ntuple(n):
|
275 |
+
def parse(x):
|
276 |
+
if isinstance(x, collections.abc.Iterable):
|
277 |
+
return x
|
278 |
+
return tuple(repeat(x, n))
|
279 |
+
return parse
|
280 |
+
|
281 |
+
|
282 |
+
to_1tuple = _ntuple(1)
|
283 |
+
to_2tuple = _ntuple(2)
|
284 |
+
to_3tuple = _ntuple(3)
|
285 |
+
to_4tuple = _ntuple(4)
|
286 |
+
to_ntuple = lambda n, x: _ntuple(n)(x)
|
287 |
+
|
288 |
+
|
289 |
+
def is_logging(args):
|
290 |
+
def is_global_master(args):
|
291 |
+
return args.rank == 0
|
292 |
+
|
293 |
+
def is_local_master(args):
|
294 |
+
return args.local_rank == 0
|
295 |
+
|
296 |
+
def is_master(args, local=False):
|
297 |
+
return is_local_master(args) if local else is_global_master(args)
|
298 |
+
return is_master
|
299 |
+
|
300 |
+
|
301 |
+
class AllGather(torch.autograd.Function):
|
302 |
+
"""An autograd function that performs allgather on a tensor.
|
303 |
+
Performs all_gather operation on the provided tensors.
|
304 |
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
305 |
+
"""
|
306 |
+
|
307 |
+
@staticmethod
|
308 |
+
def forward(ctx, tensor, rank, world_size):
|
309 |
+
tensors_gather = [torch.empty_like(tensor) for _ in range(world_size)]
|
310 |
+
torch.distributed.all_gather(tensors_gather, tensor)
|
311 |
+
ctx.rank = rank
|
312 |
+
ctx.batch_size = tensor.shape[0]
|
313 |
+
return torch.cat(tensors_gather, 0)
|
314 |
+
|
315 |
+
@staticmethod
|
316 |
+
def backward(ctx, grad_output):
|
317 |
+
return (
|
318 |
+
grad_output[ctx.batch_size * ctx.rank: ctx.batch_size * (ctx.rank + 1)],
|
319 |
+
None,
|
320 |
+
None
|
321 |
+
)
|
322 |
+
|
323 |
+
allgather = AllGather.apply
|