Create Agent_class.py
Browse files- Agent_class.py +49 -0
Agent_class.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class ParameterisedPolicy(torch.nn.Module):
|
4 |
+
"""
|
5 |
+
REINFORCE RL agent class. Returns action when the ParameterisedPolicy.act(observation) is used.
|
6 |
+
observation is a gym state vector.
|
7 |
+
obs_len - length of the state vector
|
8 |
+
act_space_len - length of the action vector
|
9 |
+
|
10 |
+
"""
|
11 |
+
def __init__(self, obs_len=8, act_space_len=2):
|
12 |
+
super().__init__()
|
13 |
+
self.deterministic = False
|
14 |
+
self.continuous = True
|
15 |
+
self.obs_len = obs_len
|
16 |
+
self.act_space_len = act_space_len
|
17 |
+
self.lin_1 = torch.nn.Linear(self.obs_len, 256)
|
18 |
+
self.rel_1 = torch.nn.ReLU()
|
19 |
+
|
20 |
+
self.lin_2 = torch.nn.Linear(256, 128)
|
21 |
+
self.rel_2 = torch.nn.ReLU()
|
22 |
+
|
23 |
+
self.lin_3 = torch.nn.Linear(128, self.act_space_len)
|
24 |
+
|
25 |
+
self.lin_4 = torch.nn.Linear(128, self.act_space_len)
|
26 |
+
self.elu = torch.nn.ELU()
|
27 |
+
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
x = self.lin_1(x)
|
31 |
+
x = self.rel_1(x)
|
32 |
+
|
33 |
+
x = self.lin_2(x)
|
34 |
+
x = self.rel_2(x)
|
35 |
+
|
36 |
+
mu = self.lin_3(x)
|
37 |
+
|
38 |
+
x = self.lin_4(x)
|
39 |
+
sigma = self.elu(x) + 1.000001
|
40 |
+
|
41 |
+
return mu, sigma
|
42 |
+
|
43 |
+
def act(self, observation):
|
44 |
+
|
45 |
+
(mus, sigmas) = self.forward(torch.tensor(observation, dtype=torch.float32))
|
46 |
+
m = torch.distributions.normal.Normal(mus, sigmas)
|
47 |
+
action = m.sample().detach().numpy()
|
48 |
+
|
49 |
+
return action
|