kd5678 commited on
Commit
8a6cb91
·
verified ·
1 Parent(s): 43d2c2f

Upload InpaintReward.py

Browse files
Files changed (1) hide show
  1. InpaintReward.py +228 -0
InpaintReward.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import clip
4
+ from torch.nn import TransformerEncoder, TransformerEncoderLayer
5
+ from torchvision import transforms
6
+
7
+
8
+ class MLP(nn.Module):
9
+ def __init__(self, input_size):
10
+ super().__init__()
11
+ self.input_size = input_size
12
+
13
+ self.layers = nn.Sequential(
14
+ nn.Linear(self.input_size, 512),
15
+ nn.Dropout(0.2),
16
+ nn.Linear(512, 256),
17
+ )
18
+ self.last_layer = nn.Linear(256, 1, bias=False)
19
+ self.last_layer_weight = self.last_layer.weight
20
+ # initial MLP param
21
+ for name, param in self.layers.named_parameters():
22
+ if 'weight' in name:
23
+ nn.init.normal_(param, mean=0.0, std=1.0/(self.input_size+1))
24
+ if 'bias' in name:
25
+ nn.init.constant_(param, val=0)
26
+
27
+ for name, param in self.last_layer.named_parameters():
28
+ if 'weight' in name:
29
+ nn.init.normal_(param, mean=0.0, std=1.0/(self.input_size+1))
30
+ if 'bias' in name:
31
+ nn.init.constant_(param, val=0)
32
+
33
+ def forward(self, input):
34
+ features = self.layers(input)
35
+ out = self.last_layer(features)
36
+ return out, features
37
+
38
+
39
+ class ViTBlock(nn.Module):
40
+ def __init__(self, feature_dim, num_heads, mlp_dim, dropout=0.1):
41
+ super(ViTBlock, self).__init__()
42
+ # Transformer encoder layer
43
+ self.encoder_layer = TransformerEncoderLayer(
44
+ d_model=feature_dim,
45
+ nhead=num_heads,
46
+ dim_feedforward=mlp_dim,
47
+ dropout=dropout,
48
+ batch_first=True # Input shape: (batch_size, seq_length, feature_dim)
49
+ )
50
+ self.transformer_encoder = TransformerEncoder(self.encoder_layer, num_layers=1)
51
+
52
+ def forward(self, x):
53
+ x = self.transformer_encoder(x)
54
+ return x
55
+
56
+
57
+ class ImageReward(nn.Module):
58
+ def __init__(self, config, device='cpu'):
59
+ super().__init__()
60
+ self.config = config
61
+ self.device = device
62
+
63
+ self.clip_model, self.preprocess = clip.load("ViT-B/32")
64
+ self.clip_model = self.clip_model.float()
65
+ self.mlp = MLP(self.config['ImageReward']['mlp_dim'])
66
+ self.vit_block = ViTBlock(self.config["ViT"]["feature_dim"], self.config["ViT"]["num_heads"], self.config["ViT"]["mlp_dim"])
67
+
68
+ self.toImage = transforms.ToPILImage()
69
+
70
+ self.mean = 0.4064 #0.65823
71
+ self.std = 2.3021 #8.5400
72
+
73
+ if self.config.fix_base:
74
+ self.clip_model.requires_grad_(False)
75
+
76
+ def score(self, inpaint_list, masks_rgb):
77
+
78
+ inpaint_embeds_bs, mask_rgb_embeds_bs = [], []
79
+
80
+ for bs in range(len(inpaint_list)):
81
+ if isinstance(inpaint_list[bs], torch.Tensor):
82
+ inpaint = self.toImage(inpaint_list[bs])
83
+ else:
84
+ inpaint = inpaint_list[bs]
85
+ inpaint = self.preprocess(inpaint).unsqueeze(0)
86
+ if isinstance(masks_rgb[bs], torch.Tensor):
87
+ mask_rgb = self.toImage(masks_rgb[bs])
88
+ else:
89
+ mask_rgb = masks_rgb[bs]
90
+ mask_rgb = self.preprocess(masks_rgb[bs]).unsqueeze(0)
91
+ inpt, msk = inpaint.to(self.device), mask_rgb.to(self.device)
92
+
93
+ inpt_embeds = self.clip_model.encode_image(inpt).to(torch.float32)
94
+ msk_embeds = self.clip_model.encode_image(msk).to(torch.float32)
95
+
96
+ inpaint_embeds_bs.append(inpt_embeds.squeeze(0))
97
+ mask_rgb_embeds_bs.append(msk_embeds.squeeze(0))
98
+
99
+
100
+ emb_inpaint = torch.stack(inpaint_embeds_bs, dim=0)
101
+ emb_mask_rgb = torch.stack(mask_rgb_embeds_bs, dim=0)
102
+
103
+ emb_feature = torch.cat((emb_inpaint, emb_mask_rgb), dim=-1)
104
+ emb_feature = emb_feature.unsqueeze(1)
105
+ emb_feature = self.vit_block(emb_feature) # 1024
106
+
107
+ scores, last_features = self.mlp(emb_feature)
108
+ scores = torch.squeeze(scores)
109
+ last_features = torch.squeeze(last_features)
110
+
111
+ if self.config.group:
112
+ scores = (scores - self.mean) / self.std
113
+
114
+
115
+ return scores.detach().cpu().numpy().tolist(), last_features.detach().cpu().numpy().tolist()
116
+
117
+
118
+
119
+ def load_model(self, model, ckpt_path = None):
120
+
121
+ print('load checkpoint from %s'%ckpt_path)
122
+ state_dict = {k: v for k, v in torch.load(ckpt_path, map_location='cpu').items()}
123
+ new_dict = {key.replace("module.", ""): value for key, value in state_dict.items()}
124
+ model.load_state_dict(new_dict)
125
+
126
+ return model
127
+
128
+
129
+
130
+ class ImageRewardGroup(nn.Module):
131
+ def __init__(self, config, device='cpu'):
132
+ super().__init__()
133
+ self.config = config
134
+ self.device = device
135
+
136
+ self.clip_model, self.preprocess = clip.load("ViT-B/32", device="cuda") #clip.load(config['clip_model'], device="cuda" if torch.cuda.is_available() else "cpu")
137
+
138
+ self.clip_model = self.clip_model.float()
139
+ self.mlp = MLP(config['ImageReward']['mlp_dim'])
140
+ self.vit_block = ViTBlock(self.config["ViT"]["feature_dim"], self.config["ViT"]["num_heads"], self.config["ViT"]["mlp_dim"])
141
+
142
+ if self.config.fix_base:
143
+ self.clip_model.requires_grad_(False)
144
+
145
+ for name, parms in self.clip_model.named_parameters():
146
+ if '_proj' in name:
147
+ parms.requires_grad_(False)
148
+
149
+ # fix certain ratio of layers
150
+ self.image_layer_num = 12
151
+ if self.config.fix_base > 0:
152
+ image_fix_num = "resblocks.{}".format(int(self.image_layer_num * self.config.fix_base))
153
+ for name, parms in self.clip_model.visual.named_parameters():
154
+ parms.requires_grad_(False)
155
+ if image_fix_num in name:
156
+ break
157
+
158
+
159
+ def loose_layer(self, fix_rate):
160
+ text_layer_id = [f"layer.{id}" for id in range(int(12 * fix_rate), 13)]
161
+ image_layer_id = [f"blocks.{id}" for id in range(int(24 * fix_rate), 25)]
162
+ for name, parms in self.blip.text_encoder.named_parameters():
163
+ for text_id in text_layer_id:
164
+ if text_id in name:
165
+ parms.requires_grad_(True)
166
+ for name, parms in self.blip.visual_encoder.named_parameters():
167
+ for image_id in image_layer_id:
168
+ if image_id in name:
169
+ parms.requires_grad_(True)
170
+
171
+
172
+ def forward(self, batch_data):
173
+
174
+ b_emb_inpt, b_emb_msk, w_emb_inpt, w_emb_msk = self.encode_pair(batch_data) # Nan
175
+ # forward
176
+ b_emb_feature = torch.cat((b_emb_inpt, b_emb_msk), dim=-1)
177
+ b_emb_feature = self.vit_block(b_emb_feature) # 1024
178
+ w_emb_feature = torch.cat((w_emb_inpt, w_emb_msk), dim=-1)
179
+ w_emb_feature = self.vit_block(w_emb_feature) # 1024
180
+
181
+ reward_better = self.mlp(b_emb_feature).squeeze(-1)
182
+ reward_worse = self.mlp(w_emb_feature).squeeze(-1)
183
+ reward = torch.concat((reward_better, reward_worse), dim=1)
184
+
185
+ return reward
186
+
187
+
188
+ def encode_pair(self, batch_data):
189
+ better_inpaint_embeds_bs, better_mask_rgb_embeds_bs = [], []
190
+ worse_inpaint_embeds_bs, worse_mask_rgb_embeds_bs = [], []
191
+ for bs in range(len(batch_data)):
192
+ better_inpt, better_msk = batch_data[bs]['better_inpt'], batch_data[bs]['better_msk']
193
+ better_inpt, better_msk = better_inpt.to(self.device), better_msk.to(self.device)
194
+
195
+ worse_inpt, worse_msk = batch_data[bs]['worse_inpt'], batch_data[bs]['worse_msk']
196
+ worse_inpt, worse_msk = worse_inpt.to(self.device), worse_msk.to(self.device)
197
+ # with torch.no_grad():
198
+ better_inpaint_embeds = self.clip_model.encode_image(better_inpt).to(torch.float32)
199
+ better_mask_rgb_embeds = self.clip_model.encode_image(better_msk).to(torch.float32)
200
+ worse_inpaint_embeds = self.clip_model.encode_image(worse_inpt).to(torch.float32)
201
+ worse_mask_rgb_embeds = self.clip_model.encode_image(worse_msk).to(torch.float32)
202
+
203
+ better_inpaint_embeds_bs.append(better_inpaint_embeds)
204
+ better_mask_rgb_embeds_bs.append(better_mask_rgb_embeds)
205
+ worse_inpaint_embeds_bs.append(worse_inpaint_embeds)
206
+ worse_mask_rgb_embeds_bs.append(worse_mask_rgb_embeds)
207
+
208
+ b_inpt = torch.stack(better_inpaint_embeds_bs, dim=0)
209
+ b_msk = torch.stack(better_mask_rgb_embeds_bs, dim=0)
210
+ w_inpt = torch.stack(worse_inpaint_embeds_bs, dim=0)
211
+ w_msk = torch.stack(worse_mask_rgb_embeds_bs, dim=0)
212
+
213
+
214
+ return b_inpt, b_msk, w_inpt, w_msk
215
+
216
+
217
+ def load_model(self, model, ckpt_path = None):
218
+
219
+ print('load checkpoint from %s'%ckpt_path)
220
+ checkpoint = torch.load(ckpt_path, map_location='cpu')
221
+ state_dict = checkpoint
222
+ msg = model.load_state_dict(state_dict,strict=False)
223
+ print("missing keys:", msg.missing_keys)
224
+
225
+ return model
226
+
227
+
228
+