# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import datetime import logging import os.path as osp from typing import Optional from mmengine.fileio import dump from mmengine.logging import print_log from . import root from .default_scope import DefaultScope from .registry import Registry def traverse_registry_tree(registry: Registry, verbose: bool = True) -> list: """Traverse the whole registry tree from any given node, and collect information of all registered modules in this registry tree. Args: registry (Registry): a registry node in the registry tree. verbose (bool): Whether to print log. Defaults to True Returns: list: Statistic results of all modules in each node of the registry tree. """ root_registry = registry.root modules_info = [] def _dfs_registry(_registry): if isinstance(_registry, Registry): num_modules = len(_registry.module_dict) scope = _registry.scope registry_info = dict(num_modules=num_modules, scope=scope) for name, registered_class in _registry.module_dict.items(): folder = '/'.join(registered_class.__module__.split('.')[:-1]) if folder in registry_info: registry_info[folder].append(name) else: registry_info[folder] = [name] if verbose: print_log( f"Find {num_modules} modules in {scope}'s " f"'{_registry.name}' registry ", logger='current') modules_info.append(registry_info) else: return for _, child in _registry.children.items(): _dfs_registry(child) _dfs_registry(root_registry) return modules_info def count_registered_modules(save_path: Optional[str] = None, verbose: bool = True) -> dict: """Scan all modules in MMEngine's root and child registries and dump to json. Args: save_path (str, optional): Path to save the json file. verbose (bool): Whether to print log. Defaults to True. Returns: dict: Statistic results of all registered modules. """ # import modules to trigger registering import mmengine.dataset import mmengine.evaluator import mmengine.hooks import mmengine.model import mmengine.optim import mmengine.runner import mmengine.visualization # noqa: F401 registries_info = {} # traverse all registries in MMEngine for item in dir(root): if not item.startswith('__'): registry = getattr(root, item) if isinstance(registry, Registry): registries_info[item] = traverse_registry_tree( registry, verbose) scan_data = dict( scan_date=datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), registries=registries_info) if verbose: print_log( f'Finish registry analysis, got: {scan_data}', logger='current') if save_path is not None: json_path = osp.join(save_path, 'modules_statistic_results.json') dump(scan_data, json_path, indent=2) print_log(f'Result has been saved to {json_path}', logger='current') return scan_data def init_default_scope(scope: str) -> None: """Initialize the given default scope. Args: scope (str): The name of the default scope. """ never_created = DefaultScope.get_current_instance( ) is None or not DefaultScope.check_instance_created(scope) if never_created: DefaultScope.get_instance(scope, scope_name=scope) return current_scope = DefaultScope.get_current_instance() # type: ignore if current_scope.scope_name != scope: # type: ignore print_log( 'The current default scope ' # type: ignore f'"{current_scope.scope_name}" is not "{scope}", ' '`init_default_scope` will force set the current' f'default scope to "{scope}".', logger='current', level=logging.WARNING) # avoid name conflict new_instance_name = f'{scope}-{datetime.datetime.now()}' DefaultScope.get_instance(new_instance_name, scope_name=scope)