File size: 10,692 Bytes
8dd41a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 |
import math
from posixpath import basename, dirname, join
# import clip
from clip.model import convert_weights
import torch
import json
from torch import nn
from torch.nn import functional as nnf
from torch.nn.modules import activation
from torch.nn.modules.activation import ReLU
from torchvision import transforms
normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
from torchvision.models import ResNet
def process_prompts(conditional, prompt_list, conditional_map):
# DEPRECATED
# randomly sample a synonym
words = [conditional_map[int(i)] for i in conditional]
words = [syns[torch.multinomial(torch.ones(len(syns)), 1, replacement=True).item()] for syns in words]
words = [w.replace('_', ' ') for w in words]
if prompt_list is not None:
prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
prompts = [prompt_list[i] for i in prompt_indices]
else:
prompts = ['a photo of {}'] * (len(words))
return [promt.format(w) for promt, w in zip(prompts, words)]
class VITDenseBase(nn.Module):
def rescaled_pos_emb(self, new_size):
assert len(new_size) == 2
a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape)
b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T
return torch.cat([self.model.positional_embedding[:1], b])
def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None):
with torch.no_grad():
x_inp = nnf.interpolate(x_inp, (384, 384))
x = self.model.patch_embed(x_inp)
cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
if self.model.dist_token is None:
x = torch.cat((cls_token, x), dim=1)
else:
x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
x = self.model.pos_drop(x + self.model.pos_embed)
activations = []
for i, block in enumerate(self.model.blocks):
x = block(x)
if i in extract_layers:
# permute to be compatible with CLIP
activations += [x.permute(1,0,2)]
x = self.model.norm(x)
x = self.model.head(self.model.pre_logits(x[:, 0]))
# again for CLIP compatibility
# x = x.permute(1, 0, 2)
return x, activations, None
def sample_prompts(self, words, prompt_list=None):
prompt_list = prompt_list if prompt_list is not None else self.prompt_list
prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True)
prompts = [prompt_list[i] for i in prompt_indices]
return [promt.format(w) for promt, w in zip(prompts, words)]
def get_cond_vec(self, conditional, batch_size):
# compute conditional from a single string
if conditional is not None and type(conditional) == str:
cond = self.compute_conditional(conditional)
cond = cond.repeat(batch_size, 1)
# compute conditional from string list/tuple
elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str:
assert len(conditional) == batch_size
cond = self.compute_conditional(conditional)
# use conditional directly
elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2:
cond = conditional
# compute conditional from image
elif conditional is not None and type(conditional) == torch.Tensor:
with torch.no_grad():
cond, _, _ = self.visual_forward(conditional)
else:
raise ValueError('invalid conditional')
return cond
def compute_conditional(self, conditional):
import clip
dev = next(self.parameters()).device
if type(conditional) in {list, tuple}:
text_tokens = clip.tokenize(conditional).to(dev)
cond = self.clip_model.encode_text(text_tokens)
else:
if conditional in self.precomputed_prompts:
cond = self.precomputed_prompts[conditional].float().to(dev)
else:
text_tokens = clip.tokenize([conditional]).to(dev)
cond = self.clip_model.encode_text(text_tokens)[0]
return cond
class VITDensePredT(VITDenseBase):
def __init__(self, extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed',
depth=3, extra_blocks=0, reduce_cond=None, fix_shift=False,
learn_trans_conv_only=False, refine=None, limit_to_clip_only=False, upsample=False,
add_calibration=False, process_cond=None, not_pretrained=False):
super().__init__()
# device = 'cpu'
self.extract_layers = extract_layers
self.cond_layer = cond_layer
self.limit_to_clip_only = limit_to_clip_only
self.process_cond = None
if add_calibration:
self.calibration_conds = 1
self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None
self.add_activation1 = True
import timm
self.model = timm.create_model('vit_base_patch16_384', pretrained=True)
self.model.head = nn.Linear(768, 512 if reduce_cond is None else reduce_cond)
for p in self.model.parameters():
p.requires_grad_(False)
import clip
self.clip_model, _ = clip.load('ViT-B/16', device='cpu', jit=False)
# del self.clip_model.visual
self.token_shape = (14, 14)
# conditional
if reduce_cond is not None:
self.reduce_cond = nn.Linear(512, reduce_cond)
for p in self.reduce_cond.parameters():
p.requires_grad_(False)
else:
self.reduce_cond = None
# self.film = AVAILABLE_BLOCKS['film'](512, 128)
self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim)
# DEPRECATED
# self.conditional_map = {c['id']: c['synonyms'] for c in json.load(open(cond_map))}
assert len(self.extract_layers) == depth
self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)])
self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))])
self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)])
trans_conv_ks = (16, 16)
self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks)
# refinement and trans conv
if learn_trans_conv_only:
for p in self.parameters():
p.requires_grad_(False)
for p in self.trans_conv.parameters():
p.requires_grad_(True)
if prompt == 'fixed':
self.prompt_list = ['a photo of a {}.']
elif prompt == 'shuffle':
self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.']
elif prompt == 'shuffle+':
self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.',
'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.',
'a bad photo of a {}.', 'a photo of the {}.']
elif prompt == 'shuffle_clip':
from models.clip_prompts import imagenet_templates
self.prompt_list = imagenet_templates
if process_cond is not None:
if process_cond == 'clamp' or process_cond[0] == 'clamp':
val = process_cond[1] if type(process_cond) in {list, tuple} else 0.2
def clamp_vec(x):
return torch.clamp(x, -val, val)
self.process_cond = clamp_vec
elif process_cond.endswith('.pth'):
shift = torch.load(process_cond)
def add_shift(x):
return x + shift.to(x.device)
self.process_cond = add_shift
import pickle
precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb'))
self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()}
def forward(self, inp_image, conditional=None, return_features=False, mask=None):
assert type(return_features) == bool
# inp_image = inp_image.to(self.model.positional_embedding.device)
if mask is not None:
raise ValueError('mask not supported')
# x_inp = normalize(inp_image)
x_inp = inp_image
bs, dev = inp_image.shape[0], x_inp.device
inp_image_size = inp_image.shape[2:]
cond = self.get_cond_vec(conditional, bs)
visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers))
activation1 = activations[0]
activations = activations[1:]
a = None
for i, (activation, block, reduce) in enumerate(zip(activations[::-1], self.blocks, self.reduces)):
if a is not None:
a = reduce(activation) + a
else:
a = reduce(activation)
if i == self.cond_layer:
if self.reduce_cond is not None:
cond = self.reduce_cond(cond)
a = self.film_mul(cond) * a + self.film_add(cond)
a = block(a)
for block in self.extra_blocks:
a = a + block(a)
a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens
size = int(math.sqrt(a.shape[2]))
a = a.view(bs, a.shape[1], size, size)
if self.trans_conv is not None:
a = self.trans_conv(a)
if self.upsample_proj is not None:
a = self.upsample_proj(a)
a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear')
a = nnf.interpolate(a, inp_image_size)
if return_features:
return a, visual_q, cond, [activation1] + activations
else:
return a,
|