sdpkjc commited on
Commit
87e7c60
1 Parent(s): f5b1eee

pushing model

Browse files
README.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - Swimmer-v4
4
+ - deep-reinforcement-learning
5
+ - reinforcement-learning
6
+ - custom-implementation
7
+ library_name: cleanrl
8
+ model-index:
9
+ - name: SAC
10
+ results:
11
+ - task:
12
+ type: reinforcement-learning
13
+ name: reinforcement-learning
14
+ dataset:
15
+ name: Swimmer-v4
16
+ type: Swimmer-v4
17
+ metrics:
18
+ - type: mean_reward
19
+ value: 64.82 +/- 24.71
20
+ name: mean_reward
21
+ verified: false
22
+ ---
23
+
24
+ # (CleanRL) **SAC** Agent Playing **Swimmer-v4**
25
+
26
+ This is a trained model of a SAC agent playing Swimmer-v4.
27
+ The model was trained by using [CleanRL](https://github.com/vwxyzjn/cleanrl) and the most up-to-date training code can be
28
+ found [here](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/sac_continuous_action.py).
29
+
30
+ ## Get Started
31
+
32
+ To use this model, please install the `cleanrl` package with the following command:
33
+
34
+ ```
35
+ pip install "cleanrl[sac_continuous_action]"
36
+ python -m cleanrl_utils.enjoy --exp-name sac_continuous_action --env-id Swimmer-v4
37
+ ```
38
+
39
+ Please refer to the [documentation](https://docs.cleanrl.dev/get-started/zoo/) for more detail.
40
+
41
+
42
+ ## Command to reproduce the training
43
+
44
+ ```bash
45
+ curl -OL https://huggingface.co/sdpkjc/Swimmer-v4-sac_continuous_action-seed1/raw/main/sac_continuous_action.py
46
+ curl -OL https://huggingface.co/sdpkjc/Swimmer-v4-sac_continuous_action-seed1/raw/main/pyproject.toml
47
+ curl -OL https://huggingface.co/sdpkjc/Swimmer-v4-sac_continuous_action-seed1/raw/main/poetry.lock
48
+ poetry install --all-extras
49
+ python sac_continuous_action.py --save-model --upload-model --hf-entity sdpkjc --env-id Swimmer-v4 --seed 1
50
+ ```
51
+
52
+ # Hyperparameters
53
+ ```python
54
+ {'alpha': 0.2,
55
+ 'autotune': True,
56
+ 'batch_size': 256,
57
+ 'buffer_size': 1000000,
58
+ 'capture_video': False,
59
+ 'cuda': True,
60
+ 'env_id': 'Swimmer-v4',
61
+ 'exp_name': 'sac_continuous_action',
62
+ 'gamma': 0.99,
63
+ 'hf_entity': 'sdpkjc',
64
+ 'learning_starts': 5000.0,
65
+ 'noise_clip': 0.5,
66
+ 'policy_frequency': 2,
67
+ 'policy_lr': 0.0003,
68
+ 'q_lr': 0.001,
69
+ 'save_model': True,
70
+ 'seed': 1,
71
+ 'target_network_frequency': 1,
72
+ 'tau': 0.005,
73
+ 'torch_deterministic': True,
74
+ 'total_timesteps': 1000000,
75
+ 'track': False,
76
+ 'upload_model': True,
77
+ 'wandb_entity': None,
78
+ 'wandb_project_name': 'cleanRL'}
79
+ ```
80
+
events.out.tfevents.1699115910.4090-171.2900642.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b578d6ba1b511b0d43035c6facbee34484e83d8a098992486cb327d155a7772d
3
+ size 5038964
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "cleanrl"
3
+ version = "1.1.0"
4
+ description = "High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features"
5
+ authors = ["Costa Huang <costa.huang@outlook.com>"]
6
+ packages = [
7
+ { include = "cleanrl" },
8
+ { include = "cleanrl_utils" },
9
+ ]
10
+ keywords = ["reinforcement", "machine", "learning", "research"]
11
+ license="MIT"
12
+ readme = "README.md"
13
+
14
+ [tool.poetry.dependencies]
15
+ python = ">=3.7.1,<3.11"
16
+ tensorboard = "^2.10.0"
17
+ wandb = "^0.13.11"
18
+ gym = "0.23.1"
19
+ torch = ">=1.12.1"
20
+ stable-baselines3 = "1.2.0"
21
+ gymnasium = ">=0.28.1"
22
+ moviepy = "^1.0.3"
23
+ pygame = "2.1.0"
24
+ huggingface-hub = "^0.11.1"
25
+ rich = "<12.0"
26
+ tenacity = "^8.2.2"
27
+
28
+ ale-py = {version = "0.7.4", optional = true}
29
+ AutoROM = {extras = ["accept-rom-license"], version = "^0.4.2", optional = true}
30
+ opencv-python = {version = "^4.6.0.66", optional = true}
31
+ procgen = {version = "^0.10.7", optional = true}
32
+ pytest = {version = "^7.1.3", optional = true}
33
+ mujoco = {version = "<=2.3.3", optional = true}
34
+ imageio = {version = "^2.14.1", optional = true}
35
+ free-mujoco-py = {version = "^2.1.6", optional = true}
36
+ mkdocs-material = {version = "^8.4.3", optional = true}
37
+ markdown-include = {version = "^0.7.0", optional = true}
38
+ openrlbenchmark = {version = "^0.1.1b4", optional = true}
39
+ jax = {version = "^0.3.17", optional = true}
40
+ jaxlib = {version = "^0.3.15", optional = true}
41
+ flax = {version = "^0.6.0", optional = true}
42
+ optuna = {version = "^3.0.1", optional = true}
43
+ optuna-dashboard = {version = "^0.7.2", optional = true}
44
+ envpool = {version = "^0.6.4", optional = true}
45
+ PettingZoo = {version = "1.18.1", optional = true}
46
+ SuperSuit = {version = "3.4.0", optional = true}
47
+ multi-agent-ale-py = {version = "0.1.11", optional = true}
48
+ boto3 = {version = "^1.24.70", optional = true}
49
+ awscli = {version = "^1.25.71", optional = true}
50
+ shimmy = {version = ">=1.0.0", extras = ["dm-control"], optional = true}
51
+
52
+ [tool.poetry.group.dev.dependencies]
53
+ pre-commit = "^2.20.0"
54
+
55
+
56
+ [tool.poetry.group.isaacgym]
57
+ optional = true
58
+ [tool.poetry.group.isaacgym.dependencies]
59
+ isaacgymenvs = {git = "https://github.com/vwxyzjn/IsaacGymEnvs.git", rev = "poetry", python = ">=3.7.1,<3.10"}
60
+ isaacgym = {path = "cleanrl/ppo_continuous_action_isaacgym/isaacgym", develop = true}
61
+
62
+
63
+ [build-system]
64
+ requires = ["poetry-core"]
65
+ build-backend = "poetry.core.masonry.api"
66
+
67
+ [tool.poetry.extras]
68
+ atari = ["ale-py", "AutoROM", "opencv-python"]
69
+ procgen = ["procgen"]
70
+ plot = ["pandas", "seaborn"]
71
+ pytest = ["pytest"]
72
+ mujoco = ["mujoco", "imageio"]
73
+ mujoco_py = ["free-mujoco-py"]
74
+ jax = ["jax", "jaxlib", "flax"]
75
+ docs = ["mkdocs-material", "markdown-include", "openrlbenchmark"]
76
+ envpool = ["envpool"]
77
+ optuna = ["optuna", "optuna-dashboard"]
78
+ pettingzoo = ["PettingZoo", "SuperSuit", "multi-agent-ale-py"]
79
+ cloud = ["boto3", "awscli"]
80
+ dm_control = ["shimmy", "mujoco"]
81
+
82
+ # dependencies for algorithm variant (useful when you want to run a specific algorithm)
83
+ dqn = []
84
+ dqn_atari = ["ale-py", "AutoROM", "opencv-python"]
85
+ dqn_jax = ["jax", "jaxlib", "flax"]
86
+ dqn_atari_jax = [
87
+ "ale-py", "AutoROM", "opencv-python", # atari
88
+ "jax", "jaxlib", "flax" # jax
89
+ ]
90
+ c51 = []
91
+ c51_atari = ["ale-py", "AutoROM", "opencv-python"]
92
+ c51_jax = ["jax", "jaxlib", "flax"]
93
+ c51_atari_jax = [
94
+ "ale-py", "AutoROM", "opencv-python", # atari
95
+ "jax", "jaxlib", "flax" # jax
96
+ ]
97
+ ppo_atari_envpool_xla_jax_scan = [
98
+ "ale-py", "AutoROM", "opencv-python", # atari
99
+ "jax", "jaxlib", "flax", # jax
100
+ "envpool", # envpool
101
+ ]
102
+ qdagger_dqn_atari_impalacnn = [
103
+ "ale-py", "AutoROM", "opencv-python"
104
+ ]
105
+ qdagger_dqn_atari_jax_impalacnn = [
106
+ "ale-py", "AutoROM", "opencv-python", # atari
107
+ "jax", "jaxlib", "flax", # jax
108
+ ]
replay.mp4 ADDED
Binary file (837 kB). View file
 
sac_continuous_action.cleanrl_model ADDED
Binary file (834 kB). View file
 
sac_continuous_action.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/sac/#sac_continuous_actionpy
2
+ import argparse
3
+ import os
4
+ import random
5
+ import time
6
+ from distutils.util import strtobool
7
+
8
+ import gymnasium as gym
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.optim as optim
14
+ from stable_baselines3.common.buffers import ReplayBuffer
15
+ from torch.utils.tensorboard import SummaryWriter
16
+
17
+
18
+ def parse_args():
19
+ # fmt: off
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
22
+ help="the name of this experiment")
23
+ parser.add_argument("--seed", type=int, default=1,
24
+ help="seed of the experiment")
25
+ parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
26
+ help="if toggled, `torch.backends.cudnn.deterministic=False`")
27
+ parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
28
+ help="if toggled, cuda will be enabled by default")
29
+ parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
30
+ help="if toggled, this experiment will be tracked with Weights and Biases")
31
+ parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
32
+ help="the wandb's project name")
33
+ parser.add_argument("--wandb-entity", type=str, default=None,
34
+ help="the entity (team) of wandb's project")
35
+ parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
36
+ help="whether to capture videos of the agent performances (check out `videos` folder)")
37
+ parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
38
+ help="whether to save model into the `runs/{run_name}` folder")
39
+ parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
40
+ help="whether to upload the saved model to huggingface")
41
+ parser.add_argument("--hf-entity", type=str, default="",
42
+ help="the user or org name of the model repository from the Hugging Face Hub")
43
+
44
+ # Algorithm specific arguments
45
+ parser.add_argument("--env-id", type=str, default="Hopper-v4",
46
+ help="the id of the environment")
47
+ parser.add_argument("--total-timesteps", type=int, default=1000000,
48
+ help="total timesteps of the experiments")
49
+ parser.add_argument("--buffer-size", type=int, default=int(1e6),
50
+ help="the replay memory buffer size")
51
+ parser.add_argument("--gamma", type=float, default=0.99,
52
+ help="the discount factor gamma")
53
+ parser.add_argument("--tau", type=float, default=0.005,
54
+ help="target smoothing coefficient (default: 0.005)")
55
+ parser.add_argument("--batch-size", type=int, default=256,
56
+ help="the batch size of sample from the reply memory")
57
+ parser.add_argument("--learning-starts", type=int, default=5e3,
58
+ help="timestep to start learning")
59
+ parser.add_argument("--policy-lr", type=float, default=3e-4,
60
+ help="the learning rate of the policy network optimizer")
61
+ parser.add_argument("--q-lr", type=float, default=1e-3,
62
+ help="the learning rate of the Q network network optimizer")
63
+ parser.add_argument("--policy-frequency", type=int, default=2,
64
+ help="the frequency of training policy (delayed)")
65
+ parser.add_argument("--target-network-frequency", type=int, default=1, # Denis Yarats' implementation delays this by 2.
66
+ help="the frequency of updates for the target nerworks")
67
+ parser.add_argument("--noise-clip", type=float, default=0.5,
68
+ help="noise clip parameter of the Target Policy Smoothing Regularization")
69
+ parser.add_argument("--alpha", type=float, default=0.2,
70
+ help="Entropy regularization coefficient.")
71
+ parser.add_argument("--autotune", type=lambda x:bool(strtobool(x)), default=True, nargs="?", const=True,
72
+ help="automatic tuning of the entropy coefficient")
73
+ args = parser.parse_args()
74
+ # fmt: on
75
+ return args
76
+
77
+
78
+ def make_env(env_id, seed, idx, capture_video, run_name):
79
+ def thunk():
80
+ if capture_video and idx == 0:
81
+ env = gym.make(env_id, render_mode="rgb_array")
82
+ env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
83
+ else:
84
+ env = gym.make(env_id)
85
+ env = gym.wrappers.RecordEpisodeStatistics(env)
86
+ env.action_space.seed(seed)
87
+ return env
88
+
89
+ return thunk
90
+
91
+
92
+ # ALGO LOGIC: initialize agent here:
93
+ class SoftQNetwork(nn.Module):
94
+ def __init__(self, env):
95
+ super().__init__()
96
+ self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 256)
97
+ self.fc2 = nn.Linear(256, 256)
98
+ self.fc3 = nn.Linear(256, 1)
99
+
100
+ def forward(self, x, a):
101
+ x = torch.cat([x, a], 1)
102
+ x = F.relu(self.fc1(x))
103
+ x = F.relu(self.fc2(x))
104
+ x = self.fc3(x)
105
+ return x
106
+
107
+
108
+ LOG_STD_MAX = 2
109
+ LOG_STD_MIN = -5
110
+
111
+
112
+ class Actor(nn.Module):
113
+ def __init__(self, env):
114
+ super().__init__()
115
+ self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 256)
116
+ self.fc2 = nn.Linear(256, 256)
117
+ self.fc_mean = nn.Linear(256, np.prod(env.single_action_space.shape))
118
+ self.fc_logstd = nn.Linear(256, np.prod(env.single_action_space.shape))
119
+ # action rescaling
120
+ self.register_buffer(
121
+ "action_scale", torch.tensor((env.action_space.high - env.action_space.low) / 2.0, dtype=torch.float32)
122
+ )
123
+ self.register_buffer(
124
+ "action_bias", torch.tensor((env.action_space.high + env.action_space.low) / 2.0, dtype=torch.float32)
125
+ )
126
+
127
+ def forward(self, x):
128
+ x = F.relu(self.fc1(x))
129
+ x = F.relu(self.fc2(x))
130
+ mean = self.fc_mean(x)
131
+ log_std = self.fc_logstd(x)
132
+ log_std = torch.tanh(log_std)
133
+ log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1) # From SpinUp / Denis Yarats
134
+
135
+ return mean, log_std
136
+
137
+ def get_action(self, x):
138
+ mean, log_std = self(x)
139
+ std = log_std.exp()
140
+ normal = torch.distributions.Normal(mean, std)
141
+ x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
142
+ y_t = torch.tanh(x_t)
143
+ action = y_t * self.action_scale + self.action_bias
144
+ log_prob = normal.log_prob(x_t)
145
+ # Enforcing Action Bound
146
+ log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6)
147
+ log_prob = log_prob.sum(1, keepdim=True)
148
+ mean = torch.tanh(mean) * self.action_scale + self.action_bias
149
+ return action, log_prob, mean
150
+
151
+
152
+ if __name__ == "__main__":
153
+ import stable_baselines3 as sb3
154
+
155
+ if sb3.__version__ < "2.0":
156
+ raise ValueError(
157
+ """Ongoing migration: run the following command to install the new dependencies:
158
+ poetry run pip install "stable_baselines3==2.0.0a1"
159
+ """
160
+ )
161
+
162
+ args = parse_args()
163
+ run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
164
+ if args.track:
165
+ import wandb
166
+
167
+ wandb.init(
168
+ project=args.wandb_project_name,
169
+ entity=args.wandb_entity,
170
+ sync_tensorboard=True,
171
+ config=vars(args),
172
+ name=run_name,
173
+ monitor_gym=True,
174
+ save_code=True,
175
+ )
176
+ writer = SummaryWriter(f"runs/{run_name}")
177
+ writer.add_text(
178
+ "hyperparameters",
179
+ "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
180
+ )
181
+
182
+ # TRY NOT TO MODIFY: seeding
183
+ random.seed(args.seed)
184
+ np.random.seed(args.seed)
185
+ torch.manual_seed(args.seed)
186
+ torch.backends.cudnn.deterministic = args.torch_deterministic
187
+
188
+ device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
189
+
190
+ # env setup
191
+ envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
192
+ assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"
193
+
194
+ max_action = float(envs.single_action_space.high[0])
195
+
196
+ actor = Actor(envs).to(device)
197
+ qf1 = SoftQNetwork(envs).to(device)
198
+ qf2 = SoftQNetwork(envs).to(device)
199
+ qf1_target = SoftQNetwork(envs).to(device)
200
+ qf2_target = SoftQNetwork(envs).to(device)
201
+ qf1_target.load_state_dict(qf1.state_dict())
202
+ qf2_target.load_state_dict(qf2.state_dict())
203
+ q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.q_lr)
204
+ actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr)
205
+
206
+ # Automatic entropy tuning
207
+ if args.autotune:
208
+ target_entropy = -torch.prod(torch.Tensor(envs.single_action_space.shape).to(device)).item()
209
+ log_alpha = torch.zeros(1, requires_grad=True, device=device)
210
+ alpha = log_alpha.exp().item()
211
+ a_optimizer = optim.Adam([log_alpha], lr=args.q_lr)
212
+ else:
213
+ alpha = args.alpha
214
+
215
+ envs.single_observation_space.dtype = np.float32
216
+ rb = ReplayBuffer(
217
+ args.buffer_size,
218
+ envs.single_observation_space,
219
+ envs.single_action_space,
220
+ device,
221
+ handle_timeout_termination=False,
222
+ )
223
+ start_time = time.time()
224
+
225
+ # TRY NOT TO MODIFY: start the game
226
+ obs, _ = envs.reset(seed=args.seed)
227
+ for global_step in range(args.total_timesteps):
228
+ # ALGO LOGIC: put action logic here
229
+ if global_step < args.learning_starts:
230
+ actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
231
+ else:
232
+ actions, _, _ = actor.get_action(torch.Tensor(obs).to(device))
233
+ actions = actions.detach().cpu().numpy()
234
+
235
+ # TRY NOT TO MODIFY: execute the game and log data.
236
+ next_obs, rewards, terminations, truncations, infos = envs.step(actions)
237
+
238
+ # TRY NOT TO MODIFY: record rewards for plotting purposes
239
+ if "final_info" in infos:
240
+ for info in infos["final_info"]:
241
+ print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
242
+ writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
243
+ writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
244
+ break
245
+
246
+ # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
247
+ real_next_obs = next_obs.copy()
248
+ for idx, trunc in enumerate(truncations):
249
+ if trunc:
250
+ real_next_obs[idx] = infos["final_observation"][idx]
251
+ rb.add(obs, real_next_obs, actions, rewards, terminations, infos)
252
+
253
+ # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
254
+ obs = next_obs
255
+
256
+ # ALGO LOGIC: training.
257
+ if global_step > args.learning_starts:
258
+ data = rb.sample(args.batch_size)
259
+ with torch.no_grad():
260
+ next_state_actions, next_state_log_pi, _ = actor.get_action(data.next_observations)
261
+ qf1_next_target = qf1_target(data.next_observations, next_state_actions)
262
+ qf2_next_target = qf2_target(data.next_observations, next_state_actions)
263
+ min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi
264
+ next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target).view(-1)
265
+
266
+ qf1_a_values = qf1(data.observations, data.actions).view(-1)
267
+ qf2_a_values = qf2(data.observations, data.actions).view(-1)
268
+ qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
269
+ qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
270
+ qf_loss = qf1_loss + qf2_loss
271
+
272
+ # optimize the model
273
+ q_optimizer.zero_grad()
274
+ qf_loss.backward()
275
+ q_optimizer.step()
276
+
277
+ if global_step % args.policy_frequency == 0: # TD 3 Delayed update support
278
+ for _ in range(
279
+ args.policy_frequency
280
+ ): # compensate for the delay by doing 'actor_update_interval' instead of 1
281
+ pi, log_pi, _ = actor.get_action(data.observations)
282
+ qf1_pi = qf1(data.observations, pi)
283
+ qf2_pi = qf2(data.observations, pi)
284
+ min_qf_pi = torch.min(qf1_pi, qf2_pi)
285
+ actor_loss = ((alpha * log_pi) - min_qf_pi).mean()
286
+
287
+ actor_optimizer.zero_grad()
288
+ actor_loss.backward()
289
+ actor_optimizer.step()
290
+
291
+ if args.autotune:
292
+ with torch.no_grad():
293
+ _, log_pi, _ = actor.get_action(data.observations)
294
+ alpha_loss = (-log_alpha.exp() * (log_pi + target_entropy)).mean()
295
+
296
+ a_optimizer.zero_grad()
297
+ alpha_loss.backward()
298
+ a_optimizer.step()
299
+ alpha = log_alpha.exp().item()
300
+
301
+ # update the target networks
302
+ if global_step % args.target_network_frequency == 0:
303
+ for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):
304
+ target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
305
+ for param, target_param in zip(qf2.parameters(), qf2_target.parameters()):
306
+ target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
307
+
308
+ if global_step % 100 == 0:
309
+ writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step)
310
+ writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step)
311
+ writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step)
312
+ writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step)
313
+ writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step)
314
+ writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step)
315
+ writer.add_scalar("losses/alpha", alpha, global_step)
316
+ print("SPS:", int(global_step / (time.time() - start_time)))
317
+ writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
318
+ if args.autotune:
319
+ writer.add_scalar("losses/alpha_loss", alpha_loss.item(), global_step)
320
+
321
+ if args.save_model:
322
+ model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
323
+ torch.save((actor.state_dict(), qf1.state_dict(), qf2.state_dict()), model_path)
324
+ print(f"model saved to {model_path}")
325
+ from cleanrl_utils.evals.sac_eval import evaluate
326
+
327
+ episodic_returns = evaluate(
328
+ model_path,
329
+ make_env,
330
+ args.env_id,
331
+ eval_episodes=10,
332
+ run_name=f"{run_name}-eval",
333
+ Model=(Actor, SoftQNetwork),
334
+ device=device,
335
+ )
336
+ for idx, episodic_return in enumerate(episodic_returns):
337
+ writer.add_scalar("eval/episodic_return", episodic_return, idx)
338
+
339
+ if args.upload_model:
340
+ from cleanrl_utils.huggingface import push_to_hub
341
+
342
+ repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
343
+ repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
344
+ push_to_hub(args, episodic_returns, repo_id, "SAC", f"runs/{run_name}", f"videos/{run_name}-eval")
345
+
346
+ envs.close()
347
+ writer.close()
videos/Swimmer-v4__sac_continuous_action__1__1699115910-eval/rl-video-episode-0.mp4 ADDED
Binary file (838 kB). View file
 
videos/Swimmer-v4__sac_continuous_action__1__1699115910-eval/rl-video-episode-1.mp4 ADDED
Binary file (836 kB). View file
 
videos/Swimmer-v4__sac_continuous_action__1__1699115910-eval/rl-video-episode-8.mp4 ADDED
Binary file (837 kB). View file