File size: 4,581 Bytes
6a62ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
from abc import ABC, abstractmethod
from enum import Enum
from typing import Tuple, List, Optional
import torch
from torch import Tensor
class PoseParameterCategory(Enum):
EYEBROW = 1
EYE = 2
IRIS_MORPH = 3
IRIS_ROTATION = 4
MOUTH = 5
FACE_ROTATION = 6
BODY_ROTATION = 7
BREATHING = 8
class PoseParameterGroup:
def __init__(self,
group_name: str,
parameter_index: int,
category: PoseParameterCategory,
arity: int = 1,
discrete: bool = False,
default_value: float = 0.0,
range: Optional[Tuple[float, float]] = None):
assert arity == 1 or arity == 2
if range is None:
range = (0.0, 1.0)
if arity == 1:
parameter_names = [group_name]
else:
parameter_names = [group_name + "_left", group_name + "_right"]
assert len(parameter_names) == arity
self.parameter_names = parameter_names
self.range = range
self.default_value = default_value
self.discrete = discrete
self.arity = arity
self.category = category
self.parameter_index = parameter_index
self.group_name = group_name
def get_arity(self) -> int:
return self.arity
def get_group_name(self) -> str:
return self.group_name
def get_parameter_names(self) -> List[str]:
return self.parameter_names
def is_discrete(self) -> bool:
return self.discrete
def get_range(self) -> Tuple[float, float]:
return self.range
def get_default_value(self):
return self.default_value
def get_parameter_index(self):
return self.parameter_index
def get_category(self) -> PoseParameterCategory:
return self.category
class PoseParameters:
def __init__(self, pose_parameter_groups: List[PoseParameterGroup]):
self.pose_parameter_groups = pose_parameter_groups
def get_parameter_index(self, name: str) -> int:
index = 0
for parameter_group in self.pose_parameter_groups:
for param_name in parameter_group.parameter_names:
if name == param_name:
return index
index += 1
raise RuntimeError("Cannot find parameter with name %s" % name)
def get_parameter_name(self, index: int) -> str:
assert index >= 0 and index < self.get_parameter_count()
for group in self.pose_parameter_groups:
if index < group.get_arity():
return group.get_parameter_names()[index]
index -= group.arity
raise RuntimeError("Something is wrong here!!!")
def get_pose_parameter_groups(self):
return self.pose_parameter_groups
def get_parameter_count(self):
count = 0
for group in self.pose_parameter_groups:
count += group.arity
return count
class Builder:
def __init__(self):
self.index = 0
self.pose_parameter_groups = []
def add_parameter_group(self,
group_name: str,
category: PoseParameterCategory,
arity: int = 1,
discrete: bool = False,
default_value: float = 0.0,
range: Optional[Tuple[float, float]] = None):
self.pose_parameter_groups.append(
PoseParameterGroup(
group_name,
self.index,
category,
arity,
discrete,
default_value,
range))
self.index += arity
return self
def build(self) -> 'PoseParameters':
return PoseParameters(self.pose_parameter_groups)
class Poser(ABC):
@abstractmethod
def get_image_size(self) -> int:
pass
@abstractmethod
def get_output_length(self) -> int:
pass
@abstractmethod
def get_pose_parameter_groups(self) -> List[PoseParameterGroup]:
pass
@abstractmethod
def get_num_parameters(self) -> int:
pass
@abstractmethod
def pose(self, image: Tensor, pose: Tensor, output_index: int = 0) -> Tensor:
pass
@abstractmethod
def get_posing_outputs(self, image: Tensor, pose: Tensor) -> List[Tensor]:
pass
def get_dtype(self) -> torch.dtype:
return torch.float
|