gomoku / DI-engine /dizoo /metadrive /config /metadrive_onppo_config.py
zjowowen's picture
init space
079c32c
raw
history blame
4.27 kB
from easydict import EasyDict
from functools import partial
from tensorboardX import SummaryWriter
import metadrive
import gym
from ding.envs import BaseEnvManager, SyncSubprocessEnvManager
from ding.config import compile_config
from ding.model.template import ContinuousQAC, VAC
from ding.policy import PPOPolicy
from ding.worker import SampleSerialCollector, InteractionSerialEvaluator, BaseLearner
from dizoo.metadrive.env.drive_env import MetaDrivePPOOriginEnv
from dizoo.metadrive.env.drive_wrapper import DriveEnvWrapper
metadrive_basic_config = dict(
exp_name='metadrive_onppo_seed0',
env=dict(
metadrive=dict(
use_render=False,
traffic_density=0.10, # Density of vehicles occupying the roads, range in [0,1]
map='XSOS', # Int or string: an easy way to fill map_config
horizon=4000, # Max step number
driving_reward=1.0, # Reward to encourage agent to move forward.
speed_reward=0.1, # Reward to encourage agent to drive at a high speed
use_lateral_reward=False, # reward for lane keeping
out_of_road_penalty=40.0, # Penalty to discourage driving out of road
crash_vehicle_penalty=40.0, # Penalty to discourage collision
decision_repeat=20, # Reciprocal of decision frequency
out_of_route_done=True, # Game over if driving out of road
),
manager=dict(
shared_memory=False,
max_retry=2,
context='spawn',
),
n_evaluator_episode=16,
stop_value=255,
collector_env_num=8,
evaluator_env_num=8,
),
policy=dict(
cuda=True,
action_space='continuous',
model=dict(
obs_shape=[5, 84, 84],
action_shape=2,
action_space='continuous',
bound_type='tanh',
encoder_hidden_size_list=[128, 128, 64],
),
learn=dict(
epoch_per_collect=10,
batch_size=64,
learning_rate=3e-4,
entropy_weight=0.001,
value_weight=0.5,
clip_ratio=0.02,
adv_norm=False,
value_norm=True,
grad_clip_value=10,
),
collect=dict(n_sample=3000, ),
eval=dict(evaluator=dict(eval_freq=1000, ), ),
),
)
main_config = EasyDict(metadrive_basic_config)
def wrapped_env(env_cfg, wrapper_cfg=None):
return DriveEnvWrapper(MetaDrivePPOOriginEnv(env_cfg), wrapper_cfg)
def main(cfg):
cfg = compile_config(
cfg, SyncSubprocessEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator
)
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
collector_env = SyncSubprocessEnvManager(
env_fn=[partial(wrapped_env, cfg.env.metadrive) for _ in range(collector_env_num)],
cfg=cfg.env.manager,
)
evaluator_env = SyncSubprocessEnvManager(
env_fn=[partial(wrapped_env, cfg.env.metadrive) for _ in range(evaluator_env_num)],
cfg=cfg.env.manager,
)
model = VAC(**cfg.policy.model)
policy = PPOPolicy(cfg.policy, model=model)
tb_logger = SummaryWriter('./log/{}/'.format(cfg.exp_name))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = SampleSerialCollector(
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
learner.call_hook('before_run')
while True:
if evaluator.should_eval(learner.train_iter):
stop, rate = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break
# Sampling data from environments
new_data = collector.collect(cfg.policy.collect.n_sample, train_iter=learner.train_iter)
learner.train(new_data, collector.envstep)
learner.call_hook('after_run')
collector.close()
evaluator.close()
learner.close()
if __name__ == '__main__':
main(main_config)