Spaces:
Runtime error
Runtime error
import os | |
import torch | |
from model.backbone import ResEncUnet | |
from model.shader import CINN | |
from model.decoder_small import RGBADecoderNet | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def UDPClip(x): | |
return torch.clamp(x, min=0, max=1) # NCHW | |
class CoNR(): | |
def __init__(self, args): | |
self.args = args | |
self.udpparsernet = ResEncUnet( | |
backbone_name='resnet50_danbo', | |
classes=4, | |
pretrained=(args.local_rank == 0), | |
parametric_upsampling=True, | |
decoder_filters=(512, 384, 256, 128, 32), | |
map_location=device | |
) | |
self.target_pose_encoder = ResEncUnet( | |
backbone_name='resnet18_danbo-4', | |
classes=1, | |
pretrained=(args.local_rank == 0), | |
parametric_upsampling=True, | |
decoder_filters=(512, 384, 256, 128, 32), | |
map_location=device | |
) | |
self.DIM_SHADER_REFERENCE = 4 | |
self.shader = CINN(self.DIM_SHADER_REFERENCE) | |
self.rgbadecodernet = RGBADecoderNet( | |
) | |
self.device() | |
self.parser_ckpt = None | |
def dist(self): | |
args = self.args | |
if args.distributed: | |
self.udpparsernet = torch.nn.parallel.DistributedDataParallel( | |
self.udpparsernet, | |
device_ids=[ | |
args.local_rank], | |
output_device=args.local_rank, | |
broadcast_buffers=False, | |
find_unused_parameters=True | |
) | |
self.target_pose_encoder = torch.nn.parallel.DistributedDataParallel( | |
self.target_pose_encoder, | |
device_ids=[ | |
args.local_rank], | |
output_device=args.local_rank, | |
broadcast_buffers=False, | |
find_unused_parameters=True | |
) | |
self.shader = torch.nn.parallel.DistributedDataParallel( | |
self.shader, | |
device_ids=[ | |
args.local_rank], | |
output_device=args.local_rank, | |
broadcast_buffers=True | |
) | |
self.rgbadecodernet = torch.nn.parallel.DistributedDataParallel( | |
self.rgbadecodernet, | |
device_ids=[ | |
args.local_rank], | |
output_device=args.local_rank, | |
broadcast_buffers=True | |
) | |
def load_model(self, path): | |
self.udpparsernet.load_state_dict( | |
torch.load('{}/udpparsernet.pth'.format(path), map_location=device)) | |
self.target_pose_encoder.load_state_dict( | |
torch.load('{}/target_pose_encoder.pth'.format(path), map_location=device)) | |
self.shader.load_state_dict( | |
torch.load('{}/shader.pth'.format(path), map_location=device)) | |
self.rgbadecodernet.load_state_dict( | |
torch.load('{}/rgbadecodernet.pth'.format(path), map_location=device)) | |
def save_model(self, ite_num): | |
self._save_pth(self.udpparsernet, | |
model_name="udpparsernet", ite_num=ite_num) | |
self._save_pth(self.target_pose_encoder, | |
model_name="target_pose_encoder", ite_num=ite_num) | |
self._save_pth(self.shader, | |
model_name="shader", ite_num=ite_num) | |
self._save_pth(self.rgbadecodernet, | |
model_name="rgbadecodernet", ite_num=ite_num) | |
def _save_pth(self, net, model_name, ite_num): | |
args = self.args | |
to_save = None | |
if args.distributed: | |
if args.local_rank == 0: | |
to_save = net.module.state_dict() | |
else: | |
to_save = net.state_dict() | |
if to_save: | |
model_dir = os.path.join( | |
os.getcwd(), 'saved_models', args.model_name + os.sep + "checkpoints" + os.sep + "itr_%d" % (ite_num)+os.sep) | |
os.makedirs(model_dir, exist_ok=True) | |
torch.save(to_save, model_dir + model_name + ".pth") | |
def train(self): | |
self.udpparsernet.train() | |
self.target_pose_encoder.train() | |
self.shader.train() | |
self.rgbadecodernet.train() | |
def eval(self): | |
self.udpparsernet.eval() | |
self.target_pose_encoder.eval() | |
self.shader.eval() | |
self.rgbadecodernet.eval() | |
def device(self): | |
self.udpparsernet.to(device) | |
self.target_pose_encoder.to(device) | |
self.shader.to(device) | |
self.rgbadecodernet.to(device) | |
def data_norm_image(self, data): | |
with torch.cuda.amp.autocast(enabled=False): | |
for name in ["character_labels", "pose_label"]: | |
if name in data: | |
data[name] = data[name].to( | |
device, non_blocking=True).float() | |
for name in ["pose_images", "pose_mask", "character_images", "character_masks"]: | |
if name in data: | |
data[name] = data[name].to( | |
device, non_blocking=True).float() / 255.0 | |
if "pose_images" in data: | |
data["num_pose_images"] = data["pose_images"].shape[1] | |
data["num_samples"] = data["pose_images"].shape[0] | |
if "character_images" in data: | |
data["num_character_images"] = data["character_images"].shape[1] | |
data["num_samples"] = data["character_images"].shape[0] | |
if "pose_images" in data and "character_images" in data: | |
assert (data["pose_images"].shape[0] == | |
data["character_images"].shape[0]) | |
return data | |
def reset_charactersheet(self): | |
self.parser_ckpt = None | |
def model_step(self, data, training=False): | |
self.eval() | |
with torch.cuda.amp.autocast(enabled=False): | |
pred = {} | |
if self.parser_ckpt: | |
pred["parser"] = self.parser_ckpt | |
else: | |
pred = self.character_parser_forward(data, pred) | |
self.parser_ckpt = pred["parser"] | |
pred = self.pose_parser_sc_forward(data, pred) | |
pred = self.shader_pose_encoder_forward(data, pred) | |
pred = self.shader_forward(data, pred) | |
return pred | |
def shader_forward(self, data, pred={}): | |
assert ("num_character_images" in data), "ERROR: No Character Sheet input." | |
character_images_rgb_nmchw, num_character_images = data[ | |
"character_images"], data["num_character_images"] | |
# build x_reference_rgb_a_sudp in the draw call | |
shader_character_a_nmchw = data["character_masks"] | |
assert torch.any(torch.mean(shader_character_a_nmchw, (0, 2, 3, 4)) >= 0.95) == False, "ERROR: \ | |
No transparent area found in the image, PLEASE separate the foreground of input character sheets.\ | |
The website waifucutout.com is recommended to automatically cut out the foreground." | |
if shader_character_a_nmchw is None: | |
shader_character_a_nmchw = pred["parser"]["pred"][:, :, 3:4, :, :] | |
x_reference_rgb_a = torch.cat([shader_character_a_nmchw[:, :, :, :, :] * character_images_rgb_nmchw[:, :, :, :, :], | |
shader_character_a_nmchw[:, | |
:, :, :, :], | |
], 2) | |
assert (x_reference_rgb_a.shape[2] == self.DIM_SHADER_REFERENCE) | |
# build x_reference_features in the draw call | |
x_reference_features = pred["parser"]["features"] | |
# run cinn shader | |
retdic = self.shader( | |
pred["shader"]["target_pose_features"], x_reference_rgb_a, x_reference_features) | |
pred["shader"].update(retdic) | |
# decode rgba | |
if True: | |
dec_out = self.rgbadecodernet( | |
retdic["y_last_remote_features"]) | |
y_weighted_x_reference_RGB = dec_out[:, 0:3, :, :] | |
y_weighted_mask_A = dec_out[:, 3:4, :, :] | |
y_weighted_warp_decoded_rgba = torch.cat( | |
(y_weighted_x_reference_RGB*y_weighted_mask_A, y_weighted_mask_A), dim=1 | |
) | |
assert(y_weighted_warp_decoded_rgba.shape[1] == 4) | |
assert( | |
y_weighted_warp_decoded_rgba.shape[-1] == character_images_rgb_nmchw.shape[-1]) | |
# apply decoded mask to decoded rgb, finishing the draw call | |
pred["shader"]["y_weighted_warp_decoded_rgba"] = y_weighted_warp_decoded_rgba | |
return pred | |
def character_parser_forward(self, data, pred={}): | |
if not("num_character_images" in data and "character_images" in data): | |
return pred | |
pred["parser"] = {"pred": None} # create output | |
inputs_rgb_nmchw, num_samples, num_character_images = data[ | |
"character_images"], data["num_samples"], data["num_character_images"] | |
inputs_rgb_fchw = inputs_rgb_nmchw.view( | |
(num_samples * num_character_images, inputs_rgb_nmchw.shape[2], inputs_rgb_nmchw.shape[3], inputs_rgb_nmchw.shape[4])) | |
encoder_out, features = self.udpparsernet( | |
(inputs_rgb_fchw-0.6)/0.2970) | |
pred["parser"]["features"] = [features_out.view( | |
(num_samples, num_character_images, features_out.shape[1], features_out.shape[2], features_out.shape[3])) for features_out in features] | |
if (encoder_out is not None): | |
pred["parser"]["pred"] = UDPClip(encoder_out.view( | |
(num_samples, num_character_images, encoder_out.shape[1], encoder_out.shape[2], encoder_out.shape[3]))) | |
return pred | |
def pose_parser_sc_forward(self, data, pred={}): | |
if not("num_pose_images" in data and "pose_images" in data): | |
return pred | |
inputs_aug_rgb_nmchw, num_samples, num_pose_images = data[ | |
"pose_images"], data["num_samples"], data["num_pose_images"] | |
inputs_aug_rgb_fchw = inputs_aug_rgb_nmchw.view( | |
(num_samples * num_pose_images, inputs_aug_rgb_nmchw.shape[2], inputs_aug_rgb_nmchw.shape[3], inputs_aug_rgb_nmchw.shape[4])) | |
encoder_out, _ = self.udpparsernet( | |
(inputs_aug_rgb_fchw-0.6)/0.2970) | |
encoder_out = encoder_out.view( | |
(num_samples, num_pose_images, encoder_out.shape[1], encoder_out.shape[2], encoder_out.shape[3])) | |
# apply sigmoid after eval loss | |
pred["pose_parser"] = {"pred":UDPClip(encoder_out)[:,0,:,:,:]} | |
return pred | |
def shader_pose_encoder_forward(self, data, pred={}): | |
pred["shader"] = {} # create output | |
if "pose_images" in data: | |
pose_images_rgb_nmchw = data["pose_images"] | |
target_gt_rgb = pose_images_rgb_nmchw[:, 0, :, :, :] | |
pred["shader"]["target_gt_rgb"] = target_gt_rgb | |
shader_target_a = None | |
if "pose_mask" in data: | |
pred["shader"]["target_gt_a"] = data["pose_mask"] | |
shader_target_a = data["pose_mask"] | |
shader_target_sudp = None | |
if "pose_label" in data: | |
shader_target_sudp = data["pose_label"][:, :3, :, :] | |
if self.args.test_pose_use_parser_udp: | |
shader_target_sudp = None | |
if shader_target_sudp is None: | |
shader_target_sudp = pred["pose_parser"]["pred"][:, 0:3, :, :] | |
if shader_target_a is None: | |
shader_target_a = pred["pose_parser"]["pred"][:, 3:4, :, :] | |
# build x_target_sudp_a in the draw call | |
x_target_sudp_a = torch.cat(( | |
shader_target_sudp*shader_target_a, | |
shader_target_a | |
), 1) | |
pred["shader"].update({ | |
"x_target_sudp_a": x_target_sudp_a | |
}) | |
_, features = self.target_pose_encoder( | |
(x_target_sudp_a-0.6)/0.2970, ret_parser_out=False) | |
pred["shader"]["target_pose_features"] = features | |
return pred |