Spaces:
Runtime error
Runtime error
| from abc import ABC, abstractmethod | |
| from enum import Enum | |
| from typing import Tuple, List, Optional | |
| import torch | |
| from torch import Tensor | |
| class PoseParameterCategory(Enum): | |
| EYEBROW = 1 | |
| EYE = 2 | |
| IRIS_MORPH = 3 | |
| IRIS_ROTATION = 4 | |
| MOUTH = 5 | |
| FACE_ROTATION = 6 | |
| BODY_ROTATION = 7 | |
| BREATHING = 8 | |
| class PoseParameterGroup: | |
| def __init__(self, | |
| group_name: str, | |
| parameter_index: int, | |
| category: PoseParameterCategory, | |
| arity: int = 1, | |
| discrete: bool = False, | |
| default_value: float = 0.0, | |
| range: Optional[Tuple[float, float]] = None): | |
| assert arity == 1 or arity == 2 | |
| if range is None: | |
| range = (0.0, 1.0) | |
| if arity == 1: | |
| parameter_names = [group_name] | |
| else: | |
| parameter_names = [group_name + "_left", group_name + "_right"] | |
| assert len(parameter_names) == arity | |
| self.parameter_names = parameter_names | |
| self.range = range | |
| self.default_value = default_value | |
| self.discrete = discrete | |
| self.arity = arity | |
| self.category = category | |
| self.parameter_index = parameter_index | |
| self.group_name = group_name | |
| def get_arity(self) -> int: | |
| return self.arity | |
| def get_group_name(self) -> str: | |
| return self.group_name | |
| def get_parameter_names(self) -> List[str]: | |
| return self.parameter_names | |
| def is_discrete(self) -> bool: | |
| return self.discrete | |
| def get_range(self) -> Tuple[float, float]: | |
| return self.range | |
| def get_default_value(self): | |
| return self.default_value | |
| def get_parameter_index(self): | |
| return self.parameter_index | |
| def get_category(self) -> PoseParameterCategory: | |
| return self.category | |
| class PoseParameters: | |
| def __init__(self, pose_parameter_groups: List[PoseParameterGroup]): | |
| self.pose_parameter_groups = pose_parameter_groups | |
| def get_parameter_index(self, name: str) -> int: | |
| index = 0 | |
| for parameter_group in self.pose_parameter_groups: | |
| for param_name in parameter_group.parameter_names: | |
| if name == param_name: | |
| return index | |
| index += 1 | |
| raise RuntimeError("Cannot find parameter with name %s" % name) | |
| def get_parameter_name(self, index: int) -> str: | |
| assert index >= 0 and index < self.get_parameter_count() | |
| for group in self.pose_parameter_groups: | |
| if index < group.get_arity(): | |
| return group.get_parameter_names()[index] | |
| index -= group.arity | |
| raise RuntimeError("Something is wrong here!!!") | |
| def get_pose_parameter_groups(self): | |
| return self.pose_parameter_groups | |
| def get_parameter_count(self): | |
| count = 0 | |
| for group in self.pose_parameter_groups: | |
| count += group.arity | |
| return count | |
| class Builder: | |
| def __init__(self): | |
| self.index = 0 | |
| self.pose_parameter_groups = [] | |
| def add_parameter_group(self, | |
| group_name: str, | |
| category: PoseParameterCategory, | |
| arity: int = 1, | |
| discrete: bool = False, | |
| default_value: float = 0.0, | |
| range: Optional[Tuple[float, float]] = None): | |
| self.pose_parameter_groups.append( | |
| PoseParameterGroup( | |
| group_name, | |
| self.index, | |
| category, | |
| arity, | |
| discrete, | |
| default_value, | |
| range)) | |
| self.index += arity | |
| return self | |
| def build(self) -> 'PoseParameters': | |
| return PoseParameters(self.pose_parameter_groups) | |
| class Poser(ABC): | |
| def get_image_size(self) -> int: | |
| pass | |
| def get_output_length(self) -> int: | |
| pass | |
| def get_pose_parameter_groups(self) -> List[PoseParameterGroup]: | |
| pass | |
| def get_num_parameters(self) -> int: | |
| pass | |
| def pose(self, image: Tensor, pose: Tensor, output_index: int = 0) -> Tensor: | |
| pass | |
| def get_posing_outputs(self, image: Tensor, pose: Tensor) -> List[Tensor]: | |
| pass | |
| def get_dtype(self) -> torch.dtype: | |
| return torch.float | |