|
import torch |
|
|
|
class ParameterisedPolicy(torch.nn.Module): |
|
""" |
|
REINFORCE RL agent class. Returns action when the ParameterisedPolicy.act(observation) method is used. |
|
observation is a gym state vector. |
|
obs_len - length of the state vector |
|
act_space_len - length of the action vector |
|
|
|
""" |
|
def __init__(self, obs_len=8, act_space_len=2): |
|
super().__init__() |
|
self.obs_len = obs_len |
|
self.act_space_len = act_space_len |
|
self.lin_1 = torch.nn.Linear(self.obs_len, 256) |
|
self.rel_1 = torch.nn.ReLU() |
|
|
|
self.lin_2 = torch.nn.Linear(256, 128) |
|
self.rel_2 = torch.nn.ReLU() |
|
|
|
self.lin_3 = torch.nn.Linear(128, self.act_space_len) |
|
|
|
self.lin_4 = torch.nn.Linear(128, self.act_space_len) |
|
self.elu = torch.nn.ELU() |
|
|
|
|
|
def forward(self, x): |
|
x = self.lin_1(x) |
|
x = self.rel_1(x) |
|
|
|
x = self.lin_2(x) |
|
x = self.rel_2(x) |
|
|
|
mu = self.lin_3(x) |
|
|
|
x = self.lin_4(x) |
|
sigma = self.elu(x) + 1.000001 |
|
|
|
return mu, sigma |
|
|
|
def act(self, observation): |
|
""" |
|
Method returns action when gym state vector is passed. |
|
""" |
|
(mus, sigmas) = self.forward(torch.tensor(observation, dtype=torch.float32)) |
|
m = torch.distributions.normal.Normal(mus, sigmas) |
|
action = m.sample().detach().numpy() |
|
|
|
return action |