from typing import Optional, Union, Tuple import time import pickle from ditk import logging from multiprocessing import Process, Event import threading from easydict import EasyDict from ding.worker import create_comm_learner, create_comm_collector, Coordinator from ding.config import read_config_with_system, compile_config_parallel from ding.utils import set_pkg_seed def parallel_pipeline( input_cfg: Union[str, Tuple[dict, dict, dict]], seed: int, enable_total_log: Optional[bool] = False, disable_flask_log: Optional[bool] = True, ) -> None: r""" Overview: Parallel pipeline entry. Arguments: - config (:obj:`Union[str, dict]`): Config file path. - seed (:obj:`int`): Random seed. - enable_total_log (:obj:`Optional[bool]`): whether enable total DI-engine system log - disable_flask_log (:obj:`Optional[bool]`): whether disable flask log """ # Disable some part of DI-engine log if not enable_total_log: coordinator_log = logging.getLogger('coordinator_logger') coordinator_log.disabled = True # Disable flask logger. if disable_flask_log: log = logging.getLogger('werkzeug') log.disabled = True # Parallel job launch. if isinstance(input_cfg, str): main_cfg, create_cfg, system_cfg = read_config_with_system(input_cfg) elif isinstance(input_cfg, tuple) or isinstance(input_cfg, list): main_cfg, create_cfg, system_cfg = input_cfg else: raise TypeError("invalid config type: {}".format(input_cfg)) config = compile_config_parallel(main_cfg, create_cfg=create_cfg, system_cfg=system_cfg, seed=seed) learner_handle = [] collector_handle = [] for k, v in config.system.items(): if 'learner' in k: learner_handle.append(launch_learner(config.seed, v)) elif 'collector' in k: collector_handle.append(launch_collector(config.seed, v)) launch_coordinator(config.seed, config, learner_handle=learner_handle, collector_handle=collector_handle) # Following functions are used to launch different components(learner, learner aggregator, collector, coordinator). # Argument ``config`` is the dict type config. If it is None, then ``filename`` and ``name`` must be passed, # for they can be used to read corresponding config from file. def run_learner(config, seed, start_learner_event, close_learner_event): set_pkg_seed(seed) log = logging.getLogger('werkzeug') log.disabled = True learner = create_comm_learner(config) learner.start() start_learner_event.set() close_learner_event.wait() learner.close() def launch_learner( seed: int, config: Optional[dict] = None, filename: Optional[str] = None, name: Optional[str] = None ) -> list: if config is None: with open(filename, 'rb') as f: config = pickle.load(f)[name] start_learner_event = Event() close_learner_event = Event() learner_thread = Process( target=run_learner, args=(config, seed, start_learner_event, close_learner_event), name='learner_entry_process' ) learner_thread.start() return learner_thread, start_learner_event, close_learner_event def run_collector(config, seed, start_collector_event, close_collector_event): set_pkg_seed(seed) log = logging.getLogger('werkzeug') log.disabled = True collector = create_comm_collector(config) collector.start() start_collector_event.set() close_collector_event.wait() collector.close() def launch_collector( seed: int, config: Optional[dict] = None, filename: Optional[str] = None, name: Optional[str] = None ) -> list: if config is None: with open(filename, 'rb') as f: config = pickle.load(f)[name] start_collector_event = Event() close_collector_event = Event() collector_thread = Process( target=run_collector, args=(config, seed, start_collector_event, close_collector_event), name='collector_entry_process' ) collector_thread.start() return collector_thread, start_collector_event, close_collector_event def launch_coordinator( seed: int, config: Optional[EasyDict] = None, filename: Optional[str] = None, learner_handle: Optional[list] = None, collector_handle: Optional[list] = None ) -> None: set_pkg_seed(seed) if config is None: with open(filename, 'rb') as f: config = pickle.load(f) coordinator = Coordinator(config) for _, start_event, _ in learner_handle: start_event.wait() for _, start_event, _ in collector_handle: start_event.wait() coordinator.start() system_shutdown_event = threading.Event() # Monitor thread: Coordinator will remain running until its ``system_shutdown_flag`` is set to False. def shutdown_monitor(): while True: time.sleep(3) if coordinator.system_shutdown_flag: coordinator.close() for _, _, close_event in learner_handle: close_event.set() for _, _, close_event in collector_handle: close_event.set() system_shutdown_event.set() break shutdown_monitor_thread = threading.Thread(target=shutdown_monitor, args=(), daemon=True, name='shutdown_monitor') shutdown_monitor_thread.start() system_shutdown_event.wait() print( "[DI-engine parallel pipeline]Your RL agent is converged, you can refer to 'log' and 'tensorboard' for details" )