kd5678 commited on
Commit
2bcca80
·
verified ·
1 Parent(s): 34f8bed

Upload InpaintReward.py

Browse files
Files changed (1) hide show
  1. InpaintReward.py +266 -0
InpaintReward.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ @File : ImageReward.py
3
+ @Time : 2023/02/28 19:53:00
4
+ @Auther : Jiazheng Xu
5
+ @Contact : xjz22@mails.tsinghua.edu.cn
6
+ @Description: ImageReward Reward model for reward model.
7
+ '''
8
+
9
+ import os
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import clip
14
+ from PIL import Image
15
+ from torch.nn import TransformerEncoder, TransformerEncoderLayer
16
+ from torchvision import transforms
17
+
18
+
19
+ class MLP(nn.Module):
20
+ def __init__(self, input_size):
21
+ super().__init__()
22
+ self.input_size = input_size
23
+
24
+ self.layers = nn.Sequential(
25
+ nn.Linear(self.input_size, 512),
26
+ # nn.ReLU(),
27
+ nn.Dropout(0.2),
28
+ nn.Linear(512, 256),
29
+ # # nn.ReLU(),
30
+ # nn.Dropout(0.2),
31
+ # nn.Linear(256, 128)
32
+ # # nn.ReLU(),
33
+ # nn.Dropout(0.1),
34
+ # nn.Linear(128, 64),
35
+ # # nn.ReLU(),
36
+ # nn.Linear(64, 1)
37
+ )
38
+ self.last_layer = nn.Linear(256, 1, bias=False)
39
+ self.last_layer_weight = self.last_layer.weight
40
+ # initial MLP param
41
+ for name, param in self.layers.named_parameters():
42
+ if 'weight' in name:
43
+ nn.init.normal_(param, mean=0.0, std=1.0/(self.input_size+1))
44
+ if 'bias' in name:
45
+ nn.init.constant_(param, val=0)
46
+
47
+ for name, param in self.last_layer.named_parameters():
48
+ if 'weight' in name:
49
+ nn.init.normal_(param, mean=0.0, std=1.0/(self.input_size+1))
50
+ if 'bias' in name:
51
+ nn.init.constant_(param, val=0)
52
+
53
+ def forward(self, input):
54
+ features = self.layers(input)
55
+ out = self.last_layer(features)
56
+ return out, features
57
+
58
+
59
+ class ViTBlock(nn.Module):
60
+ def __init__(self, feature_dim, num_heads, mlp_dim, dropout=0.1):
61
+ super(ViTBlock, self).__init__()
62
+ # Transformer encoder layer
63
+ self.encoder_layer = TransformerEncoderLayer(
64
+ d_model=feature_dim,
65
+ nhead=num_heads,
66
+ dim_feedforward=mlp_dim,
67
+ dropout=dropout,
68
+ batch_first=True # Input shape: (batch_size, seq_length, feature_dim)
69
+ )
70
+ self.transformer_encoder = TransformerEncoder(self.encoder_layer, num_layers=1)
71
+
72
+ def forward(self, x):
73
+ x = self.transformer_encoder(x)
74
+ return x
75
+
76
+
77
+ class ImageReward(nn.Module):
78
+ def __init__(self, config, device='cpu'):
79
+ super().__init__()
80
+ self.config = config
81
+ self.device = device
82
+
83
+ self.clip_model, self.preprocess = clip.load("ViT-B/32") #, device=self.device) #clip.load(config['clip_model'], device="cuda" if torch.cuda.is_available() else "cpu")
84
+
85
+ self.clip_model = self.clip_model.float()
86
+ self.mlp = MLP(self.config['ImageReward']['mlp_dim'])
87
+ self.vit_block = ViTBlock(self.config["ViT"]["feature_dim"], self.config["ViT"]["num_heads"], self.config["ViT"]["mlp_dim"])
88
+
89
+ self.toImage = transforms.ToPILImage()
90
+
91
+ self.mean = 0.4064 #0.65823
92
+ self.std = 2.3021 #8.5400
93
+
94
+ if self.config.fix_base:
95
+ self.clip_model.requires_grad_(False)
96
+
97
+ # for name, parms in self.clip_model.named_parameters():
98
+ # if '_proj' in name:
99
+ # parms.requires_grad_(False)
100
+
101
+ # # fix certain ratio of layers
102
+ # self.image_layer_num = 12
103
+ # if self.config.fix_rate > 0:
104
+ # image_fix_num = "resblocks.{}".format(int(self.image_layer_num * self.config.fix_rate))
105
+ # for name, parms in self.clip_model.visual.named_parameters():
106
+ # parms.requires_grad_(False)
107
+ # if image_fix_num in name:
108
+ # break
109
+
110
+
111
+ def score(self, inpaint_list, masks_rgb):
112
+
113
+ inpaint_embeds_bs, mask_rgb_embeds_bs = [], []
114
+
115
+ for bs in range(len(inpaint_list)):
116
+ if isinstance(inpaint_list[bs], torch.Tensor):
117
+ inpaint = self.toImage(inpaint_list[bs])
118
+ else:
119
+ inpaint = inpaint_list[bs]
120
+ inpaint = self.preprocess(inpaint).unsqueeze(0)
121
+ if isinstance(masks_rgb[bs], torch.Tensor):
122
+ mask_rgb = self.toImage(masks_rgb[bs])
123
+ else:
124
+ mask_rgb = masks_rgb[bs]
125
+ mask_rgb = self.preprocess(masks_rgb[bs]).unsqueeze(0)
126
+ inpt, msk = inpaint.to(self.device), mask_rgb.to(self.device)
127
+ # with torch.no_grad():
128
+ inpt_embeds = self.clip_model.encode_image(inpt).to(torch.float32)
129
+ msk_embeds = self.clip_model.encode_image(msk).to(torch.float32)
130
+
131
+ inpaint_embeds_bs.append(inpt_embeds.squeeze(0))
132
+ mask_rgb_embeds_bs.append(msk_embeds.squeeze(0))
133
+
134
+
135
+ emb_inpaint = torch.stack(inpaint_embeds_bs, dim=0)
136
+ emb_mask_rgb = torch.stack(mask_rgb_embeds_bs, dim=0)
137
+
138
+ emb_feature = torch.cat((emb_inpaint, emb_mask_rgb), dim=-1)
139
+ emb_feature = emb_feature.unsqueeze(1)
140
+ emb_feature = self.vit_block(emb_feature) # 1024
141
+
142
+ scores, last_features = self.mlp(emb_feature)
143
+ scores = torch.squeeze(scores)
144
+ last_features = torch.squeeze(last_features)
145
+
146
+ if self.config.group:
147
+ scores = (scores - self.mean) / self.std
148
+
149
+
150
+ return scores.detach().cpu().numpy().tolist(), last_features.detach().cpu().numpy().tolist()
151
+
152
+
153
+
154
+ def load_model(self, model, ckpt_path = None):
155
+
156
+ print('load checkpoint from %s'%ckpt_path)
157
+ state_dict = {k: v for k, v in torch.load(ckpt_path, map_location='cpu').items()}
158
+ new_dict = {key.replace("module.", ""): value for key, value in state_dict.items()}
159
+ msg = model.load_state_dict(new_dict)
160
+ # checkpoint = torch.load(ckpt_path, map_location='cpu')
161
+ # state_dict = checkpoint
162
+ # msg = model.load_state_dict(state_dict,strict=False)
163
+
164
+ return model
165
+
166
+
167
+
168
+ class ImageRewardGroup(nn.Module):
169
+ def __init__(self, config, device='cpu'):
170
+ super().__init__()
171
+ self.config = config
172
+ self.device = device
173
+
174
+ self.clip_model, self.preprocess = clip.load("ViT-B/32", device="cuda") #clip.load(config['clip_model'], device="cuda" if torch.cuda.is_available() else "cpu")
175
+
176
+ self.clip_model = self.clip_model.float()
177
+ self.mlp = MLP(config['ImageReward']['mlp_dim'])
178
+ self.vit_block = ViTBlock(self.config["ViT"]["feature_dim"], self.config["ViT"]["num_heads"], self.config["ViT"]["mlp_dim"])
179
+
180
+ if self.config.fix_base:
181
+ self.clip_model.requires_grad_(False)
182
+
183
+ for name, parms in self.clip_model.named_parameters():
184
+ if '_proj' in name:
185
+ parms.requires_grad_(False)
186
+
187
+ # fix certain ratio of layers
188
+ self.image_layer_num = 12
189
+ if self.config.fix_base > 0:
190
+ image_fix_num = "resblocks.{}".format(int(self.image_layer_num * self.config.fix_base))
191
+ for name, parms in self.clip_model.visual.named_parameters():
192
+ parms.requires_grad_(False)
193
+ if image_fix_num in name:
194
+ break
195
+
196
+
197
+ def loose_layer(self, fix_rate):
198
+ text_layer_id = [f"layer.{id}" for id in range(int(12 * fix_rate), 13)]
199
+ image_layer_id = [f"blocks.{id}" for id in range(int(24 * fix_rate), 25)]
200
+ for name, parms in self.blip.text_encoder.named_parameters():
201
+ for text_id in text_layer_id:
202
+ if text_id in name:
203
+ parms.requires_grad_(True)
204
+ for name, parms in self.blip.visual_encoder.named_parameters():
205
+ for image_id in image_layer_id:
206
+ if image_id in name:
207
+ parms.requires_grad_(True)
208
+
209
+
210
+ def forward(self, batch_data):
211
+
212
+ b_emb_inpt, b_emb_msk, w_emb_inpt, w_emb_msk = self.encode_pair(batch_data) # Nan
213
+ # forward
214
+ b_emb_feature = torch.cat((b_emb_inpt, b_emb_msk), dim=-1)
215
+ b_emb_feature = self.vit_block(b_emb_feature) # 1024
216
+ w_emb_feature = torch.cat((w_emb_inpt, w_emb_msk), dim=-1)
217
+ w_emb_feature = self.vit_block(w_emb_feature) # 1024
218
+
219
+ reward_better = self.mlp(b_emb_feature).squeeze(-1)
220
+ reward_worse = self.mlp(w_emb_feature).squeeze(-1)
221
+ reward = torch.concat((reward_better, reward_worse), dim=1)
222
+
223
+ return reward
224
+
225
+
226
+ def encode_pair(self, batch_data):
227
+ better_inpaint_embeds_bs, better_mask_rgb_embeds_bs = [], []
228
+ worse_inpaint_embeds_bs, worse_mask_rgb_embeds_bs = [], []
229
+ for bs in range(len(batch_data)):
230
+ better_inpt, better_msk = batch_data[bs]['better_inpt'], batch_data[bs]['better_msk']
231
+ better_inpt, better_msk = better_inpt.to(self.device), better_msk.to(self.device)
232
+
233
+ worse_inpt, worse_msk = batch_data[bs]['worse_inpt'], batch_data[bs]['worse_msk']
234
+ worse_inpt, worse_msk = worse_inpt.to(self.device), worse_msk.to(self.device)
235
+ # with torch.no_grad():
236
+ better_inpaint_embeds = self.clip_model.encode_image(better_inpt).to(torch.float32)
237
+ better_mask_rgb_embeds = self.clip_model.encode_image(better_msk).to(torch.float32)
238
+ worse_inpaint_embeds = self.clip_model.encode_image(worse_inpt).to(torch.float32)
239
+ worse_mask_rgb_embeds = self.clip_model.encode_image(worse_msk).to(torch.float32)
240
+
241
+ better_inpaint_embeds_bs.append(better_inpaint_embeds)
242
+ better_mask_rgb_embeds_bs.append(better_mask_rgb_embeds)
243
+ worse_inpaint_embeds_bs.append(worse_inpaint_embeds)
244
+ worse_mask_rgb_embeds_bs.append(worse_mask_rgb_embeds)
245
+
246
+ b_inpt = torch.stack(better_inpaint_embeds_bs, dim=0)
247
+ b_msk = torch.stack(better_mask_rgb_embeds_bs, dim=0)
248
+ w_inpt = torch.stack(worse_inpaint_embeds_bs, dim=0)
249
+ w_msk = torch.stack(worse_mask_rgb_embeds_bs, dim=0)
250
+
251
+
252
+ return b_inpt, b_msk, w_inpt, w_msk
253
+
254
+
255
+ def load_model(self, model, ckpt_path = None):
256
+
257
+ print('load checkpoint from %s'%ckpt_path)
258
+ checkpoint = torch.load(ckpt_path, map_location='cpu')
259
+ state_dict = checkpoint
260
+ msg = model.load_state_dict(state_dict,strict=False)
261
+ print("missing keys:", msg.missing_keys)
262
+
263
+ return model
264
+
265
+
266
+