gomoku / DI-engine /ding /entry /serial_entry_mbrl.py
zjowowen's picture
init space
079c32c
from typing import Union, Optional, List, Any, Tuple
import torch
import os
from functools import partial
from tensorboardX import SummaryWriter
from copy import deepcopy
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
get_buffer_cls, create_serial_collector
from ding.world_model import WorldModel
from ding.worker import IBuffer
from ding.envs import get_vec_env_setting, create_env_manager
from ding.config import read_config, compile_config
from ding.utils import set_pkg_seed, deep_merge_dicts
from ding.policy import create_policy
from ding.world_model import create_world_model
from ding.entry.utils import random_collect
def mbrl_entry_setup(
input_cfg: Union[str, Tuple[dict, dict]],
seed: int = 0,
env_setting: Optional[List[Any]] = None,
model: Optional[torch.nn.Module] = None,
) -> Tuple:
if isinstance(input_cfg, str):
cfg, create_cfg = read_config(input_cfg)
else:
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type = create_cfg.policy.type + '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
else:
env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
collector_env.seed(cfg.seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
# create logger
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
# create world model
world_model = create_world_model(cfg.world_model, env_fn(cfg.env), tb_logger)
# create policy
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
# create worker
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = create_serial_collector(
cfg.policy.collect.collector,
env=collector_env,
policy=policy.collect_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name
)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
env_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
commander = BaseSerialCommander(
cfg.policy.other.commander, learner, collector, evaluator, env_buffer, policy.command_mode
)
return (
cfg,
policy,
world_model,
env_buffer,
learner,
collector,
collector_env,
evaluator,
commander,
tb_logger,
)
def create_img_buffer(
cfg: dict, input_cfg: Union[str, Tuple[dict, dict]], world_model: WorldModel, tb_logger: 'SummaryWriter'
) -> IBuffer: # noqa
if isinstance(input_cfg, str):
_, create_cfg = read_config(input_cfg)
else:
_, create_cfg = input_cfg
img_buffer_cfg = cfg.world_model.other.imagination_buffer
img_buffer_cfg.update(create_cfg.imagination_buffer)
buffer_cls = get_buffer_cls(img_buffer_cfg)
cfg.world_model.other.imagination_buffer.update(deep_merge_dicts(buffer_cls.default_config(), img_buffer_cfg))
if img_buffer_cfg.type == 'elastic':
img_buffer_cfg.set_buffer_size = world_model.buffer_size_scheduler
img_buffer = create_buffer(cfg.world_model.other.imagination_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
return img_buffer
def serial_pipeline_dyna(
input_cfg: Union[str, Tuple[dict, dict]],
seed: int = 0,
env_setting: Optional[List[Any]] = None,
model: Optional[torch.nn.Module] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
) -> 'Policy': # noqa
"""
Overview:
Serial pipeline entry for dyna-style model-based RL.
Arguments:
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
``str`` type means config file path. \
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
``BaseEnv`` subclass, collector env config, and evaluator env config.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
- max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
Returns:
- policy (:obj:`Policy`): Converged policy.
"""
cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger = \
mbrl_entry_setup(input_cfg, seed, env_setting, model)
img_buffer = create_img_buffer(cfg, input_cfg, world_model, tb_logger)
learner.call_hook('before_run')
if cfg.policy.get('random_collect_size', 0) > 0:
random_collect(cfg.policy, policy, collector, collector_env, commander, env_buffer)
while True:
collect_kwargs = commander.step()
# eval the policy
if evaluator.should_eval(collector.envstep):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
# fill environment buffer
data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
env_buffer.push(data, cur_collector_envstep=collector.envstep)
# eval&train world model and fill imagination buffer
if world_model.should_eval(collector.envstep):
world_model.eval(env_buffer, collector.envstep, learner.train_iter)
if world_model.should_train(collector.envstep):
world_model.train(env_buffer, collector.envstep, learner.train_iter)
world_model.fill_img_buffer(
policy.collect_mode, env_buffer, img_buffer, collector.envstep, learner.train_iter
)
for i in range(cfg.policy.learn.update_per_collect):
batch_size = learner.policy.get_attribute('batch_size')
train_data = world_model.sample(env_buffer, img_buffer, batch_size, learner.train_iter)
learner.train(train_data, collector.envstep)
if cfg.policy.on_policy:
# On-policy algorithm must clear the replay buffer.
env_buffer.clear()
img_buffer.clear()
if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
break
learner.call_hook('after_run')
return policy
def serial_pipeline_dream(
input_cfg: Union[str, Tuple[dict, dict]],
seed: int = 0,
env_setting: Optional[List[Any]] = None,
model: Optional[torch.nn.Module] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
) -> 'Policy': # noqa
"""
Overview:
Serial pipeline entry for dreamer-style model-based RL.
Arguments:
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
``str`` type means config file path. \
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
``BaseEnv`` subclass, collector env config, and evaluator env config.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
- max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
Returns:
- policy (:obj:`Policy`): Converged policy.
"""
cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger = \
mbrl_entry_setup(input_cfg, seed, env_setting, model)
learner.call_hook('before_run')
if cfg.policy.get('random_collect_size', 0) > 0:
random_collect(cfg.policy, policy, collector, collector_env, commander, env_buffer)
while True:
collect_kwargs = commander.step()
# eval the policy
if evaluator.should_eval(collector.envstep):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
# fill environment buffer
data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
env_buffer.push(data, cur_collector_envstep=collector.envstep)
# eval&train world model and fill imagination buffer
if world_model.should_eval(collector.envstep):
world_model.eval(env_buffer, collector.envstep, learner.train_iter)
if world_model.should_train(collector.envstep):
world_model.train(env_buffer, collector.envstep, learner.train_iter)
update_per_collect = cfg.policy.learn.update_per_collect // world_model.rollout_length_scheduler(
collector.envstep
)
update_per_collect = max(1, update_per_collect)
for i in range(update_per_collect):
batch_size = learner.policy.get_attribute('batch_size')
train_data = env_buffer.sample(batch_size, learner.train_iter)
# dreamer-style: use pure on-policy imagined rollout to train policy,
# which depends on the current envstep to decide the rollout length
learner.train(
train_data, collector.envstep, policy_kwargs=dict(world_model=world_model, envstep=collector.envstep)
)
if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
break
learner.call_hook('after_run')
return policy
def serial_pipeline_dreamer(
input_cfg: Union[str, Tuple[dict, dict]],
seed: int = 0,
env_setting: Optional[List[Any]] = None,
model: Optional[torch.nn.Module] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
) -> 'Policy': # noqa
"""
Overview:
Serial pipeline entry for dreamerv3.
Arguments:
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
``str`` type means config file path. \
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
``BaseEnv`` subclass, collector env config, and evaluator env config.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
- max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
Returns:
- policy (:obj:`Policy`): Converged policy.
"""
cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger = \
mbrl_entry_setup(input_cfg, seed, env_setting, model)
learner.call_hook('before_run')
# prefill environment buffer
if cfg.policy.get('random_collect_size', 0) > 0:
cfg.policy.random_collect_size = cfg.policy.random_collect_size // cfg.policy.collect.unroll_len
random_collect(cfg.policy, policy, collector, collector_env, commander, env_buffer)
while True:
collect_kwargs = commander.step()
# eval the policy
if evaluator.should_eval(collector.envstep):
stop, reward = evaluator.eval(
learner.save_checkpoint,
learner.train_iter,
collector.envstep,
policy_kwargs=dict(world_model=world_model)
)
if stop:
break
# train world model and fill imagination buffer
steps = (
cfg.world_model.pretrain
if world_model.should_pretrain() else int(world_model.should_train(collector.envstep))
)
for _ in range(steps):
batch_size = learner.policy.get_attribute('batch_size')
batch_length = cfg.policy.learn.batch_length
post, context = world_model.train(
env_buffer, collector.envstep, learner.train_iter, batch_size, batch_length
)
start = post
learner.train(
start, collector.envstep, policy_kwargs=dict(world_model=world_model, envstep=collector.envstep)
)
# fill environment buffer
data = collector.collect(
train_iter=learner.train_iter,
policy_kwargs=dict(world_model=world_model, envstep=collector.envstep, **collect_kwargs)
)
env_buffer.push(data, cur_collector_envstep=collector.envstep)
if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
break
learner.call_hook('after_run')
return policy