TomatoCocotree
上传
6a62ffb
raw
history blame
4.58 kB
from abc import ABC, abstractmethod
from enum import Enum
from typing import Tuple, List, Optional
import torch
from torch import Tensor
class PoseParameterCategory(Enum):
EYEBROW = 1
EYE = 2
IRIS_MORPH = 3
IRIS_ROTATION = 4
MOUTH = 5
FACE_ROTATION = 6
BODY_ROTATION = 7
BREATHING = 8
class PoseParameterGroup:
def __init__(self,
group_name: str,
parameter_index: int,
category: PoseParameterCategory,
arity: int = 1,
discrete: bool = False,
default_value: float = 0.0,
range: Optional[Tuple[float, float]] = None):
assert arity == 1 or arity == 2
if range is None:
range = (0.0, 1.0)
if arity == 1:
parameter_names = [group_name]
else:
parameter_names = [group_name + "_left", group_name + "_right"]
assert len(parameter_names) == arity
self.parameter_names = parameter_names
self.range = range
self.default_value = default_value
self.discrete = discrete
self.arity = arity
self.category = category
self.parameter_index = parameter_index
self.group_name = group_name
def get_arity(self) -> int:
return self.arity
def get_group_name(self) -> str:
return self.group_name
def get_parameter_names(self) -> List[str]:
return self.parameter_names
def is_discrete(self) -> bool:
return self.discrete
def get_range(self) -> Tuple[float, float]:
return self.range
def get_default_value(self):
return self.default_value
def get_parameter_index(self):
return self.parameter_index
def get_category(self) -> PoseParameterCategory:
return self.category
class PoseParameters:
def __init__(self, pose_parameter_groups: List[PoseParameterGroup]):
self.pose_parameter_groups = pose_parameter_groups
def get_parameter_index(self, name: str) -> int:
index = 0
for parameter_group in self.pose_parameter_groups:
for param_name in parameter_group.parameter_names:
if name == param_name:
return index
index += 1
raise RuntimeError("Cannot find parameter with name %s" % name)
def get_parameter_name(self, index: int) -> str:
assert index >= 0 and index < self.get_parameter_count()
for group in self.pose_parameter_groups:
if index < group.get_arity():
return group.get_parameter_names()[index]
index -= group.arity
raise RuntimeError("Something is wrong here!!!")
def get_pose_parameter_groups(self):
return self.pose_parameter_groups
def get_parameter_count(self):
count = 0
for group in self.pose_parameter_groups:
count += group.arity
return count
class Builder:
def __init__(self):
self.index = 0
self.pose_parameter_groups = []
def add_parameter_group(self,
group_name: str,
category: PoseParameterCategory,
arity: int = 1,
discrete: bool = False,
default_value: float = 0.0,
range: Optional[Tuple[float, float]] = None):
self.pose_parameter_groups.append(
PoseParameterGroup(
group_name,
self.index,
category,
arity,
discrete,
default_value,
range))
self.index += arity
return self
def build(self) -> 'PoseParameters':
return PoseParameters(self.pose_parameter_groups)
class Poser(ABC):
@abstractmethod
def get_image_size(self) -> int:
pass
@abstractmethod
def get_output_length(self) -> int:
pass
@abstractmethod
def get_pose_parameter_groups(self) -> List[PoseParameterGroup]:
pass
@abstractmethod
def get_num_parameters(self) -> int:
pass
@abstractmethod
def pose(self, image: Tensor, pose: Tensor, output_index: int = 0) -> Tensor:
pass
@abstractmethod
def get_posing_outputs(self, image: Tensor, pose: Tensor) -> List[Tensor]:
pass
def get_dtype(self) -> torch.dtype:
return torch.float