from enum import Enum from typing import List, Dict, Optional import torch from torch import Tensor from torch.nn import Module from torch.nn.functional import interpolate from tha3.nn.eyebrow_decomposer.eyebrow_decomposer_03 import EyebrowDecomposer03Factory, \ EyebrowDecomposer03Args, EyebrowDecomposer03 from tha3.nn.eyebrow_morphing_combiner.eyebrow_morphing_combiner_03 import \ EyebrowMorphingCombiner03Factory, EyebrowMorphingCombiner03Args, EyebrowMorphingCombiner03 from tha3.nn.face_morpher.face_morpher_09 import FaceMorpher09Factory, FaceMorpher09Args from tha3.poser.general_poser_02 import GeneralPoser02 from tha3.poser.poser import PoseParameterCategory, PoseParameters from tha3.nn.editor.editor_07 import Editor07, Editor07Args from tha3.nn.two_algo_body_rotator.two_algo_face_body_rotator_05 import TwoAlgoFaceBodyRotator05, \ TwoAlgoFaceBodyRotator05Args from tha3.util import torch_load from tha3.compute.cached_computation_func import TensorListCachedComputationFunc from tha3.compute.cached_computation_protocol import CachedComputationProtocol from tha3.nn.nonlinearity_factory import ReLUFactory, LeakyReLUFactory from tha3.nn.normalization import InstanceNorm2dFactory from tha3.nn.util import BlockArgs class Network(Enum): eyebrow_decomposer = 1 eyebrow_morphing_combiner = 2 face_morpher = 3 two_algo_face_body_rotator = 4 editor = 5 @property def outputs_key(self): return f"{self.name}_outputs" class Branch(Enum): face_morphed_half = 1 face_morphed_full = 2 all_outputs = 3 NUM_EYEBROW_PARAMS = 12 NUM_FACE_PARAMS = 27 NUM_ROTATION_PARAMS = 6 class FiveStepPoserComputationProtocol(CachedComputationProtocol): def __init__(self, eyebrow_morphed_image_index: int): super().__init__() self.eyebrow_morphed_image_index = eyebrow_morphed_image_index self.cached_batch_0 = None self.cached_eyebrow_decomposer_output = None def compute_func(self) -> TensorListCachedComputationFunc: def func(modules: Dict[str, Module], batch: List[Tensor], outputs: Dict[str, List[Tensor]]): if self.cached_batch_0 is None: new_batch_0 = True elif batch[0].shape[0] != self.cached_batch_0.shape[0]: new_batch_0 = True else: new_batch_0 = torch.max((batch[0] - self.cached_batch_0).abs()).item() > 0 if not new_batch_0: outputs[Network.eyebrow_decomposer.outputs_key] = self.cached_eyebrow_decomposer_output output = self.get_output(Branch.all_outputs.name, modules, batch, outputs) if new_batch_0: self.cached_batch_0 = batch[0] self.cached_eyebrow_decomposer_output = outputs[Network.eyebrow_decomposer.outputs_key] return output return func def compute_output(self, key: str, modules: Dict[str, Module], batch: List[Tensor], outputs: Dict[str, List[Tensor]]) -> List[Tensor]: if key == Network.eyebrow_decomposer.outputs_key: input_image = batch[0][:, :, 64:192, 64 + 128:192 + 128] return modules[Network.eyebrow_decomposer.name].forward(input_image) elif key == Network.eyebrow_morphing_combiner.outputs_key: eyebrow_decomposer_output = self.get_output(Network.eyebrow_decomposer.outputs_key, modules, batch, outputs) background_layer = eyebrow_decomposer_output[EyebrowDecomposer03.BACKGROUND_LAYER_INDEX] eyebrow_layer = eyebrow_decomposer_output[EyebrowDecomposer03.EYEBROW_LAYER_INDEX] eyebrow_pose = batch[1][:, :NUM_EYEBROW_PARAMS] return modules[Network.eyebrow_morphing_combiner.name].forward( background_layer, eyebrow_layer, eyebrow_pose) elif key == Network.face_morpher.outputs_key: eyebrow_morphing_combiner_output = self.get_output( Network.eyebrow_morphing_combiner.outputs_key, modules, batch, outputs) eyebrow_morphed_image = eyebrow_morphing_combiner_output[self.eyebrow_morphed_image_index] input_image = batch[0][:, :, 32:32 + 192, (32 + 128):(32 + 192 + 128)].clone() input_image[:, :, 32:32 + 128, 32:32 + 128] = eyebrow_morphed_image face_pose = batch[1][:, NUM_EYEBROW_PARAMS:NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS] return modules[Network.face_morpher.name].forward(input_image, face_pose) elif key == Branch.face_morphed_full.name: face_morpher_output = self.get_output(Network.face_morpher.outputs_key, modules, batch, outputs) face_morphed_image = face_morpher_output[0] input_image = batch[0].clone() input_image[:, :, 32:32 + 192, 32 + 128:32 + 192 + 128] = face_morphed_image return [input_image] elif key == Branch.face_morphed_half.name: face_morphed_full = self.get_output(Branch.face_morphed_full.name, modules, batch, outputs)[0] return [ interpolate(face_morphed_full, size=(256, 256), mode='bilinear', align_corners=False) ] elif key == Network.two_algo_face_body_rotator.outputs_key: face_morphed_half = self.get_output(Branch.face_morphed_half.name, modules, batch, outputs)[0] rotation_pose = batch[1][:, NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS:] return modules[Network.two_algo_face_body_rotator.name].forward(face_morphed_half, rotation_pose) elif key == Network.editor.outputs_key: input_original_image = self.get_output(Branch.face_morphed_full.name, modules, batch, outputs)[0] rotator_outputs = self.get_output( Network.two_algo_face_body_rotator.outputs_key, modules, batch, outputs) half_warped_image = rotator_outputs[TwoAlgoFaceBodyRotator05.WARPED_IMAGE_INDEX] full_warped_image = interpolate( half_warped_image, size=(512, 512), mode='bilinear', align_corners=False) half_grid_change = rotator_outputs[TwoAlgoFaceBodyRotator05.GRID_CHANGE_INDEX] full_grid_change = interpolate( half_grid_change, size=(512, 512), mode='bilinear', align_corners=False) rotation_pose = batch[1][:, NUM_EYEBROW_PARAMS + NUM_FACE_PARAMS:] return modules[Network.editor.name].forward( input_original_image, full_warped_image, full_grid_change, rotation_pose) elif key == Branch.all_outputs.name: editor_output = self.get_output(Network.editor.outputs_key, modules, batch, outputs) rotater_output = self.get_output(Network.two_algo_face_body_rotator.outputs_key, modules, batch, outputs) face_morpher_output = self.get_output(Network.face_morpher.outputs_key, modules, batch, outputs) eyebrow_morphing_combiner_output = self.get_output( Network.eyebrow_morphing_combiner.outputs_key, modules, batch, outputs) eyebrow_decomposer_output = self.get_output( Network.eyebrow_decomposer.outputs_key, modules, batch, outputs) output = editor_output \ + rotater_output \ + face_morpher_output \ + eyebrow_morphing_combiner_output \ + eyebrow_decomposer_output return output else: raise RuntimeError("Unsupported key: " + key) def load_eyebrow_decomposer(file_name: str): factory = EyebrowDecomposer03Factory( EyebrowDecomposer03Args( image_size=128, image_channels=4, start_channels=64, bottleneck_image_size=16, num_bottleneck_blocks=6, max_channels=512, block_args=BlockArgs( initialization_method='he', use_spectral_norm=False, normalization_layer_factory=InstanceNorm2dFactory(), nonlinearity_factory=ReLUFactory(inplace=True)))) print("Loading the eyebrow decomposer ... ", end="") module = factory.create().half() module.load_state_dict(torch_load(file_name)) print("DONE!!!") return module def load_eyebrow_morphing_combiner(file_name: str): factory = EyebrowMorphingCombiner03Factory( EyebrowMorphingCombiner03Args( image_size=128, image_channels=4, start_channels=64, num_pose_params=12, bottleneck_image_size=16, num_bottleneck_blocks=6, max_channels=512, block_args=BlockArgs( initialization_method='he', use_spectral_norm=False, normalization_layer_factory=InstanceNorm2dFactory(), nonlinearity_factory=ReLUFactory(inplace=True)))) print("Loading the eyebrow morphing conbiner ... ", end="") module = factory.create().half() module.load_state_dict(torch_load(file_name)) print("DONE!!!") return module def load_face_morpher(file_name: str): factory = FaceMorpher09Factory( FaceMorpher09Args( image_size=192, image_channels=4, num_pose_params=27, start_channels=64, bottleneck_image_size=24, num_bottleneck_blocks=6, max_channels=512, block_args=BlockArgs( initialization_method='he', use_spectral_norm=False, normalization_layer_factory=InstanceNorm2dFactory(), nonlinearity_factory=ReLUFactory(inplace=False)))) print("Loading the face morpher ... ", end="") module = factory.create().half() module.load_state_dict(torch_load(file_name)) print("DONE!!!") return module def load_two_algo_generator(file_name) -> Module: module = TwoAlgoFaceBodyRotator05( TwoAlgoFaceBodyRotator05Args( image_size=256, image_channels=4, start_channels=64, num_pose_params=6, bottleneck_image_size=32, num_bottleneck_blocks=6, max_channels=512, upsample_mode='nearest', use_separable_convolution=True, block_args=BlockArgs( initialization_method='he', use_spectral_norm=False, normalization_layer_factory=InstanceNorm2dFactory(), nonlinearity_factory=LeakyReLUFactory(inplace=False, negative_slope=0.1)))).half() print("Loading the face-body rotator ... ", end="") module.load_state_dict(torch_load(file_name)) print("DONE!!!") return module def load_editor(file_name) -> Module: module = Editor07( Editor07Args( image_size=512, image_channels=4, num_pose_params=6, start_channels=32, bottleneck_image_size=64, num_bottleneck_blocks=6, max_channels=512, upsampling_mode='nearest', use_separable_convolution=True, block_args=BlockArgs( initialization_method='he', use_spectral_norm=False, normalization_layer_factory=InstanceNorm2dFactory(), nonlinearity_factory=LeakyReLUFactory(inplace=False, negative_slope=0.1)))).half() print("Loading the combiner ... ", end="") module.load_state_dict(torch_load(file_name)) print("DONE!!!") return module def get_pose_parameters(): return PoseParameters.Builder() \ .add_parameter_group("eyebrow_troubled", PoseParameterCategory.EYEBROW, arity=2) \ .add_parameter_group("eyebrow_angry", PoseParameterCategory.EYEBROW, arity=2) \ .add_parameter_group("eyebrow_lowered", PoseParameterCategory.EYEBROW, arity=2) \ .add_parameter_group("eyebrow_raised", PoseParameterCategory.EYEBROW, arity=2) \ .add_parameter_group("eyebrow_happy", PoseParameterCategory.EYEBROW, arity=2) \ .add_parameter_group("eyebrow_serious", PoseParameterCategory.EYEBROW, arity=2) \ .add_parameter_group("eye_wink", PoseParameterCategory.EYE, arity=2) \ .add_parameter_group("eye_happy_wink", PoseParameterCategory.EYE, arity=2) \ .add_parameter_group("eye_surprised", PoseParameterCategory.EYE, arity=2) \ .add_parameter_group("eye_relaxed", PoseParameterCategory.EYE, arity=2) \ .add_parameter_group("eye_unimpressed", PoseParameterCategory.EYE, arity=2) \ .add_parameter_group("eye_raised_lower_eyelid", PoseParameterCategory.EYE, arity=2) \ .add_parameter_group("iris_small", PoseParameterCategory.IRIS_MORPH, arity=2) \ .add_parameter_group("mouth_aaa", PoseParameterCategory.MOUTH, arity=1, default_value=1.0) \ .add_parameter_group("mouth_iii", PoseParameterCategory.MOUTH, arity=1) \ .add_parameter_group("mouth_uuu", PoseParameterCategory.MOUTH, arity=1) \ .add_parameter_group("mouth_eee", PoseParameterCategory.MOUTH, arity=1) \ .add_parameter_group("mouth_ooo", PoseParameterCategory.MOUTH, arity=1) \ .add_parameter_group("mouth_delta", PoseParameterCategory.MOUTH, arity=1) \ .add_parameter_group("mouth_lowered_corner", PoseParameterCategory.MOUTH, arity=2) \ .add_parameter_group("mouth_raised_corner", PoseParameterCategory.MOUTH, arity=2) \ .add_parameter_group("mouth_smirk", PoseParameterCategory.MOUTH, arity=1) \ .add_parameter_group("iris_rotation_x", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \ .add_parameter_group("iris_rotation_y", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \ .add_parameter_group("head_x", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ .add_parameter_group("head_y", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ .add_parameter_group("neck_z", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ .add_parameter_group("body_y", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \ .add_parameter_group("body_z", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \ .add_parameter_group("breathing", PoseParameterCategory.BREATHING, arity=1, range=(0.0, 1.0)) \ .build() def create_poser( device: torch.device, module_file_names: Optional[Dict[str, str]] = None, eyebrow_morphed_image_index: int = EyebrowMorphingCombiner03.EYEBROW_IMAGE_NO_COMBINE_ALPHA_INDEX, default_output_index: int = 0) -> GeneralPoser02: if module_file_names is None: module_file_names = {} if Network.eyebrow_decomposer.name not in module_file_names: dir = "talkinghead/tha3/models/separable_half" file_name = dir + "/eyebrow_decomposer.pt" module_file_names[Network.eyebrow_decomposer.name] = file_name if Network.eyebrow_morphing_combiner.name not in module_file_names: dir = "talkinghead/tha3/models/separable_half" file_name = dir + "/eyebrow_morphing_combiner.pt" module_file_names[Network.eyebrow_morphing_combiner.name] = file_name if Network.face_morpher.name not in module_file_names: dir = "talkinghead/tha3/models/separable_half" file_name = dir + "/face_morpher.pt" module_file_names[Network.face_morpher.name] = file_name if Network.two_algo_face_body_rotator.name not in module_file_names: dir = "talkinghead/tha3/models/separable_half" file_name = dir + "/two_algo_face_body_rotator.pt" module_file_names[Network.two_algo_face_body_rotator.name] = file_name if Network.editor.name not in module_file_names: dir = "talkinghead/tha3/models/separable_half" file_name = dir + "/editor.pt" module_file_names[Network.editor.name] = file_name loaders = { Network.eyebrow_decomposer.name: lambda: load_eyebrow_decomposer(module_file_names[Network.eyebrow_decomposer.name]), Network.eyebrow_morphing_combiner.name: lambda: load_eyebrow_morphing_combiner(module_file_names[Network.eyebrow_morphing_combiner.name]), Network.face_morpher.name: lambda: load_face_morpher(module_file_names[Network.face_morpher.name]), Network.two_algo_face_body_rotator.name: lambda: load_two_algo_generator(module_file_names[Network.two_algo_face_body_rotator.name]), Network.editor.name: lambda: load_editor(module_file_names[Network.editor.name]), } return GeneralPoser02( image_size=512, module_loaders=loaders, pose_parameters=get_pose_parameters().get_pose_parameter_groups(), output_list_func=FiveStepPoserComputationProtocol(eyebrow_morphed_image_index).compute_func(), subrect=None, device=device, output_length=29, dtype=torch.half, default_output_index=default_output_index) if __name__ == "__main__": device = torch.device('cuda') poser = create_poser(device) image = torch.zeros(1, 4, 512, 512, device=device, dtype=torch.half) pose = torch.zeros(1, 45, device=device, dtype=torch.half) repeat = 100 acc = 0.0 for i in range(repeat + 2): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() poser.pose(image, pose) end.record() torch.cuda.synchronize() if i >= 2: elapsed_time = start.elapsed_time(end) print("%d:" % i, elapsed_time) acc = acc + elapsed_time print("average:", acc / repeat)