|
import os.path |
|
import warnings |
|
|
|
import attr |
|
import cattr |
|
from typing import ( |
|
Dict, |
|
Optional, |
|
List, |
|
Any, |
|
DefaultDict, |
|
Mapping, |
|
Tuple, |
|
Union, |
|
ClassVar, |
|
) |
|
from enum import Enum |
|
import collections |
|
import argparse |
|
import abc |
|
import numpy as np |
|
import math |
|
import copy |
|
|
|
from mlagents.trainers.cli_utils import StoreConfigFile, DetectDefault, parser |
|
from mlagents.trainers.cli_utils import load_config |
|
from mlagents.trainers.exception import TrainerConfigError, TrainerConfigWarning |
|
|
|
from mlagents_envs import logging_util |
|
from mlagents_envs.side_channel.environment_parameters_channel import ( |
|
EnvironmentParametersChannel, |
|
) |
|
from mlagents.plugins import all_trainer_settings, all_trainer_types |
|
|
|
logger = logging_util.get_logger(__name__) |
|
|
|
|
|
def check_and_structure(key: str, value: Any, class_type: type) -> Any: |
|
attr_fields_dict = attr.fields_dict(class_type) |
|
if key not in attr_fields_dict: |
|
raise TrainerConfigError( |
|
f"The option {key} was specified in your YAML file for {class_type.__name__}, but is invalid." |
|
) |
|
|
|
return cattr.structure(value, attr_fields_dict[key].type) |
|
|
|
|
|
def check_hyperparam_schedules(val: Dict, trainer_type: str) -> Dict: |
|
|
|
if trainer_type == "ppo" or trainer_type == "poca": |
|
if "beta_schedule" not in val.keys() and "learning_rate_schedule" in val.keys(): |
|
val["beta_schedule"] = val["learning_rate_schedule"] |
|
if ( |
|
"epsilon_schedule" not in val.keys() |
|
and "learning_rate_schedule" in val.keys() |
|
): |
|
val["epsilon_schedule"] = val["learning_rate_schedule"] |
|
return val |
|
|
|
|
|
def strict_to_cls(d: Mapping, t: type) -> Any: |
|
if not isinstance(d, Mapping): |
|
raise TrainerConfigError(f"Unsupported config {d} for {t.__name__}.") |
|
d_copy: Dict[str, Any] = {} |
|
d_copy.update(d) |
|
for key, val in d_copy.items(): |
|
d_copy[key] = check_and_structure(key, val, t) |
|
return t(**d_copy) |
|
|
|
|
|
def defaultdict_to_dict(d: DefaultDict) -> Dict: |
|
return {key: cattr.unstructure(val) for key, val in d.items()} |
|
|
|
|
|
def deep_update_dict(d: Dict, update_d: Mapping) -> None: |
|
""" |
|
Similar to dict.update(), but works for nested dicts of dicts as well. |
|
""" |
|
for key, val in update_d.items(): |
|
if key in d and isinstance(d[key], Mapping) and isinstance(val, Mapping): |
|
deep_update_dict(d[key], val) |
|
else: |
|
d[key] = val |
|
|
|
|
|
class SerializationSettings: |
|
convert_to_onnx = True |
|
onnx_opset = 9 |
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class ExportableSettings: |
|
def as_dict(self): |
|
return cattr.unstructure(self) |
|
|
|
|
|
class EncoderType(Enum): |
|
FULLY_CONNECTED = "fully_connected" |
|
MATCH3 = "match3" |
|
SIMPLE = "simple" |
|
NATURE_CNN = "nature_cnn" |
|
RESNET = "resnet" |
|
|
|
|
|
class ScheduleType(Enum): |
|
CONSTANT = "constant" |
|
LINEAR = "linear" |
|
|
|
|
|
|
|
|
|
class ConditioningType(Enum): |
|
HYPER = "hyper" |
|
NONE = "none" |
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class NetworkSettings: |
|
@attr.s |
|
class MemorySettings: |
|
sequence_length: int = attr.ib(default=64) |
|
memory_size: int = attr.ib(default=128) |
|
|
|
@memory_size.validator |
|
def _check_valid_memory_size(self, attribute, value): |
|
if value <= 0: |
|
raise TrainerConfigError( |
|
"When using a recurrent network, memory size must be greater than 0." |
|
) |
|
elif value % 2 != 0: |
|
raise TrainerConfigError( |
|
"When using a recurrent network, memory size must be divisible by 2." |
|
) |
|
|
|
normalize: bool = False |
|
hidden_units: int = 128 |
|
num_layers: int = 2 |
|
vis_encode_type: EncoderType = EncoderType.SIMPLE |
|
memory: Optional[MemorySettings] = None |
|
goal_conditioning_type: ConditioningType = ConditioningType.HYPER |
|
deterministic: bool = parser.get_default("deterministic") |
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class BehavioralCloningSettings: |
|
demo_path: str |
|
steps: int = 0 |
|
strength: float = 1.0 |
|
samples_per_update: int = 0 |
|
|
|
|
|
num_epoch: Optional[int] = None |
|
batch_size: Optional[int] = None |
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class HyperparamSettings: |
|
batch_size: int = 1024 |
|
buffer_size: int = 10240 |
|
learning_rate: float = 3.0e-4 |
|
learning_rate_schedule: ScheduleType = ScheduleType.CONSTANT |
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class OnPolicyHyperparamSettings(HyperparamSettings): |
|
num_epoch: int = 3 |
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class OffPolicyHyperparamSettings(HyperparamSettings): |
|
batch_size: int = 128 |
|
buffer_size: int = 50000 |
|
buffer_init_steps: int = 0 |
|
steps_per_update: float = 1 |
|
save_replay_buffer: bool = False |
|
reward_signal_steps_per_update: float = 4 |
|
|
|
|
|
|
|
class RewardSignalType(Enum): |
|
EXTRINSIC: str = "extrinsic" |
|
GAIL: str = "gail" |
|
CURIOSITY: str = "curiosity" |
|
RND: str = "rnd" |
|
|
|
def to_settings(self) -> type: |
|
_mapping = { |
|
RewardSignalType.EXTRINSIC: RewardSignalSettings, |
|
RewardSignalType.GAIL: GAILSettings, |
|
RewardSignalType.CURIOSITY: CuriositySettings, |
|
RewardSignalType.RND: RNDSettings, |
|
} |
|
return _mapping[self] |
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class RewardSignalSettings: |
|
gamma: float = 0.99 |
|
strength: float = 1.0 |
|
network_settings: NetworkSettings = attr.ib(factory=NetworkSettings) |
|
|
|
@staticmethod |
|
def structure(d: Mapping, t: type) -> Any: |
|
""" |
|
Helper method to structure a Dict of RewardSignalSettings class. Meant to be registered with |
|
cattr.register_structure_hook() and called with cattr.structure(). This is needed to handle |
|
the special Enum selection of RewardSignalSettings classes. |
|
""" |
|
if not isinstance(d, Mapping): |
|
raise TrainerConfigError(f"Unsupported reward signal configuration {d}.") |
|
d_final: Dict[RewardSignalType, RewardSignalSettings] = {} |
|
for key, val in d.items(): |
|
enum_key = RewardSignalType(key) |
|
t = enum_key.to_settings() |
|
d_final[enum_key] = strict_to_cls(val, t) |
|
|
|
|
|
|
|
|
|
if "encoding_size" in val: |
|
logger.warning( |
|
"'encoding_size' was deprecated for RewardSignals. Please use network_settings." |
|
) |
|
|
|
if "network_settings" not in val: |
|
d_final[enum_key].network_settings.hidden_units = val[ |
|
"encoding_size" |
|
] |
|
return d_final |
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class GAILSettings(RewardSignalSettings): |
|
learning_rate: float = 3e-4 |
|
encoding_size: Optional[int] = None |
|
use_actions: bool = False |
|
use_vail: bool = False |
|
demo_path: str = attr.ib(kw_only=True) |
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class CuriositySettings(RewardSignalSettings): |
|
learning_rate: float = 3e-4 |
|
encoding_size: Optional[int] = None |
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class RNDSettings(RewardSignalSettings): |
|
learning_rate: float = 1e-4 |
|
encoding_size: Optional[int] = None |
|
|
|
|
|
|
|
class ParameterRandomizationType(Enum): |
|
UNIFORM: str = "uniform" |
|
GAUSSIAN: str = "gaussian" |
|
MULTIRANGEUNIFORM: str = "multirangeuniform" |
|
CONSTANT: str = "constant" |
|
|
|
def to_settings(self) -> type: |
|
_mapping = { |
|
ParameterRandomizationType.UNIFORM: UniformSettings, |
|
ParameterRandomizationType.GAUSSIAN: GaussianSettings, |
|
ParameterRandomizationType.MULTIRANGEUNIFORM: MultiRangeUniformSettings, |
|
ParameterRandomizationType.CONSTANT: ConstantSettings |
|
|
|
} |
|
return _mapping[self] |
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class ParameterRandomizationSettings(abc.ABC): |
|
seed: int = parser.get_default("seed") |
|
|
|
def __str__(self) -> str: |
|
""" |
|
Helper method to output sampler stats to console. |
|
""" |
|
raise TrainerConfigError(f"__str__ not implemented for type {self.__class__}.") |
|
|
|
@staticmethod |
|
def structure( |
|
d: Union[Mapping, float], t: type |
|
) -> "ParameterRandomizationSettings": |
|
""" |
|
Helper method to a ParameterRandomizationSettings class. Meant to be registered with |
|
cattr.register_structure_hook() and called with cattr.structure(). This is needed to handle |
|
the special Enum selection of ParameterRandomizationSettings classes. |
|
""" |
|
if isinstance(d, (float, int)): |
|
return ConstantSettings(value=d) |
|
if not isinstance(d, Mapping): |
|
raise TrainerConfigError( |
|
f"Unsupported parameter randomization configuration {d}." |
|
) |
|
if "sampler_type" not in d: |
|
raise TrainerConfigError( |
|
f"Sampler configuration does not contain sampler_type : {d}." |
|
) |
|
if "sampler_parameters" not in d: |
|
raise TrainerConfigError( |
|
f"Sampler configuration does not contain sampler_parameters : {d}." |
|
) |
|
enum_key = ParameterRandomizationType(d["sampler_type"]) |
|
t = enum_key.to_settings() |
|
return strict_to_cls(d["sampler_parameters"], t) |
|
|
|
@staticmethod |
|
def unstructure(d: "ParameterRandomizationSettings") -> Mapping: |
|
""" |
|
Helper method to a ParameterRandomizationSettings class. Meant to be registered with |
|
cattr.register_unstructure_hook() and called with cattr.unstructure(). |
|
""" |
|
_reversed_mapping = { |
|
UniformSettings: ParameterRandomizationType.UNIFORM, |
|
GaussianSettings: ParameterRandomizationType.GAUSSIAN, |
|
MultiRangeUniformSettings: ParameterRandomizationType.MULTIRANGEUNIFORM, |
|
ConstantSettings: ParameterRandomizationType.CONSTANT, |
|
} |
|
sampler_type: Optional[str] = None |
|
for t, name in _reversed_mapping.items(): |
|
if isinstance(d, t): |
|
sampler_type = name.value |
|
sampler_parameters = attr.asdict(d) |
|
return {"sampler_type": sampler_type, "sampler_parameters": sampler_parameters} |
|
|
|
@abc.abstractmethod |
|
def apply(self, key: str, env_channel: EnvironmentParametersChannel) -> None: |
|
""" |
|
Helper method to send sampler settings over EnvironmentParametersChannel |
|
Calls the appropriate sampler type set method. |
|
:param key: environment parameter to be sampled |
|
:param env_channel: The EnvironmentParametersChannel to communicate sampler settings to environment |
|
""" |
|
pass |
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class ConstantSettings(ParameterRandomizationSettings): |
|
value: float = 0.0 |
|
|
|
def __str__(self) -> str: |
|
""" |
|
Helper method to output sampler stats to console. |
|
""" |
|
return f"Float: value={self.value}" |
|
|
|
def apply(self, key: str, env_channel: EnvironmentParametersChannel) -> None: |
|
""" |
|
Helper method to send sampler settings over EnvironmentParametersChannel |
|
Calls the constant sampler type set method. |
|
:param key: environment parameter to be sampled |
|
:param env_channel: The EnvironmentParametersChannel to communicate sampler settings to environment |
|
""" |
|
env_channel.set_float_parameter(key, self.value) |
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class UniformSettings(ParameterRandomizationSettings): |
|
min_value: float = attr.ib() |
|
max_value: float = 1.0 |
|
|
|
def __str__(self) -> str: |
|
""" |
|
Helper method to output sampler stats to console. |
|
""" |
|
return f"Uniform sampler: min={self.min_value}, max={self.max_value}" |
|
|
|
@min_value.default |
|
def _min_value_default(self): |
|
return 0.0 |
|
|
|
@min_value.validator |
|
def _check_min_value(self, attribute, value): |
|
if self.min_value > self.max_value: |
|
raise TrainerConfigError( |
|
"Minimum value is greater than maximum value in uniform sampler." |
|
) |
|
|
|
def apply(self, key: str, env_channel: EnvironmentParametersChannel) -> None: |
|
""" |
|
Helper method to send sampler settings over EnvironmentParametersChannel |
|
Calls the uniform sampler type set method. |
|
:param key: environment parameter to be sampled |
|
:param env_channel: The EnvironmentParametersChannel to communicate sampler settings to environment |
|
""" |
|
env_channel.set_uniform_sampler_parameters( |
|
key, self.min_value, self.max_value, self.seed |
|
) |
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class GaussianSettings(ParameterRandomizationSettings): |
|
mean: float = 1.0 |
|
st_dev: float = 1.0 |
|
|
|
def __str__(self) -> str: |
|
""" |
|
Helper method to output sampler stats to console. |
|
""" |
|
return f"Gaussian sampler: mean={self.mean}, stddev={self.st_dev}" |
|
|
|
def apply(self, key: str, env_channel: EnvironmentParametersChannel) -> None: |
|
""" |
|
Helper method to send sampler settings over EnvironmentParametersChannel |
|
Calls the gaussian sampler type set method. |
|
:param key: environment parameter to be sampled |
|
:param env_channel: The EnvironmentParametersChannel to communicate sampler settings to environment |
|
""" |
|
env_channel.set_gaussian_sampler_parameters( |
|
key, self.mean, self.st_dev, self.seed |
|
) |
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class MultiRangeUniformSettings(ParameterRandomizationSettings): |
|
intervals: List[Tuple[float, float]] = attr.ib() |
|
|
|
def __str__(self) -> str: |
|
""" |
|
Helper method to output sampler stats to console. |
|
""" |
|
return f"MultiRangeUniform sampler: intervals={self.intervals}" |
|
|
|
@intervals.default |
|
def _intervals_default(self): |
|
return [[0.0, 1.0]] |
|
|
|
@intervals.validator |
|
def _check_intervals(self, attribute, value): |
|
for interval in self.intervals: |
|
if len(interval) != 2: |
|
raise TrainerConfigError( |
|
f"The sampling interval {interval} must contain exactly two values." |
|
) |
|
min_value, max_value = interval |
|
if min_value > max_value: |
|
raise TrainerConfigError( |
|
f"Minimum value is greater than maximum value in interval {interval}." |
|
) |
|
|
|
def apply(self, key: str, env_channel: EnvironmentParametersChannel) -> None: |
|
""" |
|
Helper method to send sampler settings over EnvironmentParametersChannel |
|
Calls the multirangeuniform sampler type set method. |
|
:param key: environment parameter to be sampled |
|
:param env_channel: The EnvironmentParametersChannel to communicate sampler settings to environment |
|
""" |
|
env_channel.set_multirangeuniform_sampler_parameters( |
|
key, self.intervals, self.seed |
|
) |
|
|
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class CompletionCriteriaSettings: |
|
""" |
|
CompletionCriteriaSettings contains the information needed to figure out if the next |
|
lesson must start. |
|
""" |
|
|
|
class MeasureType(Enum): |
|
PROGRESS: str = "progress" |
|
REWARD: str = "reward" |
|
|
|
behavior: str |
|
measure: MeasureType = attr.ib(default=MeasureType.REWARD) |
|
min_lesson_length: int = 0 |
|
signal_smoothing: bool = True |
|
threshold: float = attr.ib(default=0.0) |
|
require_reset: bool = False |
|
|
|
@threshold.validator |
|
def _check_threshold_value(self, attribute, value): |
|
""" |
|
Verify that the threshold has a value between 0 and 1 when the measure is |
|
PROGRESS |
|
""" |
|
if self.measure == self.MeasureType.PROGRESS: |
|
if self.threshold > 1.0: |
|
raise TrainerConfigError( |
|
"Threshold for next lesson cannot be greater than 1 when the measure is progress." |
|
) |
|
if self.threshold < 0.0: |
|
raise TrainerConfigError( |
|
"Threshold for next lesson cannot be negative when the measure is progress." |
|
) |
|
|
|
def need_increment( |
|
self, progress: float, reward_buffer: List[float], smoothing: float |
|
) -> Tuple[bool, float]: |
|
""" |
|
Given measures, this method returns a boolean indicating if the lesson |
|
needs to change now, and a float corresponding to the new smoothed value. |
|
""" |
|
|
|
if len(reward_buffer) < self.min_lesson_length: |
|
return False, smoothing |
|
if self.measure == CompletionCriteriaSettings.MeasureType.PROGRESS: |
|
if progress > self.threshold: |
|
return True, smoothing |
|
if self.measure == CompletionCriteriaSettings.MeasureType.REWARD: |
|
if len(reward_buffer) < 1: |
|
return False, smoothing |
|
measure = np.mean(reward_buffer) |
|
if math.isnan(measure): |
|
return False, smoothing |
|
if self.signal_smoothing: |
|
measure = 0.25 * smoothing + 0.75 * measure |
|
smoothing = measure |
|
if measure > self.threshold: |
|
return True, smoothing |
|
return False, smoothing |
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class Lesson: |
|
""" |
|
Gathers the data of one lesson for one environment parameter including its name, |
|
the condition that must be fullfiled for the lesson to be completed and a sampler |
|
for the environment parameter. If the completion_criteria is None, then this is |
|
the last lesson in the curriculum. |
|
""" |
|
|
|
value: ParameterRandomizationSettings |
|
name: str |
|
completion_criteria: Optional[CompletionCriteriaSettings] = attr.ib(default=None) |
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class EnvironmentParameterSettings: |
|
""" |
|
EnvironmentParameterSettings is an ordered list of lessons for one environment |
|
parameter. |
|
""" |
|
|
|
curriculum: List[Lesson] |
|
|
|
@staticmethod |
|
def _check_lesson_chain(lessons, parameter_name): |
|
""" |
|
Ensures that when using curriculum, all non-terminal lessons have a valid |
|
CompletionCriteria, and that the terminal lesson does not contain a CompletionCriteria. |
|
""" |
|
num_lessons = len(lessons) |
|
for index, lesson in enumerate(lessons): |
|
if index < num_lessons - 1 and lesson.completion_criteria is None: |
|
raise TrainerConfigError( |
|
f"A non-terminal lesson does not have a completion_criteria for {parameter_name}." |
|
) |
|
if index == num_lessons - 1 and lesson.completion_criteria is not None: |
|
warnings.warn( |
|
f"Your final lesson definition contains completion_criteria for {parameter_name}." |
|
f"It will be ignored.", |
|
TrainerConfigWarning, |
|
) |
|
|
|
@staticmethod |
|
def structure(d: Mapping, t: type) -> Dict[str, "EnvironmentParameterSettings"]: |
|
""" |
|
Helper method to structure a Dict of EnvironmentParameterSettings class. Meant |
|
to be registered with cattr.register_structure_hook() and called with |
|
cattr.structure(). |
|
""" |
|
if not isinstance(d, Mapping): |
|
raise TrainerConfigError( |
|
f"Unsupported parameter environment parameter settings {d}." |
|
) |
|
d_final: Dict[str, EnvironmentParameterSettings] = {} |
|
for environment_parameter, environment_parameter_config in d.items(): |
|
if ( |
|
isinstance(environment_parameter_config, Mapping) |
|
and "curriculum" in environment_parameter_config |
|
): |
|
d_final[environment_parameter] = strict_to_cls( |
|
environment_parameter_config, EnvironmentParameterSettings |
|
) |
|
EnvironmentParameterSettings._check_lesson_chain( |
|
d_final[environment_parameter].curriculum, environment_parameter |
|
) |
|
else: |
|
sampler = ParameterRandomizationSettings.structure( |
|
environment_parameter_config, ParameterRandomizationSettings |
|
) |
|
d_final[environment_parameter] = EnvironmentParameterSettings( |
|
curriculum=[ |
|
Lesson( |
|
completion_criteria=None, |
|
value=sampler, |
|
name=environment_parameter, |
|
) |
|
] |
|
) |
|
return d_final |
|
|
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class SelfPlaySettings: |
|
save_steps: int = 20000 |
|
team_change: int = attr.ib() |
|
|
|
@team_change.default |
|
def _team_change_default(self): |
|
|
|
return self.save_steps * 5 |
|
|
|
swap_steps: int = 2000 |
|
window: int = 10 |
|
play_against_latest_model_ratio: float = 0.5 |
|
initial_elo: float = 1200.0 |
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class TrainerSettings(ExportableSettings): |
|
default_override: ClassVar[Optional["TrainerSettings"]] = None |
|
trainer_type: str = "ppo" |
|
hyperparameters: HyperparamSettings = attr.ib() |
|
checkpoint_interval: int = attr.ib() |
|
|
|
@hyperparameters.default |
|
def _set_default_hyperparameters(self): |
|
return all_trainer_settings[self.trainer_type]() |
|
|
|
@checkpoint_interval.default |
|
def _set_default_checkpoint_interval(self): |
|
return 500000 |
|
|
|
network_settings: NetworkSettings = attr.ib(factory=NetworkSettings) |
|
reward_signals: Dict[RewardSignalType, RewardSignalSettings] = attr.ib( |
|
factory=lambda: {RewardSignalType.EXTRINSIC: RewardSignalSettings()} |
|
) |
|
init_path: Optional[str] = None |
|
keep_checkpoints: int = 5 |
|
even_checkpoints: bool = False |
|
max_steps: int = 500000 |
|
time_horizon: int = 64 |
|
summary_freq: int = 50000 |
|
threaded: bool = False |
|
self_play: Optional[SelfPlaySettings] = None |
|
behavioral_cloning: Optional[BehavioralCloningSettings] = None |
|
|
|
cattr.register_structure_hook_func( |
|
lambda t: t == Dict[RewardSignalType, RewardSignalSettings], |
|
RewardSignalSettings.structure, |
|
) |
|
|
|
@network_settings.validator |
|
def _check_batch_size_seq_length(self, attribute, value): |
|
if self.network_settings.memory is not None: |
|
if ( |
|
self.network_settings.memory.sequence_length |
|
> self.hyperparameters.batch_size |
|
): |
|
raise TrainerConfigError( |
|
"When using memory, sequence length must be less than or equal to batch size. " |
|
) |
|
|
|
@checkpoint_interval.validator |
|
def _set_checkpoint_interval(self, attribute, value): |
|
if self.even_checkpoints: |
|
self.checkpoint_interval = int(self.max_steps / self.keep_checkpoints) |
|
|
|
@staticmethod |
|
def dict_to_trainerdict(d: Dict, t: type) -> "TrainerSettings.DefaultTrainerDict": |
|
return TrainerSettings.DefaultTrainerDict( |
|
cattr.structure(d, Dict[str, TrainerSettings]) |
|
) |
|
|
|
@staticmethod |
|
def structure(d: Mapping, t: type) -> Any: |
|
""" |
|
Helper method to structure a TrainerSettings class. Meant to be registered with |
|
cattr.register_structure_hook() and called with cattr.structure(). |
|
""" |
|
|
|
if not isinstance(d, Mapping): |
|
raise TrainerConfigError(f"Unsupported config {d} for {t.__name__}.") |
|
|
|
d_copy: Dict[str, Any] = {} |
|
|
|
|
|
|
|
if TrainerSettings.default_override is not None: |
|
d_copy.update(cattr.unstructure(TrainerSettings.default_override)) |
|
|
|
deep_update_dict(d_copy, d) |
|
|
|
if "framework" in d_copy: |
|
logger.warning("Framework option was deprecated but was specified") |
|
d_copy.pop("framework", None) |
|
|
|
for key, val in d_copy.items(): |
|
if attr.has(type(val)): |
|
|
|
continue |
|
if key == "hyperparameters": |
|
if "trainer_type" not in d_copy: |
|
raise TrainerConfigError( |
|
"Hyperparameters were specified but no trainer_type was given." |
|
) |
|
else: |
|
d_copy[key] = check_hyperparam_schedules( |
|
val, d_copy["trainer_type"] |
|
) |
|
try: |
|
d_copy[key] = strict_to_cls( |
|
d_copy[key], all_trainer_settings[d_copy["trainer_type"]] |
|
) |
|
except KeyError: |
|
raise TrainerConfigError( |
|
f"Settings for trainer type {d_copy['trainer_type']} were not found" |
|
) |
|
elif key == "max_steps": |
|
d_copy[key] = int(float(val)) |
|
|
|
|
|
|
|
|
|
elif key == "trainer_type": |
|
if val not in all_trainer_types.keys(): |
|
raise TrainerConfigError(f"Invalid trainer type {val} was found") |
|
else: |
|
d_copy[key] = check_and_structure(key, val, t) |
|
return t(**d_copy) |
|
|
|
class DefaultTrainerDict(collections.defaultdict): |
|
def __init__(self, *args): |
|
|
|
|
|
|
|
if args and args[0] == TrainerSettings: |
|
super().__init__(*args) |
|
else: |
|
super().__init__(TrainerSettings, *args) |
|
self._config_specified = True |
|
|
|
def set_config_specified(self, require_config_specified: bool) -> None: |
|
self._config_specified = require_config_specified |
|
|
|
def __missing__(self, key: Any) -> "TrainerSettings": |
|
if TrainerSettings.default_override is not None: |
|
self[key] = copy.deepcopy(TrainerSettings.default_override) |
|
elif self._config_specified: |
|
raise TrainerConfigError( |
|
f"The behavior name {key} has not been specified in the trainer configuration. " |
|
f"Please add an entry in the configuration file for {key}, or set default_settings." |
|
) |
|
else: |
|
logger.warning( |
|
f"Behavior name {key} does not match any behaviors specified " |
|
f"in the trainer configuration file. A default configuration will be used." |
|
) |
|
self[key] = TrainerSettings() |
|
return self[key] |
|
|
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class CheckpointSettings: |
|
run_id: str = parser.get_default("run_id") |
|
initialize_from: Optional[str] = parser.get_default("initialize_from") |
|
load_model: bool = parser.get_default("load_model") |
|
resume: bool = parser.get_default("resume") |
|
force: bool = parser.get_default("force") |
|
train_model: bool = parser.get_default("train_model") |
|
inference: bool = parser.get_default("inference") |
|
results_dir: str = parser.get_default("results_dir") |
|
|
|
@property |
|
def write_path(self) -> str: |
|
return os.path.join(self.results_dir, self.run_id) |
|
|
|
@property |
|
def maybe_init_path(self) -> Optional[str]: |
|
return ( |
|
os.path.join(self.results_dir, self.initialize_from) |
|
if self.initialize_from is not None |
|
else None |
|
) |
|
|
|
@property |
|
def run_logs_dir(self) -> str: |
|
return os.path.join(self.write_path, "run_logs") |
|
|
|
def prioritize_resume_init(self) -> None: |
|
"""Prioritize explicit command line resume/init over conflicting yaml options. |
|
if both resume/init are set at one place use resume""" |
|
_non_default_args = DetectDefault.non_default_args |
|
if "resume" in _non_default_args: |
|
if self.initialize_from is not None: |
|
logger.warning( |
|
f"Both 'resume' and 'initialize_from={self.initialize_from}' are set!" |
|
f" Current run will be resumed ignoring initialization." |
|
) |
|
self.initialize_from = parser.get_default("initialize_from") |
|
elif "initialize_from" in _non_default_args: |
|
if self.resume: |
|
logger.warning( |
|
f"Both 'resume' and 'initialize_from={self.initialize_from}' are set!" |
|
f" {self.run_id} is initialized_from {self.initialize_from} and resume will be ignored." |
|
) |
|
self.resume = parser.get_default("resume") |
|
elif self.resume and self.initialize_from is not None: |
|
|
|
logger.warning( |
|
f"Both 'resume' and 'initialize_from={self.initialize_from}' are set in yaml file!" |
|
f" Current run will be resumed ignoring initialization." |
|
) |
|
self.initialize_from = parser.get_default("initialize_from") |
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class EnvironmentSettings: |
|
env_path: Optional[str] = parser.get_default("env_path") |
|
env_args: Optional[List[str]] = parser.get_default("env_args") |
|
base_port: int = parser.get_default("base_port") |
|
num_envs: int = attr.ib(default=parser.get_default("num_envs")) |
|
num_areas: int = attr.ib(default=parser.get_default("num_areas")) |
|
seed: int = parser.get_default("seed") |
|
max_lifetime_restarts: int = parser.get_default("max_lifetime_restarts") |
|
restarts_rate_limit_n: int = parser.get_default("restarts_rate_limit_n") |
|
restarts_rate_limit_period_s: int = parser.get_default( |
|
"restarts_rate_limit_period_s" |
|
) |
|
|
|
@num_envs.validator |
|
def validate_num_envs(self, attribute, value): |
|
if value > 1 and self.env_path is None: |
|
raise ValueError("num_envs must be 1 if env_path is not set.") |
|
|
|
@num_areas.validator |
|
def validate_num_area(self, attribute, value): |
|
if value <= 0: |
|
raise ValueError("num_areas must be set to a positive number >= 1.") |
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class EngineSettings: |
|
width: int = parser.get_default("width") |
|
height: int = parser.get_default("height") |
|
quality_level: int = parser.get_default("quality_level") |
|
time_scale: float = parser.get_default("time_scale") |
|
target_frame_rate: int = parser.get_default("target_frame_rate") |
|
capture_frame_rate: int = parser.get_default("capture_frame_rate") |
|
no_graphics: bool = parser.get_default("no_graphics") |
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class TorchSettings: |
|
device: Optional[str] = parser.get_default("device") |
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class RunOptions(ExportableSettings): |
|
default_settings: Optional[TrainerSettings] = None |
|
behaviors: TrainerSettings.DefaultTrainerDict = attr.ib( |
|
factory=TrainerSettings.DefaultTrainerDict |
|
) |
|
env_settings: EnvironmentSettings = attr.ib(factory=EnvironmentSettings) |
|
engine_settings: EngineSettings = attr.ib(factory=EngineSettings) |
|
environment_parameters: Optional[Dict[str, EnvironmentParameterSettings]] = None |
|
checkpoint_settings: CheckpointSettings = attr.ib(factory=CheckpointSettings) |
|
torch_settings: TorchSettings = attr.ib(factory=TorchSettings) |
|
|
|
|
|
|
|
debug: bool = parser.get_default("debug") |
|
|
|
|
|
cattr.register_structure_hook(EnvironmentSettings, strict_to_cls) |
|
cattr.register_structure_hook(EngineSettings, strict_to_cls) |
|
cattr.register_structure_hook(CheckpointSettings, strict_to_cls) |
|
cattr.register_structure_hook_func( |
|
lambda t: t == Dict[str, EnvironmentParameterSettings], |
|
EnvironmentParameterSettings.structure, |
|
) |
|
cattr.register_structure_hook(Lesson, strict_to_cls) |
|
cattr.register_structure_hook( |
|
ParameterRandomizationSettings, ParameterRandomizationSettings.structure |
|
) |
|
cattr.register_unstructure_hook( |
|
ParameterRandomizationSettings, ParameterRandomizationSettings.unstructure |
|
) |
|
cattr.register_structure_hook(TrainerSettings, TrainerSettings.structure) |
|
cattr.register_structure_hook( |
|
TrainerSettings.DefaultTrainerDict, TrainerSettings.dict_to_trainerdict |
|
) |
|
cattr.register_unstructure_hook(collections.defaultdict, defaultdict_to_dict) |
|
|
|
@staticmethod |
|
def from_argparse(args: argparse.Namespace) -> "RunOptions": |
|
""" |
|
Takes an argparse.Namespace as specified in `parse_command_line`, loads input configuration files |
|
from file paths, and converts to a RunOptions instance. |
|
:param args: collection of command-line parameters passed to mlagents-learn |
|
:return: RunOptions representing the passed in arguments, with trainer config, curriculum and sampler |
|
configs loaded from files. |
|
""" |
|
argparse_args = vars(args) |
|
config_path = StoreConfigFile.trainer_config_path |
|
|
|
|
|
configured_dict: Dict[str, Any] = { |
|
"checkpoint_settings": {}, |
|
"env_settings": {}, |
|
"engine_settings": {}, |
|
"torch_settings": {}, |
|
} |
|
_require_all_behaviors = True |
|
if config_path is not None: |
|
configured_dict.update(load_config(config_path)) |
|
else: |
|
|
|
_require_all_behaviors = False |
|
|
|
|
|
for key in configured_dict.keys(): |
|
|
|
if key not in attr.fields_dict(RunOptions): |
|
raise TrainerConfigError( |
|
"The option {} was specified in your YAML file, but is invalid.".format( |
|
key |
|
) |
|
) |
|
|
|
|
|
|
|
argparse_args["resume"] = argparse_args["resume"] or argparse_args["load_model"] |
|
|
|
for key, val in argparse_args.items(): |
|
if key in DetectDefault.non_default_args: |
|
if key in attr.fields_dict(CheckpointSettings): |
|
configured_dict["checkpoint_settings"][key] = val |
|
elif key in attr.fields_dict(EnvironmentSettings): |
|
configured_dict["env_settings"][key] = val |
|
elif key in attr.fields_dict(EngineSettings): |
|
configured_dict["engine_settings"][key] = val |
|
elif key in attr.fields_dict(TorchSettings): |
|
configured_dict["torch_settings"][key] = val |
|
else: |
|
configured_dict[key] = val |
|
|
|
final_runoptions = RunOptions.from_dict(configured_dict) |
|
final_runoptions.checkpoint_settings.prioritize_resume_init() |
|
|
|
if isinstance(final_runoptions.behaviors, TrainerSettings.DefaultTrainerDict): |
|
|
|
final_runoptions.behaviors.set_config_specified(_require_all_behaviors) |
|
|
|
_non_default_args = DetectDefault.non_default_args |
|
|
|
|
|
if "deterministic" in _non_default_args: |
|
for behaviour in final_runoptions.behaviors.keys(): |
|
final_runoptions.behaviors[ |
|
behaviour |
|
].network_settings.deterministic = argparse_args["deterministic"] |
|
|
|
return final_runoptions |
|
|
|
@staticmethod |
|
def from_dict( |
|
options_dict: Dict[str, Any], |
|
) -> "RunOptions": |
|
|
|
if ( |
|
"default_settings" in options_dict.keys() |
|
and options_dict["default_settings"] is not None |
|
): |
|
TrainerSettings.default_override = cattr.structure( |
|
options_dict["default_settings"], TrainerSettings |
|
) |
|
return cattr.structure(options_dict, RunOptions) |
|
|