kd5678 commited on
Commit
43d2c2f
·
verified ·
1 Parent(s): 23c375d

Delete InpaintReward.py

Browse files
Files changed (1) hide show
  1. InpaintReward.py +0 -266
InpaintReward.py DELETED
@@ -1,266 +0,0 @@
1
- '''
2
- @File : ImageReward.py
3
- @Time : 2023/02/28 19:53:00
4
- @Auther : Jiazheng Xu
5
- @Contact : xjz22@mails.tsinghua.edu.cn
6
- @Description: ImageReward Reward model for reward model.
7
- '''
8
-
9
- import os
10
- import torch
11
- import torch.nn as nn
12
- import torch.nn.functional as F
13
- import clip
14
- from PIL import Image
15
- from torch.nn import TransformerEncoder, TransformerEncoderLayer
16
- from torchvision import transforms
17
-
18
-
19
- class MLP(nn.Module):
20
- def __init__(self, input_size):
21
- super().__init__()
22
- self.input_size = input_size
23
-
24
- self.layers = nn.Sequential(
25
- nn.Linear(self.input_size, 512),
26
- # nn.ReLU(),
27
- nn.Dropout(0.2),
28
- nn.Linear(512, 256),
29
- # # nn.ReLU(),
30
- # nn.Dropout(0.2),
31
- # nn.Linear(256, 128)
32
- # # nn.ReLU(),
33
- # nn.Dropout(0.1),
34
- # nn.Linear(128, 64),
35
- # # nn.ReLU(),
36
- # nn.Linear(64, 1)
37
- )
38
- self.last_layer = nn.Linear(256, 1, bias=False)
39
- self.last_layer_weight = self.last_layer.weight
40
- # initial MLP param
41
- for name, param in self.layers.named_parameters():
42
- if 'weight' in name:
43
- nn.init.normal_(param, mean=0.0, std=1.0/(self.input_size+1))
44
- if 'bias' in name:
45
- nn.init.constant_(param, val=0)
46
-
47
- for name, param in self.last_layer.named_parameters():
48
- if 'weight' in name:
49
- nn.init.normal_(param, mean=0.0, std=1.0/(self.input_size+1))
50
- if 'bias' in name:
51
- nn.init.constant_(param, val=0)
52
-
53
- def forward(self, input):
54
- features = self.layers(input)
55
- out = self.last_layer(features)
56
- return out, features
57
-
58
-
59
- class ViTBlock(nn.Module):
60
- def __init__(self, feature_dim, num_heads, mlp_dim, dropout=0.1):
61
- super(ViTBlock, self).__init__()
62
- # Transformer encoder layer
63
- self.encoder_layer = TransformerEncoderLayer(
64
- d_model=feature_dim,
65
- nhead=num_heads,
66
- dim_feedforward=mlp_dim,
67
- dropout=dropout,
68
- batch_first=True # Input shape: (batch_size, seq_length, feature_dim)
69
- )
70
- self.transformer_encoder = TransformerEncoder(self.encoder_layer, num_layers=1)
71
-
72
- def forward(self, x):
73
- x = self.transformer_encoder(x)
74
- return x
75
-
76
-
77
- class ImageReward(nn.Module):
78
- def __init__(self, config, device='cpu'):
79
- super().__init__()
80
- self.config = config
81
- self.device = device
82
-
83
- self.clip_model, self.preprocess = clip.load("ViT-B/32") #, device=self.device) #clip.load(config['clip_model'], device="cuda" if torch.cuda.is_available() else "cpu")
84
-
85
- self.clip_model = self.clip_model.float()
86
- self.mlp = MLP(self.config['ImageReward']['mlp_dim'])
87
- self.vit_block = ViTBlock(self.config["ViT"]["feature_dim"], self.config["ViT"]["num_heads"], self.config["ViT"]["mlp_dim"])
88
-
89
- self.toImage = transforms.ToPILImage()
90
-
91
- self.mean = 0.4064 #0.65823
92
- self.std = 2.3021 #8.5400
93
-
94
- if self.config.fix_base:
95
- self.clip_model.requires_grad_(False)
96
-
97
- # for name, parms in self.clip_model.named_parameters():
98
- # if '_proj' in name:
99
- # parms.requires_grad_(False)
100
-
101
- # # fix certain ratio of layers
102
- # self.image_layer_num = 12
103
- # if self.config.fix_rate > 0:
104
- # image_fix_num = "resblocks.{}".format(int(self.image_layer_num * self.config.fix_rate))
105
- # for name, parms in self.clip_model.visual.named_parameters():
106
- # parms.requires_grad_(False)
107
- # if image_fix_num in name:
108
- # break
109
-
110
-
111
- def score(self, inpaint_list, masks_rgb):
112
-
113
- inpaint_embeds_bs, mask_rgb_embeds_bs = [], []
114
-
115
- for bs in range(len(inpaint_list)):
116
- if isinstance(inpaint_list[bs], torch.Tensor):
117
- inpaint = self.toImage(inpaint_list[bs])
118
- else:
119
- inpaint = inpaint_list[bs]
120
- inpaint = self.preprocess(inpaint).unsqueeze(0)
121
- if isinstance(masks_rgb[bs], torch.Tensor):
122
- mask_rgb = self.toImage(masks_rgb[bs])
123
- else:
124
- mask_rgb = masks_rgb[bs]
125
- mask_rgb = self.preprocess(masks_rgb[bs]).unsqueeze(0)
126
- inpt, msk = inpaint.to(self.device), mask_rgb.to(self.device)
127
- # with torch.no_grad():
128
- inpt_embeds = self.clip_model.encode_image(inpt).to(torch.float32)
129
- msk_embeds = self.clip_model.encode_image(msk).to(torch.float32)
130
-
131
- inpaint_embeds_bs.append(inpt_embeds.squeeze(0))
132
- mask_rgb_embeds_bs.append(msk_embeds.squeeze(0))
133
-
134
-
135
- emb_inpaint = torch.stack(inpaint_embeds_bs, dim=0)
136
- emb_mask_rgb = torch.stack(mask_rgb_embeds_bs, dim=0)
137
-
138
- emb_feature = torch.cat((emb_inpaint, emb_mask_rgb), dim=-1)
139
- emb_feature = emb_feature.unsqueeze(1)
140
- emb_feature = self.vit_block(emb_feature) # 1024
141
-
142
- scores, last_features = self.mlp(emb_feature)
143
- scores = torch.squeeze(scores)
144
- last_features = torch.squeeze(last_features)
145
-
146
- if self.config.group:
147
- scores = (scores - self.mean) / self.std
148
-
149
-
150
- return scores.detach().cpu().numpy().tolist(), last_features.detach().cpu().numpy().tolist()
151
-
152
-
153
-
154
- def load_model(self, model, ckpt_path = None):
155
-
156
- print('load checkpoint from %s'%ckpt_path)
157
- state_dict = {k: v for k, v in torch.load(ckpt_path, map_location='cpu').items()}
158
- new_dict = {key.replace("module.", ""): value for key, value in state_dict.items()}
159
- msg = model.load_state_dict(new_dict)
160
- # checkpoint = torch.load(ckpt_path, map_location='cpu')
161
- # state_dict = checkpoint
162
- # msg = model.load_state_dict(state_dict,strict=False)
163
-
164
- return model
165
-
166
-
167
-
168
- class ImageRewardGroup(nn.Module):
169
- def __init__(self, config, device='cpu'):
170
- super().__init__()
171
- self.config = config
172
- self.device = device
173
-
174
- self.clip_model, self.preprocess = clip.load("ViT-B/32", device="cuda") #clip.load(config['clip_model'], device="cuda" if torch.cuda.is_available() else "cpu")
175
-
176
- self.clip_model = self.clip_model.float()
177
- self.mlp = MLP(config['ImageReward']['mlp_dim'])
178
- self.vit_block = ViTBlock(self.config["ViT"]["feature_dim"], self.config["ViT"]["num_heads"], self.config["ViT"]["mlp_dim"])
179
-
180
- if self.config.fix_base:
181
- self.clip_model.requires_grad_(False)
182
-
183
- for name, parms in self.clip_model.named_parameters():
184
- if '_proj' in name:
185
- parms.requires_grad_(False)
186
-
187
- # fix certain ratio of layers
188
- self.image_layer_num = 12
189
- if self.config.fix_base > 0:
190
- image_fix_num = "resblocks.{}".format(int(self.image_layer_num * self.config.fix_base))
191
- for name, parms in self.clip_model.visual.named_parameters():
192
- parms.requires_grad_(False)
193
- if image_fix_num in name:
194
- break
195
-
196
-
197
- def loose_layer(self, fix_rate):
198
- text_layer_id = [f"layer.{id}" for id in range(int(12 * fix_rate), 13)]
199
- image_layer_id = [f"blocks.{id}" for id in range(int(24 * fix_rate), 25)]
200
- for name, parms in self.blip.text_encoder.named_parameters():
201
- for text_id in text_layer_id:
202
- if text_id in name:
203
- parms.requires_grad_(True)
204
- for name, parms in self.blip.visual_encoder.named_parameters():
205
- for image_id in image_layer_id:
206
- if image_id in name:
207
- parms.requires_grad_(True)
208
-
209
-
210
- def forward(self, batch_data):
211
-
212
- b_emb_inpt, b_emb_msk, w_emb_inpt, w_emb_msk = self.encode_pair(batch_data) # Nan
213
- # forward
214
- b_emb_feature = torch.cat((b_emb_inpt, b_emb_msk), dim=-1)
215
- b_emb_feature = self.vit_block(b_emb_feature) # 1024
216
- w_emb_feature = torch.cat((w_emb_inpt, w_emb_msk), dim=-1)
217
- w_emb_feature = self.vit_block(w_emb_feature) # 1024
218
-
219
- reward_better = self.mlp(b_emb_feature).squeeze(-1)
220
- reward_worse = self.mlp(w_emb_feature).squeeze(-1)
221
- reward = torch.concat((reward_better, reward_worse), dim=1)
222
-
223
- return reward
224
-
225
-
226
- def encode_pair(self, batch_data):
227
- better_inpaint_embeds_bs, better_mask_rgb_embeds_bs = [], []
228
- worse_inpaint_embeds_bs, worse_mask_rgb_embeds_bs = [], []
229
- for bs in range(len(batch_data)):
230
- better_inpt, better_msk = batch_data[bs]['better_inpt'], batch_data[bs]['better_msk']
231
- better_inpt, better_msk = better_inpt.to(self.device), better_msk.to(self.device)
232
-
233
- worse_inpt, worse_msk = batch_data[bs]['worse_inpt'], batch_data[bs]['worse_msk']
234
- worse_inpt, worse_msk = worse_inpt.to(self.device), worse_msk.to(self.device)
235
- # with torch.no_grad():
236
- better_inpaint_embeds = self.clip_model.encode_image(better_inpt).to(torch.float32)
237
- better_mask_rgb_embeds = self.clip_model.encode_image(better_msk).to(torch.float32)
238
- worse_inpaint_embeds = self.clip_model.encode_image(worse_inpt).to(torch.float32)
239
- worse_mask_rgb_embeds = self.clip_model.encode_image(worse_msk).to(torch.float32)
240
-
241
- better_inpaint_embeds_bs.append(better_inpaint_embeds)
242
- better_mask_rgb_embeds_bs.append(better_mask_rgb_embeds)
243
- worse_inpaint_embeds_bs.append(worse_inpaint_embeds)
244
- worse_mask_rgb_embeds_bs.append(worse_mask_rgb_embeds)
245
-
246
- b_inpt = torch.stack(better_inpaint_embeds_bs, dim=0)
247
- b_msk = torch.stack(better_mask_rgb_embeds_bs, dim=0)
248
- w_inpt = torch.stack(worse_inpaint_embeds_bs, dim=0)
249
- w_msk = torch.stack(worse_mask_rgb_embeds_bs, dim=0)
250
-
251
-
252
- return b_inpt, b_msk, w_inpt, w_msk
253
-
254
-
255
- def load_model(self, model, ckpt_path = None):
256
-
257
- print('load checkpoint from %s'%ckpt_path)
258
- checkpoint = torch.load(ckpt_path, map_location='cpu')
259
- state_dict = checkpoint
260
- msg = model.load_state_dict(state_dict,strict=False)
261
- print("missing keys:", msg.missing_keys)
262
-
263
- return model
264
-
265
-
266
-