|
from typing import Optional, Union, Tuple |
|
import time |
|
import pickle |
|
from ditk import logging |
|
from multiprocessing import Process, Event |
|
import threading |
|
from easydict import EasyDict |
|
|
|
from ding.worker import create_comm_learner, create_comm_collector, Coordinator |
|
from ding.config import read_config_with_system, compile_config_parallel |
|
from ding.utils import set_pkg_seed |
|
|
|
|
|
def parallel_pipeline( |
|
input_cfg: Union[str, Tuple[dict, dict, dict]], |
|
seed: int, |
|
enable_total_log: Optional[bool] = False, |
|
disable_flask_log: Optional[bool] = True, |
|
) -> None: |
|
r""" |
|
Overview: |
|
Parallel pipeline entry. |
|
Arguments: |
|
- config (:obj:`Union[str, dict]`): Config file path. |
|
- seed (:obj:`int`): Random seed. |
|
- enable_total_log (:obj:`Optional[bool]`): whether enable total DI-engine system log |
|
- disable_flask_log (:obj:`Optional[bool]`): whether disable flask log |
|
""" |
|
|
|
if not enable_total_log: |
|
coordinator_log = logging.getLogger('coordinator_logger') |
|
coordinator_log.disabled = True |
|
|
|
if disable_flask_log: |
|
log = logging.getLogger('werkzeug') |
|
log.disabled = True |
|
|
|
if isinstance(input_cfg, str): |
|
main_cfg, create_cfg, system_cfg = read_config_with_system(input_cfg) |
|
elif isinstance(input_cfg, tuple) or isinstance(input_cfg, list): |
|
main_cfg, create_cfg, system_cfg = input_cfg |
|
else: |
|
raise TypeError("invalid config type: {}".format(input_cfg)) |
|
config = compile_config_parallel(main_cfg, create_cfg=create_cfg, system_cfg=system_cfg, seed=seed) |
|
learner_handle = [] |
|
collector_handle = [] |
|
for k, v in config.system.items(): |
|
if 'learner' in k: |
|
learner_handle.append(launch_learner(config.seed, v)) |
|
elif 'collector' in k: |
|
collector_handle.append(launch_collector(config.seed, v)) |
|
launch_coordinator(config.seed, config, learner_handle=learner_handle, collector_handle=collector_handle) |
|
|
|
|
|
|
|
|
|
|
|
def run_learner(config, seed, start_learner_event, close_learner_event): |
|
set_pkg_seed(seed) |
|
log = logging.getLogger('werkzeug') |
|
log.disabled = True |
|
learner = create_comm_learner(config) |
|
learner.start() |
|
start_learner_event.set() |
|
close_learner_event.wait() |
|
learner.close() |
|
|
|
|
|
def launch_learner( |
|
seed: int, config: Optional[dict] = None, filename: Optional[str] = None, name: Optional[str] = None |
|
) -> list: |
|
if config is None: |
|
with open(filename, 'rb') as f: |
|
config = pickle.load(f)[name] |
|
start_learner_event = Event() |
|
close_learner_event = Event() |
|
|
|
learner_thread = Process( |
|
target=run_learner, args=(config, seed, start_learner_event, close_learner_event), name='learner_entry_process' |
|
) |
|
learner_thread.start() |
|
return learner_thread, start_learner_event, close_learner_event |
|
|
|
|
|
def run_collector(config, seed, start_collector_event, close_collector_event): |
|
set_pkg_seed(seed) |
|
log = logging.getLogger('werkzeug') |
|
log.disabled = True |
|
collector = create_comm_collector(config) |
|
collector.start() |
|
start_collector_event.set() |
|
close_collector_event.wait() |
|
collector.close() |
|
|
|
|
|
def launch_collector( |
|
seed: int, config: Optional[dict] = None, filename: Optional[str] = None, name: Optional[str] = None |
|
) -> list: |
|
if config is None: |
|
with open(filename, 'rb') as f: |
|
config = pickle.load(f)[name] |
|
start_collector_event = Event() |
|
close_collector_event = Event() |
|
|
|
collector_thread = Process( |
|
target=run_collector, |
|
args=(config, seed, start_collector_event, close_collector_event), |
|
name='collector_entry_process' |
|
) |
|
collector_thread.start() |
|
return collector_thread, start_collector_event, close_collector_event |
|
|
|
|
|
def launch_coordinator( |
|
seed: int, |
|
config: Optional[EasyDict] = None, |
|
filename: Optional[str] = None, |
|
learner_handle: Optional[list] = None, |
|
collector_handle: Optional[list] = None |
|
) -> None: |
|
set_pkg_seed(seed) |
|
if config is None: |
|
with open(filename, 'rb') as f: |
|
config = pickle.load(f) |
|
coordinator = Coordinator(config) |
|
for _, start_event, _ in learner_handle: |
|
start_event.wait() |
|
for _, start_event, _ in collector_handle: |
|
start_event.wait() |
|
coordinator.start() |
|
system_shutdown_event = threading.Event() |
|
|
|
|
|
def shutdown_monitor(): |
|
while True: |
|
time.sleep(3) |
|
if coordinator.system_shutdown_flag: |
|
coordinator.close() |
|
for _, _, close_event in learner_handle: |
|
close_event.set() |
|
for _, _, close_event in collector_handle: |
|
close_event.set() |
|
system_shutdown_event.set() |
|
break |
|
|
|
shutdown_monitor_thread = threading.Thread(target=shutdown_monitor, args=(), daemon=True, name='shutdown_monitor') |
|
shutdown_monitor_thread.start() |
|
system_shutdown_event.wait() |
|
print( |
|
"[DI-engine parallel pipeline]Your RL agent is converged, you can refer to 'log' and 'tensorboard' for details" |
|
) |
|
|