File size: 4,780 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
from tensorboardX import SummaryWriter
from ditk import logging
import os
import numpy as np
from ding.model.template.qac import ContinuousQAC
from ding.policy import SACPolicy
from ding.envs import BaseEnvManagerV2
from ding.data import DequeBuffer
from ding.config import compile_config
from ding.framework import task
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import data_pusher, StepCollector, interaction_evaluator, \
CkptSaver, OffPolicyLearner, termination_checker
from ding.utils import set_pkg_seed
from dizoo.dmc2gym.envs.dmc2gym_env import DMC2GymEnv
from dizoo.dmc2gym.config.dmc2gym_sac_pixel_config import main_config, create_config
def main():
logging.getLogger().setLevel(logging.INFO)
main_config.exp_name = 'dmc2gym_sac_pixel_seed0'
main_config.policy.cuda = True
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
num_seed = 1
for seed_i in range(num_seed):
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'seed' + str(seed_i)))
with task.start(async_mode=False, ctx=OnlineRLContext()):
collector_env = BaseEnvManagerV2(
env_fn=[lambda: DMC2GymEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
)
evaluator_env = BaseEnvManagerV2(
env_fn=[lambda: DMC2GymEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
model = ContinuousQAC(**cfg.policy.model)
logging.info(model)
buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
policy = SACPolicy(cfg.policy, model=model)
def _add_scalar(ctx):
if ctx.eval_value != -np.inf:
tb_logger.add_scalar('evaluator_step/reward', ctx.eval_value, global_step=ctx.env_step)
collector_rewards = [ctx.trajectories[i]['reward'] for i in range(len(ctx.trajectories))]
collector_mean_reward = sum(collector_rewards) / len(ctx.trajectories)
# collector_max_reward = max(collector_rewards)
# collector_min_reward = min(collector_rewards)
tb_logger.add_scalar('collecter_step/mean_reward', collector_mean_reward, global_step=ctx.env_step)
# tb_logger.add_scalar('collecter_step/max_reward', collector_max_reward, global_step= ctx.env_step)
# tb_logger.add_scalar('collecter_step/min_reward', collector_min_reward, global_step= ctx.env_step)
tb_logger.add_scalar(
'collecter_step/avg_env_step_per_episode',
ctx.env_step / ctx.env_episode,
global_step=ctx.env_step
)
def _add_train_scalar(ctx):
len_train = len(ctx.train_output)
cur_lr_q_avg = sum([ctx.train_output[i]['cur_lr_q'] for i in range(len_train)]) / len_train
cur_lr_p_avg = sum([ctx.train_output[i]['cur_lr_p'] for i in range(len_train)]) / len_train
critic_loss_avg = sum([ctx.train_output[i]['critic_loss'] for i in range(len_train)]) / len_train
policy_loss_avg = sum([ctx.train_output[i]['policy_loss'] for i in range(len_train)]) / len_train
total_loss_avg = sum([ctx.train_output[i]['total_loss'] for i in range(len_train)]) / len_train
tb_logger.add_scalar('learner_step/cur_lr_q_avg', cur_lr_q_avg, global_step=ctx.env_step)
tb_logger.add_scalar('learner_step/cur_lr_p_avg', cur_lr_p_avg, global_step=ctx.env_step)
tb_logger.add_scalar('learner_step/critic_loss_avg', critic_loss_avg, global_step=ctx.env_step)
tb_logger.add_scalar('learner_step/policy_loss_avg', policy_loss_avg, global_step=ctx.env_step)
tb_logger.add_scalar('learner_step/total_loss_avg', total_loss_avg, global_step=ctx.env_step)
task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(
StepCollector(
cfg, policy.collect_mode, collector_env, random_collect_size=cfg.policy.random_collect_size
)
)
task.use(_add_scalar)
task.use(data_pusher(cfg, buffer_))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
task.use(_add_train_scalar)
task.use(CkptSaver(policy, cfg.exp_name, train_freq=int(1e5)))
task.use(termination_checker(max_env_step=int(5e6)))
task.run()
if __name__ == "__main__":
main()
|