import os import torch from model.backbone import ResEncUnet from model.shader import CINN from model.decoder_small import RGBADecoderNet device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def UDPClip(x): return torch.clamp(x, min=0, max=1) # NCHW class CoNR(): def __init__(self, args): self.args = args self.udpparsernet = ResEncUnet( backbone_name='resnet50_danbo', classes=4, pretrained=(args.local_rank == 0), parametric_upsampling=True, decoder_filters=(512, 384, 256, 128, 32), map_location=device ) self.target_pose_encoder = ResEncUnet( backbone_name='resnet18_danbo-4', classes=1, pretrained=(args.local_rank == 0), parametric_upsampling=True, decoder_filters=(512, 384, 256, 128, 32), map_location=device ) self.DIM_SHADER_REFERENCE = 4 self.shader = CINN(self.DIM_SHADER_REFERENCE) self.rgbadecodernet = RGBADecoderNet( ) self.device() self.parser_ckpt = None def dist(self): args = self.args if args.distributed: self.udpparsernet = torch.nn.parallel.DistributedDataParallel( self.udpparsernet, device_ids=[ args.local_rank], output_device=args.local_rank, broadcast_buffers=False, find_unused_parameters=True ) self.target_pose_encoder = torch.nn.parallel.DistributedDataParallel( self.target_pose_encoder, device_ids=[ args.local_rank], output_device=args.local_rank, broadcast_buffers=False, find_unused_parameters=True ) self.shader = torch.nn.parallel.DistributedDataParallel( self.shader, device_ids=[ args.local_rank], output_device=args.local_rank, broadcast_buffers=True ) self.rgbadecodernet = torch.nn.parallel.DistributedDataParallel( self.rgbadecodernet, device_ids=[ args.local_rank], output_device=args.local_rank, broadcast_buffers=True ) def load_model(self, path): self.udpparsernet.load_state_dict( torch.load('{}/udpparsernet.pth'.format(path), map_location=device)) self.target_pose_encoder.load_state_dict( torch.load('{}/target_pose_encoder.pth'.format(path), map_location=device)) self.shader.load_state_dict( torch.load('{}/shader.pth'.format(path), map_location=device)) self.rgbadecodernet.load_state_dict( torch.load('{}/rgbadecodernet.pth'.format(path), map_location=device)) def save_model(self, ite_num): self._save_pth(self.udpparsernet, model_name="udpparsernet", ite_num=ite_num) self._save_pth(self.target_pose_encoder, model_name="target_pose_encoder", ite_num=ite_num) self._save_pth(self.shader, model_name="shader", ite_num=ite_num) self._save_pth(self.rgbadecodernet, model_name="rgbadecodernet", ite_num=ite_num) def _save_pth(self, net, model_name, ite_num): args = self.args to_save = None if args.distributed: if args.local_rank == 0: to_save = net.module.state_dict() else: to_save = net.state_dict() if to_save: model_dir = os.path.join( os.getcwd(), 'saved_models', args.model_name + os.sep + "checkpoints" + os.sep + "itr_%d" % (ite_num)+os.sep) os.makedirs(model_dir, exist_ok=True) torch.save(to_save, model_dir + model_name + ".pth") def train(self): self.udpparsernet.train() self.target_pose_encoder.train() self.shader.train() self.rgbadecodernet.train() def eval(self): self.udpparsernet.eval() self.target_pose_encoder.eval() self.shader.eval() self.rgbadecodernet.eval() def device(self): self.udpparsernet.to(device) self.target_pose_encoder.to(device) self.shader.to(device) self.rgbadecodernet.to(device) def data_norm_image(self, data): with torch.cuda.amp.autocast(enabled=False): for name in ["character_labels", "pose_label"]: if name in data: data[name] = data[name].to( device, non_blocking=True).float() for name in ["pose_images", "pose_mask", "character_images", "character_masks"]: if name in data: data[name] = data[name].to( device, non_blocking=True).float() / 255.0 if "pose_images" in data: data["num_pose_images"] = data["pose_images"].shape[1] data["num_samples"] = data["pose_images"].shape[0] if "character_images" in data: data["num_character_images"] = data["character_images"].shape[1] data["num_samples"] = data["character_images"].shape[0] if "pose_images" in data and "character_images" in data: assert (data["pose_images"].shape[0] == data["character_images"].shape[0]) return data def reset_charactersheet(self): self.parser_ckpt = None def model_step(self, data, training=False): self.eval() with torch.cuda.amp.autocast(enabled=False): pred = {} if self.parser_ckpt: pred["parser"] = self.parser_ckpt else: pred = self.character_parser_forward(data, pred) self.parser_ckpt = pred["parser"] pred = self.pose_parser_sc_forward(data, pred) pred = self.shader_pose_encoder_forward(data, pred) pred = self.shader_forward(data, pred) return pred def shader_forward(self, data, pred={}): assert ("num_character_images" in data), "ERROR: No Character Sheet input." character_images_rgb_nmchw, num_character_images = data[ "character_images"], data["num_character_images"] # build x_reference_rgb_a_sudp in the draw call shader_character_a_nmchw = data["character_masks"] assert torch.any(torch.mean(shader_character_a_nmchw, (0, 2, 3, 4)) >= 0.95) == False, "ERROR: \ No transparent area found in the image, PLEASE separate the foreground of input character sheets.\ The website waifucutout.com is recommended to automatically cut out the foreground." if shader_character_a_nmchw is None: shader_character_a_nmchw = pred["parser"]["pred"][:, :, 3:4, :, :] x_reference_rgb_a = torch.cat([shader_character_a_nmchw[:, :, :, :, :] * character_images_rgb_nmchw[:, :, :, :, :], shader_character_a_nmchw[:, :, :, :, :], ], 2) assert (x_reference_rgb_a.shape[2] == self.DIM_SHADER_REFERENCE) # build x_reference_features in the draw call x_reference_features = pred["parser"]["features"] # run cinn shader retdic = self.shader( pred["shader"]["target_pose_features"], x_reference_rgb_a, x_reference_features) pred["shader"].update(retdic) # decode rgba if True: dec_out = self.rgbadecodernet( retdic["y_last_remote_features"]) y_weighted_x_reference_RGB = dec_out[:, 0:3, :, :] y_weighted_mask_A = dec_out[:, 3:4, :, :] y_weighted_warp_decoded_rgba = torch.cat( (y_weighted_x_reference_RGB*y_weighted_mask_A, y_weighted_mask_A), dim=1 ) assert(y_weighted_warp_decoded_rgba.shape[1] == 4) assert( y_weighted_warp_decoded_rgba.shape[-1] == character_images_rgb_nmchw.shape[-1]) # apply decoded mask to decoded rgb, finishing the draw call pred["shader"]["y_weighted_warp_decoded_rgba"] = y_weighted_warp_decoded_rgba return pred def character_parser_forward(self, data, pred={}): if not("num_character_images" in data and "character_images" in data): return pred pred["parser"] = {"pred": None} # create output inputs_rgb_nmchw, num_samples, num_character_images = data[ "character_images"], data["num_samples"], data["num_character_images"] inputs_rgb_fchw = inputs_rgb_nmchw.view( (num_samples * num_character_images, inputs_rgb_nmchw.shape[2], inputs_rgb_nmchw.shape[3], inputs_rgb_nmchw.shape[4])) encoder_out, features = self.udpparsernet( (inputs_rgb_fchw-0.6)/0.2970) pred["parser"]["features"] = [features_out.view( (num_samples, num_character_images, features_out.shape[1], features_out.shape[2], features_out.shape[3])) for features_out in features] if (encoder_out is not None): pred["parser"]["pred"] = UDPClip(encoder_out.view( (num_samples, num_character_images, encoder_out.shape[1], encoder_out.shape[2], encoder_out.shape[3]))) return pred def pose_parser_sc_forward(self, data, pred={}): if not("num_pose_images" in data and "pose_images" in data): return pred inputs_aug_rgb_nmchw, num_samples, num_pose_images = data[ "pose_images"], data["num_samples"], data["num_pose_images"] inputs_aug_rgb_fchw = inputs_aug_rgb_nmchw.view( (num_samples * num_pose_images, inputs_aug_rgb_nmchw.shape[2], inputs_aug_rgb_nmchw.shape[3], inputs_aug_rgb_nmchw.shape[4])) encoder_out, _ = self.udpparsernet( (inputs_aug_rgb_fchw-0.6)/0.2970) encoder_out = encoder_out.view( (num_samples, num_pose_images, encoder_out.shape[1], encoder_out.shape[2], encoder_out.shape[3])) # apply sigmoid after eval loss pred["pose_parser"] = {"pred":UDPClip(encoder_out)[:,0,:,:,:]} return pred def shader_pose_encoder_forward(self, data, pred={}): pred["shader"] = {} # create output if "pose_images" in data: pose_images_rgb_nmchw = data["pose_images"] target_gt_rgb = pose_images_rgb_nmchw[:, 0, :, :, :] pred["shader"]["target_gt_rgb"] = target_gt_rgb shader_target_a = None if "pose_mask" in data: pred["shader"]["target_gt_a"] = data["pose_mask"] shader_target_a = data["pose_mask"] shader_target_sudp = None if "pose_label" in data: shader_target_sudp = data["pose_label"][:, :3, :, :] if self.args.test_pose_use_parser_udp: shader_target_sudp = None if shader_target_sudp is None: shader_target_sudp = pred["pose_parser"]["pred"][:, 0:3, :, :] if shader_target_a is None: shader_target_a = pred["pose_parser"]["pred"][:, 3:4, :, :] # build x_target_sudp_a in the draw call x_target_sudp_a = torch.cat(( shader_target_sudp*shader_target_a, shader_target_a ), 1) pred["shader"].update({ "x_target_sudp_a": x_target_sudp_a }) _, features = self.target_pose_encoder( (x_target_sudp_a-0.6)/0.2970, ret_parser_out=False) pred["shader"]["target_pose_features"] = features return pred