File size: 2,931 Bytes
6a62ffb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from typing import List, Optional, Tuple, Dict, Callable

import torch
from torch import Tensor
from torch.nn import Module

from tha3.poser.poser import PoseParameterGroup, Poser
from tha3.compute.cached_computation_func import TensorListCachedComputationFunc


class GeneralPoser02(Poser):
    def __init__(self,
                 module_loaders: Dict[str, Callable[[], Module]],
                 device: torch.device,
                 output_length: int,
                 pose_parameters: List[PoseParameterGroup],
                 output_list_func: TensorListCachedComputationFunc,
                 subrect: Optional[Tuple[Tuple[int, int], Tuple[int, int]]] = None,
                 default_output_index: int = 0,
                 image_size: int = 256,
                 dtype: torch.dtype = torch.float):
        self.dtype = dtype
        self.image_size = image_size
        self.default_output_index = default_output_index
        self.output_list_func = output_list_func
        self.subrect = subrect
        self.pose_parameters = pose_parameters
        self.device = device
        self.module_loaders = module_loaders

        self.modules = None

        self.num_parameters = 0
        for pose_parameter in self.pose_parameters:
            self.num_parameters += pose_parameter.get_arity()

        self.output_length = output_length

    def get_image_size(self) -> int:
        return self.image_size

    def get_modules(self):
        if self.modules is None:
            self.modules = {}
            for key in self.module_loaders:
                module = self.module_loaders[key]()
                self.modules[key] = module
                module.to(self.device)
                module.train(False)
        return self.modules

    def get_pose_parameter_groups(self) -> List[PoseParameterGroup]:
        return self.pose_parameters

    def get_num_parameters(self) -> int:
        return self.num_parameters

    def pose(self, image: Tensor, pose: Tensor, output_index: Optional[int] = None) -> Tensor:
        if output_index is None:
            output_index = self.default_output_index
        output_list = self.get_posing_outputs(image, pose)
        return output_list[output_index]

    def get_posing_outputs(self, image: Tensor, pose: Tensor) -> List[Tensor]:
        modules = self.get_modules()

        if len(image.shape) == 3:
            image = image.unsqueeze(0)
        if len(pose.shape) == 1:
            pose = pose.unsqueeze(0)
        if self.subrect is not None:
            image = image[:, :, self.subrect[0][0]:self.subrect[0][1], self.subrect[1][0]:self.subrect[1][1]]
        batch = [image, pose]

        outputs = {}
        return self.output_list_func(modules, batch, outputs)

    def get_output_length(self) -> int:
        return self.output_length

    def free(self):
        self.modules = None

    def get_dtype(self) -> torch.dtype:
        return self.dtype