from typing import List, Optional, Tuple, Dict, Callable import torch from torch import Tensor from torch.nn import Module from tha3.poser.poser import PoseParameterGroup, Poser from tha3.compute.cached_computation_func import TensorListCachedComputationFunc class GeneralPoser02(Poser): def __init__(self, module_loaders: Dict[str, Callable[[], Module]], device: torch.device, output_length: int, pose_parameters: List[PoseParameterGroup], output_list_func: TensorListCachedComputationFunc, subrect: Optional[Tuple[Tuple[int, int], Tuple[int, int]]] = None, default_output_index: int = 0, image_size: int = 256, dtype: torch.dtype = torch.float): self.dtype = dtype self.image_size = image_size self.default_output_index = default_output_index self.output_list_func = output_list_func self.subrect = subrect self.pose_parameters = pose_parameters self.device = device self.module_loaders = module_loaders self.modules = None self.num_parameters = 0 for pose_parameter in self.pose_parameters: self.num_parameters += pose_parameter.get_arity() self.output_length = output_length def get_image_size(self) -> int: return self.image_size def get_modules(self): if self.modules is None: self.modules = {} for key in self.module_loaders: module = self.module_loaders[key]() self.modules[key] = module module.to(self.device) module.train(False) return self.modules def get_pose_parameter_groups(self) -> List[PoseParameterGroup]: return self.pose_parameters def get_num_parameters(self) -> int: return self.num_parameters def pose(self, image: Tensor, pose: Tensor, output_index: Optional[int] = None) -> Tensor: if output_index is None: output_index = self.default_output_index output_list = self.get_posing_outputs(image, pose) return output_list[output_index] def get_posing_outputs(self, image: Tensor, pose: Tensor) -> List[Tensor]: modules = self.get_modules() if len(image.shape) == 3: image = image.unsqueeze(0) if len(pose.shape) == 1: pose = pose.unsqueeze(0) if self.subrect is not None: image = image[:, :, self.subrect[0][0]:self.subrect[0][1], self.subrect[1][0]:self.subrect[1][1]] batch = [image, pose] outputs = {} return self.output_list_func(modules, batch, outputs) def get_output_length(self) -> int: return self.output_length def free(self): self.modules = None def get_dtype(self) -> torch.dtype: return self.dtype