gomoku / DI-engine /dizoo /atari /entry /atari_dqn_main.py
zjowowen's picture
init space
079c32c
raw
history blame
3.14 kB
import os
import gym
from tensorboardX import SummaryWriter
from easydict import EasyDict
from copy import deepcopy
from functools import partial
from ding.config import compile_config
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
from ding.envs import SyncSubprocessEnvManager
from ding.policy import DQNPolicy
from ding.model import DQN
from ding.utils import set_pkg_seed, deep_merge_dicts
from ding.rl_utils import get_epsilon_greedy_fn
from dizoo.atari.envs import AtariEnv
from dizoo.atari.config.serial.pong.pong_dqn_config import pong_dqn_config
def main(cfg, seed=0, max_iterations=int(1e10)):
cfg = compile_config(
cfg,
SyncSubprocessEnvManager,
DQNPolicy,
BaseLearner,
SampleSerialCollector,
InteractionSerialEvaluator,
AdvancedReplayBuffer,
save_cfg=True
)
collector_env_cfg = AtariEnv.create_collector_env_cfg(cfg.env)
evaluator_env_cfg = AtariEnv.create_evaluator_env_cfg(cfg.env)
collector_env = SyncSubprocessEnvManager(
env_fn=[partial(AtariEnv, cfg=c) for c in collector_env_cfg], cfg=cfg.env.manager
)
evaluator_env = SyncSubprocessEnvManager(
env_fn=[partial(AtariEnv, cfg=c) for c in evaluator_env_cfg], cfg=cfg.env.manager
)
collector_env.seed(seed)
evaluator_env.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
model = DQN(**cfg.policy.model)
policy = DQNPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = SampleSerialCollector(
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
replay_buffer = AdvancedReplayBuffer(
cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name, instance_name='replay_buffer'
)
eps_cfg = cfg.policy.other.eps
epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
while True:
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
eps = epsilon_greedy(collector.envstep)
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps})
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
for i in range(cfg.policy.learn.update_per_collect):
batch_size = learner.policy.get_attribute('batch_size')
train_data = replay_buffer.sample(batch_size, learner.train_iter)
if train_data is not None:
learner.train(train_data, collector.envstep)
if __name__ == "__main__":
main(EasyDict(pong_dqn_config))