File size: 19,342 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 |
import os
from functools import partial
from typing import Optional, Union, List
import numpy as np
import torch
from ding.bonus.common import TrainingReturn, EvalReturn
from ding.config import save_config_py, compile_config
from ding.envs import create_env_manager
from ding.envs import get_vec_env_setting
from ding.policy import create_policy
from ding.rl_utils import get_epsilon_greedy_fn
from ding.utils import set_pkg_seed, get_rank
from ding.worker import BaseLearner
from ditk import logging
from easydict import EasyDict
from tensorboardX import SummaryWriter
from lzero.agent.config.muzero import supported_env_cfg
from lzero.entry.utils import log_buffer_memory_usage, random_collect
from lzero.mcts import MuZeroGameBuffer
from lzero.policy import visit_count_temperature
from lzero.policy.muzero import MuZeroPolicy
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroCollector as Collector
from lzero.worker import MuZeroEvaluator as Evaluator
class MuZeroAgent:
"""
Overview:
Agent class for executing MuZero algorithms which include methods for training, deployment, and batch evaluation.
Interfaces:
__init__, train, deploy, batch_evaluate
Properties:
best
.. note::
This agent class is tailored for use with the HuggingFace Model Zoo for LightZero
(e.g. https://huggingface.co/OpenDILabCommunity/CartPole-v0-MuZero),
and provides methods such as "train" and "deploy".
"""
supported_env_list = list(supported_env_cfg.keys())
def __init__(
self,
env_id: str = None,
seed: int = 0,
exp_name: str = None,
model: Optional[torch.nn.Module] = None,
cfg: Optional[Union[EasyDict, dict]] = None,
policy_state_dict: str = None,
) -> None:
"""
Overview:
Initialize the MuZeroAgent instance with environment parameters, model, and configuration.
Arguments:
- env_id (:obj:`str`): Identifier for the environment to be used, registered in gym.
- seed (:obj:`int`): Random seed for reproducibility. Defaults to 0.
- exp_name (:obj:`Optional[str]`): Name for the experiment. Defaults to None.
- model (:obj:`Optional[torch.nn.Module]`): PyTorch module to be used as the model. If None, a default model is created. Defaults to None.
- cfg (:obj:`Optional[Union[EasyDict, dict]]`): Configuration for the agent. If None, default configuration will be used. Defaults to None.
- policy_state_dict (:obj:`Optional[str]`): Path to a pre-trained model state dictionary. If provided, state dict will be loaded. Defaults to None.
.. note::
- If `env_id` is not specified, it must be included in `cfg`.
- The `supported_env_list` contains all the environment IDs that are supported by this agent.
"""
assert env_id is not None or cfg["main_config"]["env_id"] is not None, "Please specify env_id or cfg."
if cfg is not None and not isinstance(cfg, EasyDict):
cfg = EasyDict(cfg)
if env_id is not None:
assert env_id in MuZeroAgent.supported_env_list, "Please use supported envs: {}".format(
MuZeroAgent.supported_env_list
)
if cfg is None:
cfg = supported_env_cfg[env_id]
else:
assert cfg.main_config.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
else:
assert hasattr(cfg.main_config.env, "env_id"), "Please specify env_id in cfg."
assert cfg.main_config.env.env_id in MuZeroAgent.supported_env_list, "Please use supported envs: {}".format(
MuZeroAgent.supported_env_list
)
default_policy_config = EasyDict({"policy": MuZeroPolicy.default_config()})
default_policy_config.policy.update(cfg.main_config.policy)
cfg.main_config.policy = default_policy_config.policy
if exp_name is not None:
cfg.main_config.exp_name = exp_name
self.origin_cfg = cfg
self.cfg = compile_config(
cfg.main_config, seed=seed, env=None, auto=True, policy=MuZeroPolicy, create_cfg=cfg.create_config
)
self.exp_name = self.cfg.exp_name
logging.getLogger().setLevel(logging.INFO)
self.seed = seed
set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
if not os.path.exists(self.exp_name):
os.makedirs(self.exp_name)
save_config_py(cfg, os.path.join(self.exp_name, 'policy_config.py'))
if model is None:
if self.cfg.policy.model.model_type == 'mlp':
from lzero.model.muzero_model_mlp import MuZeroModelMLP
model = MuZeroModelMLP(**self.cfg.policy.model)
elif self.cfg.policy.model.model_type == 'conv':
from lzero.model.muzero_model import MuZeroModel
model = MuZeroModel(**self.cfg.policy.model)
else:
raise NotImplementedError
if self.cfg.policy.cuda and torch.cuda.is_available():
self.cfg.policy.device = 'cuda'
else:
self.cfg.policy.device = 'cpu'
self.policy = create_policy(self.cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])
if policy_state_dict is not None:
self.policy.learn_mode.load_state_dict(policy_state_dict)
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
self.env_fn, self.collector_env_cfg, self.evaluator_env_cfg = get_vec_env_setting(self.cfg.env)
def train(
self,
step: int = int(1e7),
) -> TrainingReturn:
"""
Overview:
Train the agent through interactions with the environment.
Arguments:
- step (:obj:`int`): Total number of environment steps to train for. Defaults to 10 million (1e7).
Returns:
- A `TrainingReturn` object containing training information, such as logs and potentially a URL to a training dashboard.
.. note::
The method involves interacting with the environment, collecting experience, and optimizing the model.
"""
collector_env = create_env_manager(
self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.collector_env_cfg]
)
evaluator_env = create_env_manager(
self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.evaluator_env_cfg]
)
collector_env.seed(self.cfg.seed)
evaluator_env.seed(self.cfg.seed, dynamic_seed=False)
set_pkg_seed(self.cfg.seed, use_cuda=self.cfg.policy.cuda)
# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(self.cfg.exp_name), 'serial')
) if get_rank() == 0 else None
learner = BaseLearner(
self.cfg.policy.learn.learner, self.policy.learn_mode, tb_logger, exp_name=self.cfg.exp_name
)
# ==============================================================
# MCTS+RL algorithms related core code
# ==============================================================
policy_config = self.cfg.policy
batch_size = policy_config.batch_size
# specific game buffer for MCTS+RL algorithms
replay_buffer = MuZeroGameBuffer(policy_config)
collector = Collector(
env=collector_env,
policy=self.policy.collect_mode,
tb_logger=tb_logger,
exp_name=self.cfg.exp_name,
policy_config=policy_config
)
evaluator = Evaluator(
eval_freq=self.cfg.policy.eval_freq,
n_evaluator_episode=self.cfg.env.n_evaluator_episode,
stop_value=self.cfg.env.stop_value,
env=evaluator_env,
policy=self.policy.eval_mode,
tb_logger=tb_logger,
exp_name=self.cfg.exp_name,
policy_config=policy_config
)
# ==============================================================
# Main loop
# ==============================================================
# Learner's before_run hook.
learner.call_hook('before_run')
if self.cfg.policy.update_per_collect is not None:
update_per_collect = self.cfg.policy.update_per_collect
# The purpose of collecting random data before training:
# Exploration: Collecting random data helps the agent explore the environment and avoid getting stuck in a suboptimal policy prematurely.
# Comparison: By observing the agent's performance during random action-taking, we can establish a baseline to evaluate the effectiveness of reinforcement learning algorithms.
if self.cfg.policy.random_collect_episode_num > 0:
random_collect(self.cfg.policy, self.policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)
while True:
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
collect_kwargs = {}
# set temperature for visit count distributions according to the train_iter,
# please refer to Appendix D in MuZero paper for details.
collect_kwargs['temperature'] = visit_count_temperature(
policy_config.manual_temperature_decay,
policy_config.fixed_temperature_value,
policy_config.threshold_training_steps_for_final_temperature,
trained_steps=learner.train_iter
)
if policy_config.eps.eps_greedy_exploration_in_collect:
epsilon_greedy_fn = get_epsilon_greedy_fn(
start=policy_config.eps.start,
end=policy_config.eps.end,
decay=policy_config.eps.decay,
type_=policy_config.eps.type
)
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)
else:
collect_kwargs['epsilon'] = 0.0
# Evaluate policy performance.
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
# Collect data by default config n_sample/n_episode.
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
if self.cfg.policy.update_per_collect is None:
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio)
# save returned new_data collected by the collector
replay_buffer.push_game_segments(new_data)
# remove the oldest data if the replay buffer is full.
replay_buffer.remove_oldest_data_to_fit()
# Learn policy from collected data.
for i in range(update_per_collect):
# Learner will train ``update_per_collect`` times in one iteration.
if replay_buffer.get_num_of_transitions() > batch_size:
train_data = replay_buffer.sample(batch_size, self.policy)
else:
logging.warning(
f'The data in replay_buffer is not sufficient to sample a mini-batch: '
f'batch_size: {batch_size}, '
f'{replay_buffer} '
f'continue to collect now ....'
)
break
# The core train steps for MCTS+RL algorithms.
log_vars = learner.train(train_data, collector.envstep)
if self.cfg.policy.use_priority:
replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])
if collector.envstep >= step:
break
# Learner's after_run hook.
learner.call_hook('after_run')
return TrainingReturn(wandb_url=None)
def deploy(
self,
enable_save_replay: bool = False,
concatenate_all_replay: bool = False,
replay_save_path: str = None,
seed: Optional[Union[int, List]] = None,
debug: bool = False
) -> EvalReturn:
"""
Overview:
Deploy the agent for evaluation in the environment, with optional replay saving. The performance of the
agent will be evaluated. Average return and standard deviation of the return will be returned.
If `enable_save_replay` is True, replay videos are saved in the specified `replay_save_path`.
Arguments:
- enable_save_replay (:obj:`bool`): Flag to enable saving of replay footage. Defaults to False.
- concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one file. Defaults to False.
- replay_save_path (:obj:`Optional[str]`): Directory path to save replay videos. Defaults to None, which sets a default path.
- seed (:obj:`Optional[Union[int, List[int]]]`): Seed or list of seeds for environment reproducibility. Defaults to None.
- debug (:obj:`bool`): Whether to enable the debug mode. Default to False.
Returns:
- An `EvalReturn` object containing evaluation metrics such as mean and standard deviation of returns.
"""
deply_configs = [self.evaluator_env_cfg[0]]
if type(seed) == int:
seed_list = [seed]
elif seed:
seed_list = seed
else:
seed_list = [0]
reward_list = []
if enable_save_replay:
replay_save_path = replay_save_path if replay_save_path is not None else os.path.join(
self.exp_name, 'videos'
)
deply_configs[0]['replay_path'] = replay_save_path
for seed in seed_list:
evaluator_env = create_env_manager(self.cfg.env.manager, [partial(self.env_fn, cfg=deply_configs[0])])
evaluator_env.seed(seed if seed is not None else self.cfg.seed, dynamic_seed=False)
set_pkg_seed(seed if seed is not None else self.cfg.seed, use_cuda=self.cfg.policy.cuda)
# ==============================================================
# MCTS+RL algorithms related core code
# ==============================================================
policy_config = self.cfg.policy
evaluator = Evaluator(
eval_freq=self.cfg.policy.eval_freq,
n_evaluator_episode=1,
stop_value=self.cfg.env.stop_value,
env=evaluator_env,
policy=self.policy.eval_mode,
exp_name=self.cfg.exp_name,
policy_config=policy_config
)
# ==============================================================
# Main loop
# ==============================================================
stop, reward = evaluator.eval()
reward_list.extend(reward['eval_episode_return'])
if enable_save_replay:
files = os.listdir(replay_save_path)
files = [file for file in files if file.endswith('0.mp4')]
files.sort()
if concatenate_all_replay:
# create a file named 'files.txt' to store the names of all mp4 files
with open(os.path.join(replay_save_path, 'files.txt'), 'w') as f:
for file in files:
f.write("file '{}'\n".format(file))
# combine all the mp4 files into one mp4 file, rename it as 'deploy.mp4'
os.system(
'ffmpeg -f concat -safe 0 -i {} -c copy {}/deploy.mp4'.format(
os.path.join(replay_save_path, 'files.txt'), replay_save_path
)
)
return EvalReturn(eval_value=np.mean(reward_list), eval_value_std=np.std(reward_list))
def batch_evaluate(
self,
n_evaluator_episode: int = None,
) -> EvalReturn:
"""
Overview:
Perform a batch evaluation of the agent over a specified number of episodes: ``n_evaluator_episode``.
Arguments:
- n_evaluator_episode (:obj:`Optional[int]`): Number of episodes to run the evaluation.
If None, uses default value from configuration. Defaults to None.
Returns:
- An `EvalReturn` object with evaluation results such as mean and standard deviation of returns.
.. note::
This method evaluates the agent's performance across multiple episodes to gauge its effectiveness.
"""
evaluator_env = create_env_manager(
self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.evaluator_env_cfg]
)
evaluator_env.seed(self.cfg.seed, dynamic_seed=False)
set_pkg_seed(self.cfg.seed, use_cuda=self.cfg.policy.cuda)
# ==============================================================
# MCTS+RL algorithms related core code
# ==============================================================
policy_config = self.cfg.policy
evaluator = Evaluator(
eval_freq=self.cfg.policy.eval_freq,
n_evaluator_episode=self.cfg.env.n_evaluator_episode
if n_evaluator_episode is None else n_evaluator_episode,
stop_value=self.cfg.env.stop_value,
env=evaluator_env,
policy=self.policy.eval_mode,
exp_name=self.cfg.exp_name,
policy_config=policy_config
)
# ==============================================================
# Main loop
# ==============================================================
stop, reward = evaluator.eval()
return EvalReturn(
eval_value=np.mean(reward['eval_episode_return']), eval_value_std=np.std(reward['eval_episode_return'])
)
@property
def best(self):
"""
Overview:
Provides access to the best model according to evaluation metrics.
Returns:
- The agent with the best model loaded.
.. note::
The best model is saved in the path `./exp_name/ckpt/ckpt_best.pth.tar`.
When this property is accessed, the agent instance will load the best model state.
"""
best_model_file_path = os.path.join(self.checkpoint_save_dir, "ckpt_best.pth.tar")
# Load best model if it exists
if os.path.exists(best_model_file_path):
policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
self.policy.learn_mode.load_state_dict(policy_state_dict)
return self
|