File size: 2,553 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from easydict import EasyDict

from ding.entry import serial_pipeline_dreamer

cuda = False

cartpole_balance_dreamer_config = dict(
    exp_name='dmc2gym_cartpole_balance_dreamer',
    env=dict(
        env_id='dmc2gym_cartpole_balance',
        domain_name='cartpole',
        task_name='balance',
        frame_skip=1,
        warp_frame=True,
        scale=True,
        clip_rewards=False,
        action_repeat=2,
        frame_stack=1,
        from_pixels=True,
        resize=64,
        collector_env_num=1,
        evaluator_env_num=1,
        n_evaluator_episode=1,
        stop_value=1000,  # 1000
    ),
    policy=dict(
        cuda=cuda,
        # it is better to put random_collect_size in policy.other
        random_collect_size=2500,
        model=dict(
            obs_shape=(3, 64, 64),
            action_shape=1,
            actor_dist='normal',
        ),
        learn=dict(
            lambda_=0.95,
            learning_rate=3e-5,
            batch_size=16,
            batch_length=64,
            imag_sample=True,
            discount=0.997,
            reward_EMA=True,
        ),
        collect=dict(
            n_sample=1,
            unroll_len=1,
            action_size=1,  # has to be specified
            collect_dyn_sample=True,
        ),
        command=dict(),
        eval=dict(evaluator=dict(eval_freq=5000, )),
        other=dict(
            # environment buffer
            replay_buffer=dict(replay_buffer_size=500000, periodic_thruput_seconds=60),
        ),
    ),
    world_model=dict(
        pretrain=100,
        train_freq=2,
        cuda=cuda,
        model=dict(
            state_size=(3, 64, 64),  # has to be specified
            action_size=1,  # has to be specified
            reward_size=1,
            batch_size=16,
        ),
    ),
)

cartpole_balance_dreamer_config = EasyDict(cartpole_balance_dreamer_config)

cartpole_balance_create_config = dict(
    env=dict(
        type='dmc2gym',
        import_names=['dizoo.dmc2gym.envs.dmc2gym_env'],
    ),
    env_manager=dict(type='base'),
    policy=dict(
        type='dreamer',
        import_names=['ding.policy.mbpolicy.dreamer'],
    ),
    replay_buffer=dict(type='sequence', ),
    world_model=dict(
        type='dreamer',
        import_names=['ding.world_model.dreamer'],
    ),
)
cartpole_balance_create_config = EasyDict(cartpole_balance_create_config)

if __name__ == '__main__':
    serial_pipeline_dreamer(
        (cartpole_balance_dreamer_config, cartpole_balance_create_config), seed=0, max_env_step=1000000
    )