Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import copy | |
import logging | |
from collections import OrderedDict | |
from typing import TYPE_CHECKING, Any, Optional, Union | |
import numpy as np | |
from mmengine.utils import ManagerMixin | |
from .history_buffer import HistoryBuffer | |
from .logger import print_log | |
if TYPE_CHECKING: | |
import torch | |
class MessageHub(ManagerMixin): | |
"""Message hub for component interaction. MessageHub is created and | |
accessed in the same way as ManagerMixin. | |
``MessageHub`` will record log information and runtime information. The | |
log information refers to the learning rate, loss, etc. of the model | |
during training phase, which will be stored as ``HistoryBuffer``. The | |
runtime information refers to the iter times, meta information of | |
runner etc., which will be overwritten by next update. | |
Args: | |
name (str): Name of message hub used to get corresponding instance | |
globally. | |
log_scalars (dict, optional): Each key-value pair in the | |
dictionary is the name of the log information such as "loss", "lr", | |
"metric" and their corresponding values. The type of value must be | |
HistoryBuffer. Defaults to None. | |
runtime_info (dict, optional): Each key-value pair in the | |
dictionary is the name of the runtime information and their | |
corresponding values. Defaults to None. | |
resumed_keys (dict, optional): Each key-value pair in the | |
dictionary decides whether the key in :attr:`_log_scalars` and | |
:attr:`_runtime_info` will be serialized. | |
Note: | |
Key in :attr:`_resumed_keys` belongs to :attr:`_log_scalars` or | |
:attr:`_runtime_info`. The corresponding value cannot be set | |
repeatedly. | |
Examples: | |
>>> # create empty `MessageHub`. | |
>>> message_hub1 = MessageHub('name') | |
>>> log_scalars = dict(loss=HistoryBuffer()) | |
>>> runtime_info = dict(task='task') | |
>>> resumed_keys = dict(loss=True) | |
>>> # create `MessageHub` from data. | |
>>> message_hub2 = MessageHub( | |
>>> name='name', | |
>>> log_scalars=log_scalars, | |
>>> runtime_info=runtime_info, | |
>>> resumed_keys=resumed_keys) | |
""" | |
def __init__(self, | |
name: str, | |
log_scalars: Optional[dict] = None, | |
runtime_info: Optional[dict] = None, | |
resumed_keys: Optional[dict] = None): | |
super().__init__(name) | |
self._log_scalars = self._parse_input('log_scalars', log_scalars) | |
self._runtime_info = self._parse_input('runtime_info', runtime_info) | |
self._resumed_keys = self._parse_input('resumed_keys', resumed_keys) | |
for value in self._log_scalars.values(): | |
assert isinstance(value, HistoryBuffer), \ | |
("The type of log_scalars'value must be HistoryBuffer, but " | |
f'got {type(value)}') | |
for key in self._resumed_keys.keys(): | |
assert key in self._log_scalars or key in self._runtime_info, \ | |
('Key in `resumed_keys` must contained in `log_scalars` or ' | |
f'`runtime_info`, but got {key}') | |
def get_current_instance(cls) -> 'MessageHub': | |
"""Get latest created ``MessageHub`` instance. | |
:obj:`MessageHub` can call :meth:`get_current_instance` before any | |
instance has been created, and return a message hub with the instance | |
name "mmengine". | |
Returns: | |
MessageHub: Empty ``MessageHub`` instance. | |
""" | |
if not cls._instance_dict: | |
cls.get_instance('mmengine') | |
return super().get_current_instance() | |
def update_scalar(self, | |
key: str, | |
value: Union[int, float, np.ndarray, 'torch.Tensor'], | |
count: int = 1, | |
resumed: bool = True) -> None: | |
"""Update :attr:_log_scalars. | |
Update ``HistoryBuffer`` in :attr:`_log_scalars`. If corresponding key | |
``HistoryBuffer`` has been created, ``value`` and ``count`` is the | |
argument of ``HistoryBuffer.update``, Otherwise, ``update_scalar`` | |
will create an ``HistoryBuffer`` with value and count via the | |
constructor of ``HistoryBuffer``. | |
Examples: | |
>>> message_hub = MessageHub(name='name') | |
>>> # create loss `HistoryBuffer` with value=1, count=1 | |
>>> message_hub.update_scalar('loss', 1) | |
>>> # update loss `HistoryBuffer` with value | |
>>> message_hub.update_scalar('loss', 3) | |
>>> message_hub.update_scalar('loss', 3, resumed=False) | |
AssertionError: loss used to be true, but got false now. resumed | |
keys cannot be modified repeatedly' | |
Note: | |
The ``resumed`` argument needs to be consistent for the same | |
``key``. | |
Args: | |
key (str): Key of ``HistoryBuffer``. | |
value (torch.Tensor or np.ndarray or int or float): Value of log. | |
count (torch.Tensor or np.ndarray or int or float): Accumulation | |
times of log, defaults to 1. `count` will be used in smooth | |
statistics. | |
resumed (str): Whether the corresponding ``HistoryBuffer`` | |
could be resumed. Defaults to True. | |
""" | |
self._set_resumed_keys(key, resumed) | |
checked_value = self._get_valid_value(value) | |
assert isinstance(count, int), ( | |
f'The type of count must be int. but got {type(count): {count}}') | |
if key in self._log_scalars: | |
self._log_scalars[key].update(checked_value, count) | |
else: | |
self._log_scalars[key] = HistoryBuffer([checked_value], [count]) | |
def update_scalars(self, log_dict: dict, resumed: bool = True) -> None: | |
"""Update :attr:`_log_scalars` with a dict. | |
``update_scalars`` iterates through each pair of log_dict key-value, | |
and calls ``update_scalar``. If type of value is dict, the value should | |
be ``dict(value=xxx) or dict(value=xxx, count=xxx)``. Item in | |
``log_dict`` has the same resume option. | |
Note: | |
The ``resumed`` argument needs to be consistent for the same | |
``log_dict``. | |
Args: | |
log_dict (str): Used for batch updating :attr:`_log_scalars`. | |
resumed (bool): Whether all ``HistoryBuffer`` referred in | |
log_dict should be resumed. Defaults to True. | |
Examples: | |
>>> message_hub = MessageHub.get_instance('mmengine') | |
>>> log_dict = dict(a=1, b=2, c=3) | |
>>> message_hub.update_scalars(log_dict) | |
>>> # The default count of `a`, `b` and `c` is 1. | |
>>> log_dict = dict(a=1, b=2, c=dict(value=1, count=2)) | |
>>> message_hub.update_scalars(log_dict) | |
>>> # The count of `c` is 2. | |
""" | |
assert isinstance(log_dict, dict), ('`log_dict` must be a dict!, ' | |
f'but got {type(log_dict)}') | |
for log_name, log_val in log_dict.items(): | |
if isinstance(log_val, dict): | |
assert 'value' in log_val, \ | |
f'value must be defined in {log_val}' | |
count = self._get_valid_value(log_val.get('count', 1)) | |
value = log_val['value'] | |
else: | |
count = 1 | |
value = log_val | |
assert isinstance(count, | |
int), ('The type of count must be int. but got ' | |
f'{type(count): {count}}') | |
self.update_scalar(log_name, value, count, resumed) | |
def update_info(self, key: str, value: Any, resumed: bool = True) -> None: | |
"""Update runtime information. | |
The key corresponding runtime information will be overwritten each | |
time calling ``update_info``. | |
Note: | |
The ``resumed`` argument needs to be consistent for the same | |
``key``. | |
Examples: | |
>>> message_hub = MessageHub(name='name') | |
>>> message_hub.update_info('iter', 100) | |
Args: | |
key (str): Key of runtime information. | |
value (Any): Value of runtime information. | |
resumed (bool): Whether the corresponding ``HistoryBuffer`` | |
could be resumed. | |
""" | |
self._set_resumed_keys(key, resumed) | |
self._runtime_info[key] = value | |
def pop_info(self, key: str, default: Optional[Any] = None) -> Any: | |
"""Remove runtime information by key. If the key does not exist, this | |
method will return the default value. | |
Args: | |
key (str): Key of runtime information. | |
default (Any, optional): The default returned value for the | |
given key. | |
Returns: | |
Any: The runtime information if the key exists. | |
""" | |
return self._runtime_info.pop(key, default) | |
def update_info_dict(self, info_dict: dict, resumed: bool = True) -> None: | |
"""Update runtime information with dictionary. | |
The key corresponding runtime information will be overwritten each | |
time calling ``update_info``. | |
Note: | |
The ``resumed`` argument needs to be consistent for the same | |
``info_dict``. | |
Examples: | |
>>> message_hub = MessageHub(name='name') | |
>>> message_hub.update_info({'iter': 100}) | |
Args: | |
info_dict (str): Runtime information dictionary. | |
resumed (bool): Whether the corresponding ``HistoryBuffer`` | |
could be resumed. | |
""" | |
assert isinstance(info_dict, dict), ('`log_dict` must be a dict!, ' | |
f'but got {type(info_dict)}') | |
for key, value in info_dict.items(): | |
self.update_info(key, value, resumed=resumed) | |
def _set_resumed_keys(self, key: str, resumed: bool) -> None: | |
"""Set corresponding resumed keys. | |
This method is called by ``update_scalar``, ``update_scalars`` and | |
``update_info`` to set the corresponding key is true or false in | |
:attr:`_resumed_keys`. | |
Args: | |
key (str): Key of :attr:`_log_scalrs` or :attr:`_runtime_info`. | |
resumed (bool): Whether the corresponding ``HistoryBuffer`` | |
could be resumed. | |
""" | |
if key not in self._resumed_keys: | |
self._resumed_keys[key] = resumed | |
else: | |
assert self._resumed_keys[key] == resumed, \ | |
f'{key} used to be {self._resumed_keys[key]}, but got ' \ | |
'{resumed} now. resumed keys cannot be modified repeatedly.' | |
def log_scalars(self) -> OrderedDict: | |
"""Get all ``HistoryBuffer`` instances. | |
Note: | |
Considering the large memory footprint of history buffers in the | |
post-training, :meth:`get_scalar` will return a reference of | |
history buffer rather than a copy. | |
Returns: | |
OrderedDict: All ``HistoryBuffer`` instances. | |
""" | |
return self._log_scalars | |
def runtime_info(self) -> OrderedDict: | |
"""Get all runtime information. | |
Returns: | |
OrderedDict: A copy of all runtime information. | |
""" | |
return self._runtime_info | |
def get_scalar(self, key: str) -> HistoryBuffer: | |
"""Get ``HistoryBuffer`` instance by key. | |
Note: | |
Considering the large memory footprint of history buffers in the | |
post-training, :meth:`get_scalar` will not return a reference of | |
history buffer rather than a copy. | |
Args: | |
key (str): Key of ``HistoryBuffer``. | |
Returns: | |
HistoryBuffer: Corresponding ``HistoryBuffer`` instance if the | |
key exists. | |
""" | |
if key not in self.log_scalars: | |
raise KeyError(f'{key} is not found in Messagehub.log_buffers: ' | |
f'instance name is: {MessageHub.instance_name}') | |
return self.log_scalars[key] | |
def get_info(self, key: str, default: Optional[Any] = None) -> Any: | |
"""Get runtime information by key. If the key does not exist, this | |
method will return default information. | |
Args: | |
key (str): Key of runtime information. | |
default (Any, optional): The default returned value for the | |
given key. | |
Returns: | |
Any: A copy of corresponding runtime information if the key exists. | |
""" | |
if key not in self.runtime_info: | |
return default | |
else: | |
# TODO: There are restrictions on objects that can be saved | |
# return copy.deepcopy(self._runtime_info[key]) | |
return self._runtime_info[key] | |
def _get_valid_value( | |
self, | |
value: Union['torch.Tensor', np.ndarray, np.number, int, float], | |
) -> Union[int, float]: | |
"""Convert value to python built-in type. | |
Args: | |
value (torch.Tensor or np.ndarray or np.number or int or float): | |
value of log. | |
Returns: | |
float or int: python built-in type value. | |
""" | |
if isinstance(value, (np.ndarray, np.number)): | |
assert value.size == 1 | |
value = value.item() | |
elif isinstance(value, (int, float)): | |
value = value | |
else: | |
# check whether value is torch.Tensor but don't want | |
# to import torch in this file | |
assert hasattr(value, 'numel') and value.numel() == 1 | |
value = value.item() | |
return value # type: ignore | |
def state_dict(self) -> dict: | |
"""Returns a dictionary containing log scalars, runtime information and | |
resumed keys, which should be resumed. | |
The returned ``state_dict`` can be loaded by :meth:`load_state_dict`. | |
Returns: | |
dict: A dictionary contains ``log_scalars``, ``runtime_info`` and | |
``resumed_keys``. | |
""" | |
saved_scalars = OrderedDict() | |
saved_info = OrderedDict() | |
for key, value in self._log_scalars.items(): | |
if self._resumed_keys.get(key, False): | |
saved_scalars[key] = copy.deepcopy(value) | |
for key, value in self._runtime_info.items(): | |
if self._resumed_keys.get(key, False): | |
try: | |
saved_info[key] = copy.deepcopy(value) | |
except: # noqa: E722 | |
print_log( | |
f'{key} in message_hub cannot be copied, ' | |
f'just return its reference. ', | |
logger='current', | |
level=logging.WARNING) | |
saved_info[key] = value | |
return dict( | |
log_scalars=saved_scalars, | |
runtime_info=saved_info, | |
resumed_keys=self._resumed_keys) | |
def load_state_dict(self, state_dict: Union['MessageHub', dict]) -> None: | |
"""Loads log scalars, runtime information and resumed keys from | |
``state_dict`` or ``message_hub``. | |
If ``state_dict`` is a dictionary returned by :meth:`state_dict`, it | |
will only make copies of data which should be resumed from the source | |
``message_hub``. | |
If ``state_dict`` is a ``message_hub`` instance, it will make copies of | |
all data from the source message_hub. We suggest to load data from | |
``dict`` rather than a ``MessageHub`` instance. | |
Args: | |
state_dict (dict or MessageHub): A dictionary contains key | |
``log_scalars`` ``runtime_info`` and ``resumed_keys``, or a | |
MessageHub instance. | |
""" | |
if isinstance(state_dict, dict): | |
for key in ('log_scalars', 'runtime_info', 'resumed_keys'): | |
assert key in state_dict, ( | |
'The loaded `state_dict` of `MessageHub` must contain ' | |
f'key: `{key}`') | |
# The old `MessageHub` could save non-HistoryBuffer `log_scalars`, | |
# therefore the loaded `log_scalars` needs to be filtered. | |
for key, value in state_dict['log_scalars'].items(): | |
if not isinstance(value, HistoryBuffer): | |
print_log( | |
f'{key} in message_hub is not HistoryBuffer, ' | |
f'just skip resuming it.', | |
logger='current', | |
level=logging.WARNING) | |
continue | |
self.log_scalars[key] = value | |
for key, value in state_dict['runtime_info'].items(): | |
try: | |
self._runtime_info[key] = copy.deepcopy(value) | |
except: # noqa: E722 | |
print_log( | |
f'{key} in message_hub cannot be copied, ' | |
f'just return its reference.', | |
logger='current', | |
level=logging.WARNING) | |
self._runtime_info[key] = value | |
for key, value in state_dict['resumed_keys'].items(): | |
if key not in set(self.log_scalars.keys()) | \ | |
set(self._runtime_info.keys()): | |
print_log( | |
f'resumed key: {key} is not defined in message_hub, ' | |
f'just skip resuming this key.', | |
logger='current', | |
level=logging.WARNING) | |
continue | |
elif not value: | |
print_log( | |
f'Although resumed key: {key} is False, {key} ' | |
'will still be loaded this time. This key will ' | |
'not be saved by the next calling of ' | |
'`MessageHub.state_dict()`', | |
logger='current', | |
level=logging.WARNING) | |
self._resumed_keys[key] = value | |
# Since some checkpoints saved serialized `message_hub` instance, | |
# `load_state_dict` support loading `message_hub` instance for | |
# compatibility | |
else: | |
self._log_scalars = copy.deepcopy(state_dict._log_scalars) | |
self._runtime_info = copy.deepcopy(state_dict._runtime_info) | |
self._resumed_keys = copy.deepcopy(state_dict._resumed_keys) | |
def _parse_input(self, name: str, value: Any) -> OrderedDict: | |
"""Parse input value. | |
Args: | |
name (str): name of input value. | |
value (Any): Input value. | |
Returns: | |
dict: Parsed input value. | |
""" | |
if value is None: | |
return OrderedDict() | |
elif isinstance(value, dict): | |
return OrderedDict(value) | |
else: | |
raise TypeError(f'{name} should be a dict or `None`, but ' | |
f'got {type(name)}') | |