File size: 4,010 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from typing import Optional, List
from gym import utils
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from easydict import EasyDict
from itertools import product
import gym
import copy
import numpy as np
import matplotlib.pyplot as plt
from ding.utils.default_helper import deep_merge_dicts
class AAA():
def __init__(self) -> None:
self.x = 0
def deep_update(
original: dict,
new_dict: dict,
new_keys_allowed: bool = False,
whitelist: Optional[List[str]] = None,
override_all_if_type_changes: Optional[List[str]] = None
):
"""
Overview:
Updates original dict with values from new_dict recursively.
.. note::
If new key is introduced in new_dict, then if new_keys_allowed is not
True, an error will be thrown. Further, for sub-dicts, if the key is
in the whitelist, then new subkeys can be introduced.
Arguments:
- original (:obj:`dict`): Dictionary with default values.
- new_dict (:obj:`dict`): Dictionary with values to be updated
- new_keys_allowed (:obj:`bool`): Whether new keys are allowed.
- whitelist (Optional[List[str]]): List of keys that correspond to dict
values where new subkeys can be introduced. This is only at the top
level.
- override_all_if_type_changes(Optional[List[str]]): List of top level
keys with value=dict, for which we always simply override the
entire value (:obj:`dict`), if the "type" key in that value dict changes.
"""
whitelist = whitelist or []
override_all_if_type_changes = override_all_if_type_changes or []
for k, value in new_dict.items():
if k not in original and not new_keys_allowed:
raise RuntimeError("Unknown config parameter `{}`. Base config have: {}.".format(k, original.keys()))
# Both original value and new one are dicts.
if isinstance(original.get(k), dict) and isinstance(value, dict):
# Check old type vs old one. If different, override entire value.
if k in override_all_if_type_changes and \
"type" in value and "type" in original[k] and \
value["type"] != original[k]["type"]:
original[k] = value
# Whitelisted key -> ok to add new subkeys.
elif k in whitelist:
deep_update(original[k], value, True)
# Non-whitelisted key.
else:
deep_update(original[k], value, new_keys_allowed)
# Original value not a dict OR new value not a dict:
# Override entire value.
else:
original[k] = value
return original
class BaseDriveEnv(gym.Env, utils.EzPickle):
config = dict()
@abstractmethod
def __init__(self, cfg: Dict, **kwargs) -> None:
if 'cfg_type' not in cfg:
self._cfg = self.__class__.default_config()
self._cfg = deep_merge_dicts(self._cfg, cfg)
else:
self._cfg = cfg
utils.EzPickle.__init__(self)
@abstractmethod
def step(self, action: Any) -> Any:
"""
Run one step of the environment and return the observation dict.
"""
raise NotImplementedError
@abstractmethod
def reset(self, *args, **kwargs) -> Any:
"""
Reset current environment.
"""
raise NotImplementedError
@abstractmethod
def close(self) -> None:
"""
Release all resources in environment and close.
"""
raise NotImplementedError
@abstractmethod
def seed(self, seed: int) -> None:
"""
Set random seed.
"""
raise NotImplementedError
@classmethod
def default_config(cls: type) -> EasyDict:
cfg = EasyDict(cls.config)
cfg.cfg_type = cls.__name__ + 'Config'
return copy.deepcopy(cfg)
@abstractmethod
def __repr__(self) -> str:
raise NotImplementedError
|