AnnaMats's picture
Second Push
05c9ac2
import datetime
from typing import Dict, NamedTuple, List, Any, Optional, Callable, Set
import cloudpickle
import enum
import time
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.exception import (
UnityCommunicationException,
UnityTimeOutException,
UnityEnvironmentException,
UnityCommunicatorStoppedException,
)
from multiprocessing import Process, Pipe, Queue
from multiprocessing.connection import Connection
from queue import Empty as EmptyQueueException
from mlagents_envs.base_env import BaseEnv, BehaviorName, BehaviorSpec
from mlagents_envs import logging_util
from mlagents.trainers.env_manager import EnvManager, EnvironmentStep, AllStepResult
from mlagents.trainers.settings import TrainerSettings
from mlagents_envs.timers import (
TimerNode,
timed,
hierarchical_timer,
reset_timers,
get_timer_root,
)
from mlagents.trainers.settings import ParameterRandomizationSettings, RunOptions
from mlagents.trainers.action_info import ActionInfo
from mlagents_envs.side_channel.environment_parameters_channel import (
EnvironmentParametersChannel,
)
from mlagents_envs.side_channel.engine_configuration_channel import (
EngineConfigurationChannel,
EngineConfig,
)
from mlagents_envs.side_channel.stats_side_channel import (
EnvironmentStats,
StatsSideChannel,
)
from mlagents.trainers.training_analytics_side_channel import (
TrainingAnalyticsSideChannel,
)
from mlagents_envs.side_channel.side_channel import SideChannel
logger = logging_util.get_logger(__name__)
WORKER_SHUTDOWN_TIMEOUT_S = 10
class EnvironmentCommand(enum.Enum):
STEP = 1
BEHAVIOR_SPECS = 2
ENVIRONMENT_PARAMETERS = 3
RESET = 4
CLOSE = 5
ENV_EXITED = 6
CLOSED = 7
TRAINING_STARTED = 8
class EnvironmentRequest(NamedTuple):
cmd: EnvironmentCommand
payload: Any = None
class EnvironmentResponse(NamedTuple):
cmd: EnvironmentCommand
worker_id: int
payload: Any
class StepResponse(NamedTuple):
all_step_result: AllStepResult
timer_root: Optional[TimerNode]
environment_stats: EnvironmentStats
class UnityEnvWorker:
def __init__(self, process: Process, worker_id: int, conn: Connection):
self.process = process
self.worker_id = worker_id
self.conn = conn
self.previous_step: EnvironmentStep = EnvironmentStep.empty(worker_id)
self.previous_all_action_info: Dict[str, ActionInfo] = {}
self.waiting = False
self.closed = False
def send(self, cmd: EnvironmentCommand, payload: Any = None) -> None:
try:
req = EnvironmentRequest(cmd, payload)
self.conn.send(req)
except (BrokenPipeError, EOFError):
raise UnityCommunicationException("UnityEnvironment worker: send failed.")
def recv(self) -> EnvironmentResponse:
try:
response: EnvironmentResponse = self.conn.recv()
if response.cmd == EnvironmentCommand.ENV_EXITED:
env_exception: Exception = response.payload
raise env_exception
return response
except (BrokenPipeError, EOFError):
raise UnityCommunicationException("UnityEnvironment worker: recv failed.")
def request_close(self):
try:
self.conn.send(EnvironmentRequest(EnvironmentCommand.CLOSE))
except (BrokenPipeError, EOFError):
logger.debug(
f"UnityEnvWorker {self.worker_id} got exception trying to close."
)
pass
def worker(
parent_conn: Connection,
step_queue: Queue,
pickled_env_factory: str,
worker_id: int,
run_options: RunOptions,
log_level: int = logging_util.INFO,
) -> None:
env_factory: Callable[
[int, List[SideChannel]], UnityEnvironment
] = cloudpickle.loads(pickled_env_factory)
env_parameters = EnvironmentParametersChannel()
engine_config = EngineConfig(
width=run_options.engine_settings.width,
height=run_options.engine_settings.height,
quality_level=run_options.engine_settings.quality_level,
time_scale=run_options.engine_settings.time_scale,
target_frame_rate=run_options.engine_settings.target_frame_rate,
capture_frame_rate=run_options.engine_settings.capture_frame_rate,
)
engine_configuration_channel = EngineConfigurationChannel()
engine_configuration_channel.set_configuration(engine_config)
stats_channel = StatsSideChannel()
training_analytics_channel: Optional[TrainingAnalyticsSideChannel] = None
if worker_id == 0:
training_analytics_channel = TrainingAnalyticsSideChannel()
env: UnityEnvironment = None
# Set log level. On some platforms, the logger isn't common with the
# main process, so we need to set it again.
logging_util.set_log_level(log_level)
def _send_response(cmd_name: EnvironmentCommand, payload: Any) -> None:
parent_conn.send(EnvironmentResponse(cmd_name, worker_id, payload))
def _generate_all_results() -> AllStepResult:
all_step_result: AllStepResult = {}
for brain_name in env.behavior_specs:
all_step_result[brain_name] = env.get_steps(brain_name)
return all_step_result
try:
side_channels = [env_parameters, engine_configuration_channel, stats_channel]
if training_analytics_channel is not None:
side_channels.append(training_analytics_channel)
env = env_factory(worker_id, side_channels)
if (
not env.academy_capabilities
or not env.academy_capabilities.trainingAnalytics
):
# Make sure we don't try to send training analytics if the environment doesn't know how to process
# them. This wouldn't be catastrophic, but would result in unknown SideChannel UUIDs being used.
training_analytics_channel = None
if training_analytics_channel:
training_analytics_channel.environment_initialized(run_options)
while True:
req: EnvironmentRequest = parent_conn.recv()
if req.cmd == EnvironmentCommand.STEP:
all_action_info = req.payload
for brain_name, action_info in all_action_info.items():
if len(action_info.agent_ids) > 0:
env.set_actions(brain_name, action_info.env_action)
env.step()
all_step_result = _generate_all_results()
# The timers in this process are independent from all the processes and the "main" process
# So after we send back the root timer, we can safely clear them.
# Note that we could randomly return timers a fraction of the time if we wanted to reduce
# the data transferred.
# TODO get gauges from the workers and merge them in the main process too.
env_stats = stats_channel.get_and_reset_stats()
step_response = StepResponse(
all_step_result, get_timer_root(), env_stats
)
step_queue.put(
EnvironmentResponse(
EnvironmentCommand.STEP, worker_id, step_response
)
)
reset_timers()
elif req.cmd == EnvironmentCommand.BEHAVIOR_SPECS:
_send_response(EnvironmentCommand.BEHAVIOR_SPECS, env.behavior_specs)
elif req.cmd == EnvironmentCommand.ENVIRONMENT_PARAMETERS:
for k, v in req.payload.items():
if isinstance(v, ParameterRandomizationSettings):
v.apply(k, env_parameters)
elif req.cmd == EnvironmentCommand.TRAINING_STARTED:
behavior_name, trainer_config = req.payload
if training_analytics_channel:
training_analytics_channel.training_started(
behavior_name, trainer_config
)
elif req.cmd == EnvironmentCommand.RESET:
env.reset()
all_step_result = _generate_all_results()
_send_response(EnvironmentCommand.RESET, all_step_result)
elif req.cmd == EnvironmentCommand.CLOSE:
break
except (
KeyboardInterrupt,
UnityCommunicationException,
UnityTimeOutException,
UnityEnvironmentException,
UnityCommunicatorStoppedException,
) as ex:
logger.debug(f"UnityEnvironment worker {worker_id}: environment stopping.")
step_queue.put(
EnvironmentResponse(EnvironmentCommand.ENV_EXITED, worker_id, ex)
)
_send_response(EnvironmentCommand.ENV_EXITED, ex)
except Exception as ex:
logger.exception(
f"UnityEnvironment worker {worker_id}: environment raised an unexpected exception."
)
step_queue.put(
EnvironmentResponse(EnvironmentCommand.ENV_EXITED, worker_id, ex)
)
_send_response(EnvironmentCommand.ENV_EXITED, ex)
finally:
logger.debug(f"UnityEnvironment worker {worker_id} closing.")
if env is not None:
env.close()
logger.debug(f"UnityEnvironment worker {worker_id} done.")
parent_conn.close()
step_queue.put(EnvironmentResponse(EnvironmentCommand.CLOSED, worker_id, None))
step_queue.close()
class SubprocessEnvManager(EnvManager):
def __init__(
self,
env_factory: Callable[[int, List[SideChannel]], BaseEnv],
run_options: RunOptions,
n_env: int = 1,
):
super().__init__()
self.env_workers: List[UnityEnvWorker] = []
self.step_queue: Queue = Queue()
self.workers_alive = 0
self.env_factory = env_factory
self.run_options = run_options
self.env_parameters: Optional[Dict] = None
# Each worker is correlated with a list of times they restarted within the last time period.
self.recent_restart_timestamps: List[List[datetime.datetime]] = [
[] for _ in range(n_env)
]
self.restart_counts: List[int] = [0] * n_env
for worker_idx in range(n_env):
self.env_workers.append(
self.create_worker(
worker_idx, self.step_queue, env_factory, run_options
)
)
self.workers_alive += 1
@staticmethod
def create_worker(
worker_id: int,
step_queue: Queue,
env_factory: Callable[[int, List[SideChannel]], BaseEnv],
run_options: RunOptions,
) -> UnityEnvWorker:
parent_conn, child_conn = Pipe()
# Need to use cloudpickle for the env factory function since function objects aren't picklable
# on Windows as of Python 3.6.
pickled_env_factory = cloudpickle.dumps(env_factory)
child_process = Process(
target=worker,
args=(
child_conn,
step_queue,
pickled_env_factory,
worker_id,
run_options,
logger.level,
),
)
child_process.start()
return UnityEnvWorker(child_process, worker_id, parent_conn)
def _queue_steps(self) -> None:
for env_worker in self.env_workers:
if not env_worker.waiting:
env_action_info = self._take_step(env_worker.previous_step)
env_worker.previous_all_action_info = env_action_info
env_worker.send(EnvironmentCommand.STEP, env_action_info)
env_worker.waiting = True
def _restart_failed_workers(self, first_failure: EnvironmentResponse) -> None:
if first_failure.cmd != EnvironmentCommand.ENV_EXITED:
return
# Drain the step queue to make sure all workers are paused and we have found all concurrent errors.
# Pausing all training is needed since we need to reset all pending training steps as they could be corrupted.
other_failures: Dict[int, Exception] = self._drain_step_queue()
# TODO: Once we use python 3.9 switch to using the | operator to combine dicts.
failures: Dict[int, Exception] = {
**{first_failure.worker_id: first_failure.payload},
**other_failures,
}
for worker_id, ex in failures.items():
self._assert_worker_can_restart(worker_id, ex)
logger.warning(f"Restarting worker[{worker_id}] after '{ex}'")
self.recent_restart_timestamps[worker_id].append(datetime.datetime.now())
self.restart_counts[worker_id] += 1
self.env_workers[worker_id] = self.create_worker(
worker_id, self.step_queue, self.env_factory, self.run_options
)
# The restarts were successful, clear all the existing training trajectories so we don't use corrupted or
# outdated data.
self.reset(self.env_parameters)
def _drain_step_queue(self) -> Dict[int, Exception]:
"""
Drains all steps out of the step queue and returns all exceptions from crashed workers.
This will effectively pause all workers so that they won't do anything until _queue_steps is called.
"""
all_failures = {}
workers_still_pending = {w.worker_id for w in self.env_workers if w.waiting}
deadline = datetime.datetime.now() + datetime.timedelta(minutes=1)
while workers_still_pending and deadline > datetime.datetime.now():
try:
while True:
step: EnvironmentResponse = self.step_queue.get_nowait()
if step.cmd == EnvironmentCommand.ENV_EXITED:
workers_still_pending.add(step.worker_id)
all_failures[step.worker_id] = step.payload
else:
workers_still_pending.remove(step.worker_id)
self.env_workers[step.worker_id].waiting = False
except EmptyQueueException:
pass
if deadline < datetime.datetime.now():
still_waiting = {w.worker_id for w in self.env_workers if w.waiting}
raise TimeoutError(f"Workers {still_waiting} stuck in waiting state")
return all_failures
def _assert_worker_can_restart(self, worker_id: int, exception: Exception) -> None:
"""
Checks if we can recover from an exception from a worker.
If the restart limit is exceeded it will raise a UnityCommunicationException.
If the exception is not recoverable it re-raises the exception.
"""
if (
isinstance(exception, UnityCommunicationException)
or isinstance(exception, UnityTimeOutException)
or isinstance(exception, UnityEnvironmentException)
or isinstance(exception, UnityCommunicatorStoppedException)
):
if self._worker_has_restart_quota(worker_id):
return
else:
logger.error(
f"Worker {worker_id} exceeded the allowed number of restarts."
)
raise exception
raise exception
def _worker_has_restart_quota(self, worker_id: int) -> bool:
self._drop_old_restart_timestamps(worker_id)
max_lifetime_restarts = self.run_options.env_settings.max_lifetime_restarts
max_limit_check = (
max_lifetime_restarts == -1
or self.restart_counts[worker_id] < max_lifetime_restarts
)
rate_limit_n = self.run_options.env_settings.restarts_rate_limit_n
rate_limit_check = (
rate_limit_n == -1
or len(self.recent_restart_timestamps[worker_id]) < rate_limit_n
)
return rate_limit_check and max_limit_check
def _drop_old_restart_timestamps(self, worker_id: int) -> None:
"""
Drops environment restart timestamps that are outside of the current window.
"""
def _filter(t: datetime.datetime) -> bool:
return t > datetime.datetime.now() - datetime.timedelta(
seconds=self.run_options.env_settings.restarts_rate_limit_period_s
)
self.recent_restart_timestamps[worker_id] = list(
filter(_filter, self.recent_restart_timestamps[worker_id])
)
def _step(self) -> List[EnvironmentStep]:
# Queue steps for any workers which aren't in the "waiting" state.
self._queue_steps()
worker_steps: List[EnvironmentResponse] = []
step_workers: Set[int] = set()
# Poll the step queue for completed steps from environment workers until we retrieve
# 1 or more, which we will then return as StepInfos
while len(worker_steps) < 1:
try:
while True:
step: EnvironmentResponse = self.step_queue.get_nowait()
if step.cmd == EnvironmentCommand.ENV_EXITED:
# If even one env exits try to restart all envs that failed.
self._restart_failed_workers(step)
# Clear state and restart this function.
worker_steps.clear()
step_workers.clear()
self._queue_steps()
elif step.worker_id not in step_workers:
self.env_workers[step.worker_id].waiting = False
worker_steps.append(step)
step_workers.add(step.worker_id)
except EmptyQueueException:
pass
step_infos = self._postprocess_steps(worker_steps)
return step_infos
def _reset_env(self, config: Optional[Dict] = None) -> List[EnvironmentStep]:
while any(ew.waiting for ew in self.env_workers):
if not self.step_queue.empty():
step = self.step_queue.get_nowait()
self.env_workers[step.worker_id].waiting = False
# Send config to environment
self.set_env_parameters(config)
# First enqueue reset commands for all workers so that they reset in parallel
for ew in self.env_workers:
ew.send(EnvironmentCommand.RESET, config)
# Next (synchronously) collect the reset observations from each worker in sequence
for ew in self.env_workers:
ew.previous_step = EnvironmentStep(ew.recv().payload, ew.worker_id, {}, {})
return list(map(lambda ew: ew.previous_step, self.env_workers))
def set_env_parameters(self, config: Dict = None) -> None:
"""
Sends environment parameter settings to C# via the
EnvironmentParametersSidehannel for each worker.
:param config: Dict of environment parameter keys and values
"""
self.env_parameters = config
for ew in self.env_workers:
ew.send(EnvironmentCommand.ENVIRONMENT_PARAMETERS, config)
def on_training_started(
self, behavior_name: str, trainer_settings: TrainerSettings
) -> None:
"""
Handle traing starting for a new behavior type. Generally nothing is necessary here.
:param behavior_name:
:param trainer_settings:
:return:
"""
for ew in self.env_workers:
ew.send(
EnvironmentCommand.TRAINING_STARTED, (behavior_name, trainer_settings)
)
@property
def training_behaviors(self) -> Dict[BehaviorName, BehaviorSpec]:
result: Dict[BehaviorName, BehaviorSpec] = {}
for worker in self.env_workers:
worker.send(EnvironmentCommand.BEHAVIOR_SPECS)
result.update(worker.recv().payload)
return result
def close(self) -> None:
logger.debug("SubprocessEnvManager closing.")
for env_worker in self.env_workers:
env_worker.request_close()
# Pull messages out of the queue until every worker has CLOSED or we time out.
deadline = time.time() + WORKER_SHUTDOWN_TIMEOUT_S
while self.workers_alive > 0 and time.time() < deadline:
try:
step: EnvironmentResponse = self.step_queue.get_nowait()
env_worker = self.env_workers[step.worker_id]
if step.cmd == EnvironmentCommand.CLOSED and not env_worker.closed:
env_worker.closed = True
self.workers_alive -= 1
# Discard all other messages.
except EmptyQueueException:
pass
self.step_queue.close()
# Sanity check to kill zombie workers and report an issue if they occur.
if self.workers_alive > 0:
logger.error("SubprocessEnvManager had workers that didn't signal shutdown")
for env_worker in self.env_workers:
if not env_worker.closed and env_worker.process.is_alive():
env_worker.process.terminate()
logger.error(
"A SubprocessEnvManager worker did not shut down correctly so it was forcefully terminated."
)
self.step_queue.join_thread()
def _postprocess_steps(
self, env_steps: List[EnvironmentResponse]
) -> List[EnvironmentStep]:
step_infos = []
timer_nodes = []
for step in env_steps:
payload: StepResponse = step.payload
env_worker = self.env_workers[step.worker_id]
new_step = EnvironmentStep(
payload.all_step_result,
step.worker_id,
env_worker.previous_all_action_info,
payload.environment_stats,
)
step_infos.append(new_step)
env_worker.previous_step = new_step
if payload.timer_root:
timer_nodes.append(payload.timer_root)
if timer_nodes:
with hierarchical_timer("workers") as main_timer_node:
for worker_timer_node in timer_nodes:
main_timer_node.merge(
worker_timer_node, root_name="worker_root", is_parallel=True
)
return step_infos
@timed
def _take_step(self, last_step: EnvironmentStep) -> Dict[BehaviorName, ActionInfo]:
all_action_info: Dict[str, ActionInfo] = {}
for brain_name, step_tuple in last_step.current_all_step_result.items():
if brain_name in self.policies:
all_action_info[brain_name] = self.policies[brain_name].get_action(
step_tuple[0], last_step.worker_id
)
return all_action_info