File size: 2,931 Bytes
6a62ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
from typing import List, Optional, Tuple, Dict, Callable
import torch
from torch import Tensor
from torch.nn import Module
from tha3.poser.poser import PoseParameterGroup, Poser
from tha3.compute.cached_computation_func import TensorListCachedComputationFunc
class GeneralPoser02(Poser):
def __init__(self,
module_loaders: Dict[str, Callable[[], Module]],
device: torch.device,
output_length: int,
pose_parameters: List[PoseParameterGroup],
output_list_func: TensorListCachedComputationFunc,
subrect: Optional[Tuple[Tuple[int, int], Tuple[int, int]]] = None,
default_output_index: int = 0,
image_size: int = 256,
dtype: torch.dtype = torch.float):
self.dtype = dtype
self.image_size = image_size
self.default_output_index = default_output_index
self.output_list_func = output_list_func
self.subrect = subrect
self.pose_parameters = pose_parameters
self.device = device
self.module_loaders = module_loaders
self.modules = None
self.num_parameters = 0
for pose_parameter in self.pose_parameters:
self.num_parameters += pose_parameter.get_arity()
self.output_length = output_length
def get_image_size(self) -> int:
return self.image_size
def get_modules(self):
if self.modules is None:
self.modules = {}
for key in self.module_loaders:
module = self.module_loaders[key]()
self.modules[key] = module
module.to(self.device)
module.train(False)
return self.modules
def get_pose_parameter_groups(self) -> List[PoseParameterGroup]:
return self.pose_parameters
def get_num_parameters(self) -> int:
return self.num_parameters
def pose(self, image: Tensor, pose: Tensor, output_index: Optional[int] = None) -> Tensor:
if output_index is None:
output_index = self.default_output_index
output_list = self.get_posing_outputs(image, pose)
return output_list[output_index]
def get_posing_outputs(self, image: Tensor, pose: Tensor) -> List[Tensor]:
modules = self.get_modules()
if len(image.shape) == 3:
image = image.unsqueeze(0)
if len(pose.shape) == 1:
pose = pose.unsqueeze(0)
if self.subrect is not None:
image = image[:, :, self.subrect[0][0]:self.subrect[0][1], self.subrect[1][0]:self.subrect[1][1]]
batch = [image, pose]
outputs = {}
return self.output_list_func(modules, batch, outputs)
def get_output_length(self) -> int:
return self.output_length
def free(self):
self.modules = None
def get_dtype(self) -> torch.dtype:
return self.dtype
|