File size: 21,797 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 |
from typing import Union, Mapping, List, NamedTuple, Tuple, Callable, Optional, Any, Dict
import copy
from ditk import logging
import random
from functools import lru_cache # in python3.9, we can change to cache
import numpy as np
import torch
import treetensor.torch as ttorch
def get_shape0(data: Union[List, Dict, torch.Tensor, ttorch.Tensor]) -> int:
"""
Overview:
Get shape[0] of data's torch tensor or treetensor
Arguments:
- data (:obj:`Union[List,Dict,torch.Tensor,ttorch.Tensor]`): data to be analysed
Returns:
- shape[0] (:obj:`int`): first dimension length of data, usually the batchsize.
"""
if isinstance(data, list) or isinstance(data, tuple):
return get_shape0(data[0])
elif isinstance(data, dict):
for k, v in data.items():
return get_shape0(v)
elif isinstance(data, torch.Tensor):
return data.shape[0]
elif isinstance(data, ttorch.Tensor):
def fn(t):
item = list(t.values())[0]
if np.isscalar(item[0]):
return item[0]
else:
return fn(item)
return fn(data.shape)
else:
raise TypeError("Error in getting shape0, not support type: {}".format(data))
def lists_to_dicts(
data: Union[List[Union[dict, NamedTuple]], Tuple[Union[dict, NamedTuple]]],
recursive: bool = False,
) -> Union[Mapping[object, object], NamedTuple]:
"""
Overview:
Transform a list of dicts to a dict of lists.
Arguments:
- data (:obj:`Union[List[Union[dict, NamedTuple]], Tuple[Union[dict, NamedTuple]]]`):
A dict of lists need to be transformed
- recursive (:obj:`bool`): whether recursively deals with dict element
Returns:
- newdata (:obj:`Union[Mapping[object, object], NamedTuple]`): A list of dicts as a result
Example:
>>> from ding.utils import *
>>> lists_to_dicts([{1: 1, 10: 3}, {1: 2, 10: 4}])
{1: [1, 2], 10: [3, 4]}
"""
if len(data) == 0:
raise ValueError("empty data")
if isinstance(data[0], dict):
if recursive:
new_data = {}
for k in data[0].keys():
if isinstance(data[0][k], dict) and k != 'prev_state':
tmp = [data[b][k] for b in range(len(data))]
new_data[k] = lists_to_dicts(tmp)
else:
new_data[k] = [data[b][k] for b in range(len(data))]
else:
new_data = {k: [data[b][k] for b in range(len(data))] for k in data[0].keys()}
elif isinstance(data[0], tuple) and hasattr(data[0], '_fields'): # namedtuple
new_data = type(data[0])(*list(zip(*data)))
else:
raise TypeError("not support element type: {}".format(type(data[0])))
return new_data
def dicts_to_lists(data: Mapping[object, List[object]]) -> List[Mapping[object, object]]:
"""
Overview:
Transform a dict of lists to a list of dicts.
Arguments:
- data (:obj:`Mapping[object, list]`): A list of dicts need to be transformed
Returns:
- newdata (:obj:`List[Mapping[object, object]]`): A dict of lists as a result
Example:
>>> from ding.utils import *
>>> dicts_to_lists({1: [1, 2], 10: [3, 4]})
[{1: 1, 10: 3}, {1: 2, 10: 4}]
"""
new_data = [v for v in data.values()]
new_data = [{k: v for k, v in zip(data.keys(), t)} for t in list(zip(*new_data))]
return new_data
def override(cls: type) -> Callable[[
Callable,
], Callable]:
"""
Overview:
Annotation for documenting method overrides.
Arguments:
- cls (:obj:`type`): The superclass that provides the overridden method. If this
cls does not actually have the method, an error is raised.
"""
def check_override(method: Callable) -> Callable:
if method.__name__ not in dir(cls):
raise NameError("{} does not override any method of {}".format(method, cls))
return method
return check_override
def squeeze(data: object) -> object:
"""
Overview:
Squeeze data from tuple, list or dict to single object
Arguments:
- data (:obj:`object`): data to be squeezed
Example:
>>> a = (4, )
>>> a = squeeze(a)
>>> print(a)
>>> 4
"""
if isinstance(data, tuple) or isinstance(data, list):
if len(data) == 1:
return data[0]
else:
return tuple(data)
elif isinstance(data, dict):
if len(data) == 1:
return list(data.values())[0]
return data
default_get_set = set()
def default_get(
data: dict,
name: str,
default_value: Optional[Any] = None,
default_fn: Optional[Callable] = None,
judge_fn: Optional[Callable] = None
) -> Any:
"""
Overview:
Getting the value by input, checks generically on the inputs with \
at least ``data`` and ``name``. If ``name`` exists in ``data``, \
get the value at ``name``; else, add ``name`` to ``default_get_set``\
with value generated by \
``default_fn`` (or directly as ``default_value``) that \
is checked by `` judge_fn`` to be legal.
Arguments:
- data(:obj:`dict`): Data input dictionary
- name(:obj:`str`): Key name
- default_value(:obj:`Optional[Any]`) = None,
- default_fn(:obj:`Optional[Callable]`) = Value
- judge_fn(:obj:`Optional[Callable]`) = None
Returns:
- ret(:obj:`list`): Splitted data
- residual(:obj:`list`): Residule list
"""
if name in data:
return data[name]
else:
assert default_value is not None or default_fn is not None
value = default_fn() if default_fn is not None else default_value
if judge_fn:
assert judge_fn(value), "defalut value({}) is not accepted by judge_fn".format(type(value))
if name not in default_get_set:
logging.warning("{} use default value {}".format(name, value))
default_get_set.add(name)
return value
def list_split(data: list, step: int) -> List[list]:
"""
Overview:
Split list of data by step.
Arguments:
- data(:obj:`list`): List of data for spliting
- step(:obj:`int`): Number of step for spliting
Returns:
- ret(:obj:`list`): List of splitted data.
- residual(:obj:`list`): Residule list. This value is ``None`` when ``data`` divides ``steps``.
Example:
>>> list_split([1,2,3,4],2)
([[1, 2], [3, 4]], None)
>>> list_split([1,2,3,4],3)
([[1, 2, 3]], [4])
"""
if len(data) < step:
return [], data
ret = []
divide_num = len(data) // step
for i in range(divide_num):
start, end = i * step, (i + 1) * step
ret.append(data[start:end])
if divide_num * step < len(data):
residual = data[divide_num * step:]
else:
residual = None
return ret, residual
def error_wrapper(fn, default_ret, warning_msg=""):
"""
Overview:
wrap the function, so that any Exception in the function will be catched and return the default_ret
Arguments:
- fn (:obj:`Callable`): the function to be wraped
- default_ret (:obj:`obj`): the default return when an Exception occurred in the function
Returns:
- wrapper (:obj:`Callable`): the wrapped function
Examples:
>>> # Used to checkfor Fakelink (Refer to utils.linklink_dist_helper.py)
>>> def get_rank(): # Get the rank of linklink model, return 0 if use FakeLink.
>>> if is_fake_link:
>>> return 0
>>> return error_wrapper(link.get_rank, 0)()
"""
def wrapper(*args, **kwargs):
try:
ret = fn(*args, **kwargs)
except Exception as e:
ret = default_ret
if warning_msg != "":
one_time_warning(warning_msg, "\ndefault_ret = {}\terror = {}".format(default_ret, e))
return ret
return wrapper
class LimitedSpaceContainer:
"""
Overview:
A space simulator.
Interfaces:
``__init__``, ``get_residual_space``, ``release_space``
"""
def __init__(self, min_val: int, max_val: int) -> None:
"""
Overview:
Set ``min_val`` and ``max_val`` of the container, also set ``cur`` to ``min_val`` for initialization.
Arguments:
- min_val (:obj:`int`): Min volume of the container, usually 0.
- max_val (:obj:`int`): Max volume of the container.
"""
self.min_val = min_val
self.max_val = max_val
assert (max_val >= min_val)
self.cur = self.min_val
def get_residual_space(self) -> int:
"""
Overview:
Get all residual pieces of space. Set ``cur`` to ``max_val``
Arguments:
- ret (:obj:`int`): Residual space, calculated by ``max_val`` - ``cur``.
"""
ret = self.max_val - self.cur
self.cur = self.max_val
return ret
def acquire_space(self) -> bool:
"""
Overview:
Try to get one pice of space. If there is one, return True; Otherwise return False.
Returns:
- flag (:obj:`bool`): Whether there is any piece of residual space.
"""
if self.cur < self.max_val:
self.cur += 1
return True
else:
return False
def release_space(self) -> None:
"""
Overview:
Release only one piece of space. Decrement ``cur``, but ensure it won't be negative.
"""
self.cur = max(self.min_val, self.cur - 1)
def increase_space(self) -> None:
"""
Overview:
Increase one piece in space. Increment ``max_val``.
"""
self.max_val += 1
def decrease_space(self) -> None:
"""
Overview:
Decrease one piece in space. Decrement ``max_val``.
"""
self.max_val -= 1
def deep_merge_dicts(original: dict, new_dict: dict) -> dict:
"""
Overview:
Merge two dicts by calling ``deep_update``
Arguments:
- original (:obj:`dict`): Dict 1.
- new_dict (:obj:`dict`): Dict 2.
Returns:
- merged_dict (:obj:`dict`): A new dict that is d1 and d2 deeply merged.
"""
original = original or {}
new_dict = new_dict or {}
merged = copy.deepcopy(original)
if new_dict: # if new_dict is neither empty dict nor None
deep_update(merged, new_dict, True, [])
return merged
def deep_update(
original: dict,
new_dict: dict,
new_keys_allowed: bool = False,
whitelist: Optional[List[str]] = None,
override_all_if_type_changes: Optional[List[str]] = None
):
"""
Overview:
Update original dict with values from new_dict recursively.
Arguments:
- original (:obj:`dict`): Dictionary with default values.
- new_dict (:obj:`dict`): Dictionary with values to be updated
- new_keys_allowed (:obj:`bool`): Whether new keys are allowed.
- whitelist (:obj:`Optional[List[str]]`):
List of keys that correspond to dict
values where new subkeys can be introduced. This is only at the top
level.
- override_all_if_type_changes(:obj:`Optional[List[str]]`):
List of top level
keys with value=dict, for which we always simply override the
entire value (:obj:`dict`), if the "type" key in that value dict changes.
.. note::
If new key is introduced in new_dict, then if new_keys_allowed is not
True, an error will be thrown. Further, for sub-dicts, if the key is
in the whitelist, then new subkeys can be introduced.
"""
whitelist = whitelist or []
override_all_if_type_changes = override_all_if_type_changes or []
for k, value in new_dict.items():
if k not in original and not new_keys_allowed:
raise RuntimeError("Unknown config parameter `{}`. Base config have: {}.".format(k, original.keys()))
# Both original value and new one are dicts.
if isinstance(original.get(k), dict) and isinstance(value, dict):
# Check old type vs old one. If different, override entire value.
if k in override_all_if_type_changes and \
"type" in value and "type" in original[k] and \
value["type"] != original[k]["type"]:
original[k] = value
# Whitelisted key -> ok to add new subkeys.
elif k in whitelist:
deep_update(original[k], value, True)
# Non-whitelisted key.
else:
deep_update(original[k], value, new_keys_allowed)
# Original value not a dict OR new value not a dict:
# Override entire value.
else:
original[k] = value
return original
def flatten_dict(data: dict, delimiter: str = "/") -> dict:
"""
Overview:
Flatten the dict, see example
Arguments:
- data (:obj:`dict`): Original nested dict
- delimiter (str): Delimiter of the keys of the new dict
Returns:
- data (:obj:`dict`): Flattened nested dict
Example:
>>> a
{'a': {'b': 100}}
>>> flatten_dict(a)
{'a/b': 100}
"""
data = copy.deepcopy(data)
while any(isinstance(v, dict) for v in data.values()):
remove = []
add = {}
for key, value in data.items():
if isinstance(value, dict):
for subkey, v in value.items():
add[delimiter.join([key, subkey])] = v
remove.append(key)
data.update(add)
for k in remove:
del data[k]
return data
def set_pkg_seed(seed: int, use_cuda: bool = True) -> None:
"""
Overview:
Side effect function to set seed for ``random``, ``numpy random``, and ``torch's manual seed``.\
This is usaually used in entry scipt in the section of setting random seed for all package and instance
Argument:
- seed(:obj:`int`): Set seed
- use_cuda(:obj:`bool`) Whether use cude
Examples:
>>> # ../entry/xxxenv_xxxpolicy_main.py
>>> ...
# Set random seed for all package and instance
>>> collector_env.seed(seed)
>>> evaluator_env.seed(seed, dynamic_seed=False)
>>> set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
>>> ...
# Set up RL Policy, etc.
>>> ...
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if use_cuda and torch.cuda.is_available():
torch.cuda.manual_seed(seed)
@lru_cache()
def one_time_warning(warning_msg: str) -> None:
"""
Overview:
Print warning message only once.
Arguments:
- warning_msg (:obj:`str`): Warning message.
"""
logging.warning(warning_msg)
def split_fn(data, indices, start, end):
"""
Overview:
Split data by indices
Arguments:
- data (:obj:`Union[List, Dict, torch.Tensor, ttorch.Tensor]`): data to be analysed
- indices (:obj:`np.ndarray`): indices to split
- start (:obj:`int`): start index
- end (:obj:`int`): end index
"""
if data is None:
return None
elif isinstance(data, list):
return [split_fn(d, indices, start, end) for d in data]
elif isinstance(data, dict):
return {k1: split_fn(v1, indices, start, end) for k1, v1 in data.items()}
elif isinstance(data, str):
return data
else:
return data[indices[start:end]]
def split_data_generator(data: dict, split_size: int, shuffle: bool = True) -> dict:
"""
Overview:
Split data into batches
Arguments:
- data (:obj:`dict`): data to be analysed
- split_size (:obj:`int`): split size
- shuffle (:obj:`bool`): whether shuffle
"""
assert isinstance(data, dict), type(data)
length = []
for k, v in data.items():
if v is None:
continue
elif k in ['prev_state', 'prev_actor_state', 'prev_critic_state']:
length.append(len(v))
elif isinstance(v, list) or isinstance(v, tuple):
if isinstance(v[0], str):
# some buffer data contains useless string infos, such as 'buffer_id',
# which should not be split, so we just skip it
continue
else:
length.append(get_shape0(v[0]))
elif isinstance(v, dict):
length.append(len(v[list(v.keys())[0]]))
else:
length.append(len(v))
assert len(length) > 0
# assert len(set(length)) == 1, "data values must have the same length: {}".format(length)
# if continuous action, data['logit'] is list of length 2
length = length[0]
assert split_size >= 1
if shuffle:
indices = np.random.permutation(length)
else:
indices = np.arange(length)
for i in range(0, length, split_size):
if i + split_size > length:
i = length - split_size
batch = split_fn(data, indices, i, i + split_size)
yield batch
class RunningMeanStd(object):
"""
Overview:
Wrapper to update new variable, new mean, and new count
Interfaces:
``__init__``, ``update``, ``reset``, ``new_shape``
Properties:
- ``mean``, ``std``, ``_epsilon``, ``_shape``, ``_mean``, ``_var``, ``_count``
"""
def __init__(self, epsilon=1e-4, shape=(), device=torch.device('cpu')):
"""
Overview:
Initialize ``self.`` See ``help(type(self))`` for accurate \
signature; setup the properties.
Arguments:
- env (:obj:`gym.Env`): the environment to wrap.
- epsilon (:obj:`Float`): the epsilon used for self for the std output
- shape (:obj: `np.array`): the np array shape used for the expression \
of this wrapper on attibutes of mean and variance
"""
self._epsilon = epsilon
self._shape = shape
self._device = device
self.reset()
def update(self, x):
"""
Overview:
Update mean, variable, and count
Arguments:
- ``x``: the batch
"""
batch_mean = np.mean(x, axis=0)
batch_var = np.var(x, axis=0)
batch_count = x.shape[0]
new_count = batch_count + self._count
mean_delta = batch_mean - self._mean
new_mean = self._mean + mean_delta * batch_count / new_count
# this method for calculating new variable might be numerically unstable
m_a = self._var * self._count
m_b = batch_var * batch_count
m2 = m_a + m_b + np.square(mean_delta) * self._count * batch_count / new_count
new_var = m2 / new_count
self._mean = new_mean
self._var = new_var
self._count = new_count
def reset(self):
"""
Overview:
Resets the state of the environment and reset properties: ``_mean``, ``_var``, ``_count``
"""
if len(self._shape) > 0:
self._mean = np.zeros(self._shape, 'float32')
self._var = np.ones(self._shape, 'float32')
else:
self._mean, self._var = 0., 1.
self._count = self._epsilon
@property
def mean(self) -> np.ndarray:
"""
Overview:
Property ``mean`` gotten from ``self._mean``
"""
if np.isscalar(self._mean):
return self._mean
else:
return torch.FloatTensor(self._mean).to(self._device)
@property
def std(self) -> np.ndarray:
"""
Overview:
Property ``std`` calculated from ``self._var`` and the epsilon value of ``self._epsilon``
"""
std = np.sqrt(self._var + 1e-8)
if np.isscalar(std):
return std
else:
return torch.FloatTensor(std).to(self._device)
@staticmethod
def new_shape(obs_shape, act_shape, rew_shape):
"""
Overview:
Get new shape of observation, acton, and reward; in this case unchanged.
Arguments:
obs_shape (:obj:`Any`), act_shape (:obj:`Any`), rew_shape (:obj:`Any`)
Returns:
obs_shape (:obj:`Any`), act_shape (:obj:`Any`), rew_shape (:obj:`Any`)
"""
return obs_shape, act_shape, rew_shape
def make_key_as_identifier(data: Dict[str, Any]) -> Dict[str, Any]:
"""
Overview:
Make the key of dict into legal python identifier string so that it is
compatible with some python magic method such as ``__getattr``.
Arguments:
- data (:obj:`Dict[str, Any]`): The original dict data.
Return:
- new_data (:obj:`Dict[str, Any]`): The new dict data with legal identifier keys.
"""
def legalization(s: str) -> str:
if s[0].isdigit():
s = '_' + s
return s.replace('.', '_')
new_data = {}
for k in data:
new_k = legalization(k)
new_data[new_k] = data[k]
return new_data
def remove_illegal_item(data: Dict[str, Any]) -> Dict[str, Any]:
"""
Overview:
Remove illegal item in dict info, like str, which is not compatible with Tensor.
Arguments:
- data (:obj:`Dict[str, Any]`): The original dict data.
Return:
- new_data (:obj:`Dict[str, Any]`): The new dict data without legal items.
"""
new_data = {}
for k, v in data.items():
if isinstance(v, str):
continue
new_data[k] = data[k]
return new_data
|