zjowowen commited on
Commit
ad95118
1 Parent(s): ffddae6

Upload policy_config.json with huggingface_hub

Browse files
Files changed (1) hide show
  1. policy_config.json +149 -0
policy_config.json ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "project": "LunarLanderContinuous-v2-QGPO-VPSDE",
4
+ "device": "cuda",
5
+ "wandb": {
6
+ "project": "IQL-LunarLanderContinuous-v2-QGPO-VPSDE"
7
+ },
8
+ "simulator": {
9
+ "type": "GymEnvSimulator",
10
+ "args": {
11
+ "env_id": "LunarLanderContinuous-v2"
12
+ }
13
+ },
14
+ "model": {
15
+ "QGPOPolicy": {
16
+ "device": "cuda",
17
+ "critic": {
18
+ "device": "cuda",
19
+ "q_alpha": 1.0,
20
+ "DoubleQNetwork": {
21
+ "backbone": {
22
+ "type": "ConcatenateMLP",
23
+ "args": {
24
+ "hidden_sizes": [
25
+ 10,
26
+ 256,
27
+ 256
28
+ ],
29
+ "output_size": 1,
30
+ "activation": "relu"
31
+ }
32
+ }
33
+ }
34
+ },
35
+ "diffusion_model": {
36
+ "device": "cuda",
37
+ "x_size": 2,
38
+ "alpha": 1.0,
39
+ "solver": {
40
+ "type": "DPMSolver",
41
+ "args": {
42
+ "order": 2,
43
+ "device": "cuda",
44
+ "steps": 17
45
+ }
46
+ },
47
+ "path": {
48
+ "type": "linear_vp_sde",
49
+ "beta_0": 0.1,
50
+ "beta_1": 20.0
51
+ },
52
+ "reverse_path": {
53
+ "type": "linear_vp_sde",
54
+ "beta_0": 0.1,
55
+ "beta_1": 20.0
56
+ },
57
+ "model": {
58
+ "type": "noise_function",
59
+ "args": {
60
+ "t_encoder": {
61
+ "type": "GaussianFourierProjectionTimeEncoder",
62
+ "args": {
63
+ "embed_dim": 32,
64
+ "scale": 30.0
65
+ }
66
+ },
67
+ "backbone": {
68
+ "type": "TemporalSpatialResidualNet",
69
+ "args": {
70
+ "hidden_sizes": [
71
+ 512,
72
+ 256,
73
+ 128
74
+ ],
75
+ "output_dim": 2,
76
+ "t_dim": 32,
77
+ "condition_dim": 8,
78
+ "condition_hidden_dim": 32,
79
+ "t_condition_hidden_dim": 128
80
+ }
81
+ }
82
+ }
83
+ },
84
+ "energy_guidance": {
85
+ "t_encoder": {
86
+ "type": "GaussianFourierProjectionTimeEncoder",
87
+ "args": {
88
+ "embed_dim": 32,
89
+ "scale": 30.0
90
+ }
91
+ },
92
+ "backbone": {
93
+ "type": "ConcatenateMLP",
94
+ "args": {
95
+ "hidden_sizes": [
96
+ 42,
97
+ 256,
98
+ 256
99
+ ],
100
+ "output_size": 1,
101
+ "activation": "silu"
102
+ }
103
+ }
104
+ }
105
+ }
106
+ }
107
+ },
108
+ "parameter": {
109
+ "behaviour_policy": {
110
+ "batch_size": 1024,
111
+ "learning_rate": 0.0001,
112
+ "epochs": 500
113
+ },
114
+ "action_augment_num": 16,
115
+ "fake_data_t_span": null,
116
+ "energy_guided_policy": {
117
+ "batch_size": 256
118
+ },
119
+ "critic": {
120
+ "stop_training_epochs": 500,
121
+ "learning_rate": 0.0001,
122
+ "discount_factor": 0.99,
123
+ "update_momentum": 0.005
124
+ },
125
+ "energy_guidance": {
126
+ "epochs": 1000,
127
+ "learning_rate": 0.0001
128
+ },
129
+ "evaluation": {
130
+ "evaluation_interval": 50,
131
+ "guidance_scale": [
132
+ 0.0,
133
+ 1.0,
134
+ 2.0
135
+ ]
136
+ },
137
+ "checkpoint_path": "./LunarLanderContinuous-v2-QGPO"
138
+ }
139
+ },
140
+ "deploy": {
141
+ "device": "cuda",
142
+ "env": {
143
+ "env_id": "LunarLanderContinuous-v2",
144
+ "seed": 0
145
+ },
146
+ "num_deploy_steps": 1000,
147
+ "t_span": null
148
+ }
149
+ }