|
''' |
|
@File : ImageReward.py |
|
@Time : 2023/02/28 19:53:00 |
|
@Auther : Jiazheng Xu |
|
@Contact : xjz22@mails.tsinghua.edu.cn |
|
@Description: ImageReward Reward model for reward model. |
|
''' |
|
|
|
import os |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import clip |
|
from PIL import Image |
|
from torch.nn import TransformerEncoder, TransformerEncoderLayer |
|
from torchvision import transforms |
|
|
|
|
|
class MLP(nn.Module): |
|
def __init__(self, input_size): |
|
super().__init__() |
|
self.input_size = input_size |
|
|
|
self.layers = nn.Sequential( |
|
nn.Linear(self.input_size, 512), |
|
|
|
nn.Dropout(0.2), |
|
nn.Linear(512, 256), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
) |
|
self.last_layer = nn.Linear(256, 1, bias=False) |
|
self.last_layer_weight = self.last_layer.weight |
|
|
|
for name, param in self.layers.named_parameters(): |
|
if 'weight' in name: |
|
nn.init.normal_(param, mean=0.0, std=1.0/(self.input_size+1)) |
|
if 'bias' in name: |
|
nn.init.constant_(param, val=0) |
|
|
|
for name, param in self.last_layer.named_parameters(): |
|
if 'weight' in name: |
|
nn.init.normal_(param, mean=0.0, std=1.0/(self.input_size+1)) |
|
if 'bias' in name: |
|
nn.init.constant_(param, val=0) |
|
|
|
def forward(self, input): |
|
features = self.layers(input) |
|
out = self.last_layer(features) |
|
return out, features |
|
|
|
|
|
class ViTBlock(nn.Module): |
|
def __init__(self, feature_dim, num_heads, mlp_dim, dropout=0.1): |
|
super(ViTBlock, self).__init__() |
|
|
|
self.encoder_layer = TransformerEncoderLayer( |
|
d_model=feature_dim, |
|
nhead=num_heads, |
|
dim_feedforward=mlp_dim, |
|
dropout=dropout, |
|
batch_first=True |
|
) |
|
self.transformer_encoder = TransformerEncoder(self.encoder_layer, num_layers=1) |
|
|
|
def forward(self, x): |
|
x = self.transformer_encoder(x) |
|
return x |
|
|
|
|
|
class ImageReward(nn.Module): |
|
def __init__(self, config, device='cpu'): |
|
super().__init__() |
|
self.config = config |
|
self.device = device |
|
|
|
self.clip_model, self.preprocess = clip.load("ViT-B/32") |
|
|
|
self.clip_model = self.clip_model.float() |
|
self.mlp = MLP(self.config['ImageReward']['mlp_dim']) |
|
self.vit_block = ViTBlock(self.config["ViT"]["feature_dim"], self.config["ViT"]["num_heads"], self.config["ViT"]["mlp_dim"]) |
|
|
|
self.toImage = transforms.ToPILImage() |
|
|
|
self.mean = 0.4064 |
|
self.std = 2.3021 |
|
|
|
if self.config.fix_base: |
|
self.clip_model.requires_grad_(False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def score(self, inpaint_list, masks_rgb): |
|
|
|
inpaint_embeds_bs, mask_rgb_embeds_bs = [], [] |
|
|
|
for bs in range(len(inpaint_list)): |
|
if isinstance(inpaint_list[bs], torch.Tensor): |
|
inpaint = self.toImage(inpaint_list[bs]) |
|
else: |
|
inpaint = inpaint_list[bs] |
|
inpaint = self.preprocess(inpaint).unsqueeze(0) |
|
if isinstance(masks_rgb[bs], torch.Tensor): |
|
mask_rgb = self.toImage(masks_rgb[bs]) |
|
else: |
|
mask_rgb = masks_rgb[bs] |
|
mask_rgb = self.preprocess(masks_rgb[bs]).unsqueeze(0) |
|
inpt, msk = inpaint.to(self.device), mask_rgb.to(self.device) |
|
|
|
inpt_embeds = self.clip_model.encode_image(inpt).to(torch.float32) |
|
msk_embeds = self.clip_model.encode_image(msk).to(torch.float32) |
|
|
|
inpaint_embeds_bs.append(inpt_embeds.squeeze(0)) |
|
mask_rgb_embeds_bs.append(msk_embeds.squeeze(0)) |
|
|
|
|
|
emb_inpaint = torch.stack(inpaint_embeds_bs, dim=0) |
|
emb_mask_rgb = torch.stack(mask_rgb_embeds_bs, dim=0) |
|
|
|
emb_feature = torch.cat((emb_inpaint, emb_mask_rgb), dim=-1) |
|
emb_feature = emb_feature.unsqueeze(1) |
|
emb_feature = self.vit_block(emb_feature) |
|
|
|
scores, last_features = self.mlp(emb_feature) |
|
scores = torch.squeeze(scores) |
|
last_features = torch.squeeze(last_features) |
|
|
|
if self.config.group: |
|
scores = (scores - self.mean) / self.std |
|
|
|
|
|
return scores.detach().cpu().numpy().tolist(), last_features.detach().cpu().numpy().tolist() |
|
|
|
|
|
|
|
def load_model(self, model, ckpt_path = None): |
|
|
|
print('load checkpoint from %s'%ckpt_path) |
|
state_dict = {k: v for k, v in torch.load(ckpt_path, map_location='cpu').items()} |
|
new_dict = {key.replace("module.", ""): value for key, value in state_dict.items()} |
|
msg = model.load_state_dict(new_dict) |
|
|
|
|
|
|
|
|
|
return model |
|
|
|
|
|
|
|
class ImageRewardGroup(nn.Module): |
|
def __init__(self, config, device='cpu'): |
|
super().__init__() |
|
self.config = config |
|
self.device = device |
|
|
|
self.clip_model, self.preprocess = clip.load("ViT-B/32", device="cuda") |
|
|
|
self.clip_model = self.clip_model.float() |
|
self.mlp = MLP(config['ImageReward']['mlp_dim']) |
|
self.vit_block = ViTBlock(self.config["ViT"]["feature_dim"], self.config["ViT"]["num_heads"], self.config["ViT"]["mlp_dim"]) |
|
|
|
if self.config.fix_base: |
|
self.clip_model.requires_grad_(False) |
|
|
|
for name, parms in self.clip_model.named_parameters(): |
|
if '_proj' in name: |
|
parms.requires_grad_(False) |
|
|
|
|
|
self.image_layer_num = 12 |
|
if self.config.fix_base > 0: |
|
image_fix_num = "resblocks.{}".format(int(self.image_layer_num * self.config.fix_base)) |
|
for name, parms in self.clip_model.visual.named_parameters(): |
|
parms.requires_grad_(False) |
|
if image_fix_num in name: |
|
break |
|
|
|
|
|
def loose_layer(self, fix_rate): |
|
text_layer_id = [f"layer.{id}" for id in range(int(12 * fix_rate), 13)] |
|
image_layer_id = [f"blocks.{id}" for id in range(int(24 * fix_rate), 25)] |
|
for name, parms in self.blip.text_encoder.named_parameters(): |
|
for text_id in text_layer_id: |
|
if text_id in name: |
|
parms.requires_grad_(True) |
|
for name, parms in self.blip.visual_encoder.named_parameters(): |
|
for image_id in image_layer_id: |
|
if image_id in name: |
|
parms.requires_grad_(True) |
|
|
|
|
|
def forward(self, batch_data): |
|
|
|
b_emb_inpt, b_emb_msk, w_emb_inpt, w_emb_msk = self.encode_pair(batch_data) |
|
|
|
b_emb_feature = torch.cat((b_emb_inpt, b_emb_msk), dim=-1) |
|
b_emb_feature = self.vit_block(b_emb_feature) |
|
w_emb_feature = torch.cat((w_emb_inpt, w_emb_msk), dim=-1) |
|
w_emb_feature = self.vit_block(w_emb_feature) |
|
|
|
reward_better = self.mlp(b_emb_feature).squeeze(-1) |
|
reward_worse = self.mlp(w_emb_feature).squeeze(-1) |
|
reward = torch.concat((reward_better, reward_worse), dim=1) |
|
|
|
return reward |
|
|
|
|
|
def encode_pair(self, batch_data): |
|
better_inpaint_embeds_bs, better_mask_rgb_embeds_bs = [], [] |
|
worse_inpaint_embeds_bs, worse_mask_rgb_embeds_bs = [], [] |
|
for bs in range(len(batch_data)): |
|
better_inpt, better_msk = batch_data[bs]['better_inpt'], batch_data[bs]['better_msk'] |
|
better_inpt, better_msk = better_inpt.to(self.device), better_msk.to(self.device) |
|
|
|
worse_inpt, worse_msk = batch_data[bs]['worse_inpt'], batch_data[bs]['worse_msk'] |
|
worse_inpt, worse_msk = worse_inpt.to(self.device), worse_msk.to(self.device) |
|
|
|
better_inpaint_embeds = self.clip_model.encode_image(better_inpt).to(torch.float32) |
|
better_mask_rgb_embeds = self.clip_model.encode_image(better_msk).to(torch.float32) |
|
worse_inpaint_embeds = self.clip_model.encode_image(worse_inpt).to(torch.float32) |
|
worse_mask_rgb_embeds = self.clip_model.encode_image(worse_msk).to(torch.float32) |
|
|
|
better_inpaint_embeds_bs.append(better_inpaint_embeds) |
|
better_mask_rgb_embeds_bs.append(better_mask_rgb_embeds) |
|
worse_inpaint_embeds_bs.append(worse_inpaint_embeds) |
|
worse_mask_rgb_embeds_bs.append(worse_mask_rgb_embeds) |
|
|
|
b_inpt = torch.stack(better_inpaint_embeds_bs, dim=0) |
|
b_msk = torch.stack(better_mask_rgb_embeds_bs, dim=0) |
|
w_inpt = torch.stack(worse_inpaint_embeds_bs, dim=0) |
|
w_msk = torch.stack(worse_mask_rgb_embeds_bs, dim=0) |
|
|
|
|
|
return b_inpt, b_msk, w_inpt, w_msk |
|
|
|
|
|
def load_model(self, model, ckpt_path = None): |
|
|
|
print('load checkpoint from %s'%ckpt_path) |
|
checkpoint = torch.load(ckpt_path, map_location='cpu') |
|
state_dict = checkpoint |
|
msg = model.load_state_dict(state_dict,strict=False) |
|
print("missing keys:", msg.missing_keys) |
|
|
|
return model |
|
|
|
|
|
|
|
|