|
from typing import Union, Tuple, List, Dict |
|
from easydict import EasyDict |
|
|
|
import random |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
|
|
from ding.utils import SequenceType, REWARD_MODEL_REGISTRY |
|
from ding.model import FCEncoder, ConvEncoder |
|
from ding.torch_utils import one_hot |
|
from .base_reward_model import BaseRewardModel |
|
|
|
|
|
def collect_states(iterator: list) -> Tuple[list, list, list]: |
|
states = [] |
|
next_states = [] |
|
actions = [] |
|
for item in iterator: |
|
state = item['obs'] |
|
next_state = item['next_obs'] |
|
action = item['action'] |
|
states.append(state) |
|
next_states.append(next_state) |
|
actions.append(action) |
|
return states, next_states, actions |
|
|
|
|
|
class ICMNetwork(nn.Module): |
|
""" |
|
Intrinsic Curiosity Model (ICM Module) |
|
Implementation of: |
|
[1] Curiosity-driven Exploration by Self-supervised Prediction |
|
Pathak, Agrawal, Efros, and Darrell - UC Berkeley - ICML 2017. |
|
https://arxiv.org/pdf/1705.05363.pdf |
|
[2] Code implementation reference: |
|
https://github.com/pathak22/noreward-rl |
|
https://github.com/jcwleo/curiosity-driven-exploration-pytorch |
|
|
|
1) Embedding observations into a latent space |
|
2) Predicting the action logit given two consecutive embedded observations |
|
3) Predicting the next embedded obs, given the embeded former observation and action |
|
""" |
|
|
|
def __init__(self, obs_shape: Union[int, SequenceType], hidden_size_list: SequenceType, action_shape: int) -> None: |
|
super(ICMNetwork, self).__init__() |
|
if isinstance(obs_shape, int) or len(obs_shape) == 1: |
|
self.feature = FCEncoder(obs_shape, hidden_size_list) |
|
elif len(obs_shape) == 3: |
|
self.feature = ConvEncoder(obs_shape, hidden_size_list) |
|
else: |
|
raise KeyError( |
|
"not support obs_shape for pre-defined encoder: {}, please customize your own ICM model". |
|
format(obs_shape) |
|
) |
|
self.action_shape = action_shape |
|
feature_output = hidden_size_list[-1] |
|
self.inverse_net = nn.Sequential(nn.Linear(feature_output * 2, 512), nn.ReLU(), nn.Linear(512, action_shape)) |
|
self.residual = nn.ModuleList( |
|
[ |
|
nn.Sequential( |
|
nn.Linear(action_shape + 512, 512), |
|
nn.LeakyReLU(), |
|
nn.Linear(512, 512), |
|
) for _ in range(8) |
|
] |
|
) |
|
self.forward_net_1 = nn.Sequential(nn.Linear(action_shape + feature_output, 512), nn.LeakyReLU()) |
|
self.forward_net_2 = nn.Linear(action_shape + 512, feature_output) |
|
|
|
def forward(self, state: torch.Tensor, next_state: torch.Tensor, |
|
action_long: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
r""" |
|
Overview: |
|
Use observation, next_observation and action to genearte ICM module |
|
Parameter updates with ICMNetwork forward setup. |
|
Arguments: |
|
- state (:obj:`torch.Tensor`): |
|
The current state batch |
|
- next_state (:obj:`torch.Tensor`): |
|
The next state batch |
|
- action_long (:obj:`torch.Tensor`): |
|
The action batch |
|
Returns: |
|
- real_next_state_feature (:obj:`torch.Tensor`): |
|
Run with the encoder. Return the real next_state's embedded feature. |
|
- pred_next_state_feature (:obj:`torch.Tensor`): |
|
Run with the encoder and residual network. Return the predicted next_state's embedded feature. |
|
- pred_action_logit (:obj:`torch.Tensor`): |
|
Run with the encoder. Return the predicted action logit. |
|
Shapes: |
|
- state (:obj:`torch.Tensor`): :math:`(B, N)`, where B is the batch size and N is ''obs_shape'' |
|
- next_state (:obj:`torch.Tensor`): :math:`(B, N)`, where B is the batch size and N is ''obs_shape'' |
|
- action_long (:obj:`torch.Tensor`): :math:`(B)`, where B is the batch size'' |
|
- real_next_state_feature (:obj:`torch.Tensor`): :math:`(B, M)`, where B is the batch size |
|
and M is embedded feature size |
|
- pred_next_state_feature (:obj:`torch.Tensor`): :math:`(B, M)`, where B is the batch size |
|
and M is embedded feature size |
|
- pred_action_logit (:obj:`torch.Tensor`): :math:`(B, A)`, where B is the batch size |
|
and A is the ''action_shape'' |
|
""" |
|
action = one_hot(action_long, num=self.action_shape) |
|
encode_state = self.feature(state) |
|
encode_next_state = self.feature(next_state) |
|
|
|
concat_state = torch.cat((encode_state, encode_next_state), 1) |
|
pred_action_logit = self.inverse_net(concat_state) |
|
|
|
|
|
|
|
pred_next_state_feature_orig = torch.cat((encode_state, action), 1) |
|
pred_next_state_feature_orig = self.forward_net_1(pred_next_state_feature_orig) |
|
|
|
|
|
for i in range(4): |
|
pred_next_state_feature = self.residual[i * 2](torch.cat((pred_next_state_feature_orig, action), 1)) |
|
pred_next_state_feature_orig = self.residual[i * 2 + 1]( |
|
torch.cat((pred_next_state_feature, action), 1) |
|
) + pred_next_state_feature_orig |
|
pred_next_state_feature = self.forward_net_2(torch.cat((pred_next_state_feature_orig, action), 1)) |
|
real_next_state_feature = encode_next_state |
|
return real_next_state_feature, pred_next_state_feature, pred_action_logit |
|
|
|
|
|
@REWARD_MODEL_REGISTRY.register('icm') |
|
class ICMRewardModel(BaseRewardModel): |
|
""" |
|
Overview: |
|
The ICM reward model class (https://arxiv.org/pdf/1705.05363.pdf) |
|
Interface: |
|
``estimate``, ``train``, ``collect_data``, ``clear_data``, \ |
|
``__init__``, ``_train``, ``load_state_dict``, ``state_dict`` |
|
Config: |
|
== ==================== ======== ============= ==================================== ======================= |
|
ID Symbol Type Default Value Description Other(Shape) |
|
== ==================== ======== ============= ==================================== ======================= |
|
1 ``type`` str icm | Reward model register name, | |
|
| refer to registry | |
|
| ``REWARD_MODEL_REGISTRY`` | |
|
2 | ``intrinsic_`` str add | the intrinsic reward type | including add, new |
|
| ``reward_type`` | | , or assign |
|
3 | ``learning_rate`` float 0.001 | The step size of gradient descent | |
|
4 | ``obs_shape`` Tuple( 6 | the observation shape | |
|
[int, |
|
list]) |
|
5 | ``action_shape`` int 7 | the action space shape | |
|
6 | ``batch_size`` int 64 | Training batch size | |
|
7 | ``hidden`` list [64, 64, | the MLP layer shape | |
|
| ``_size_list`` (int) 128] | | |
|
8 | ``update_per_`` int 100 | Number of updates per collect | |
|
| ``collect`` | | |
|
9 | ``reverse_scale`` float 1 | the importance weight of the | |
|
| forward and reverse loss | |
|
10 | ``intrinsic_`` float 0.003 | the weight of intrinsic reward | r = w*r_i + r_e |
|
``reward_weight`` |
|
11 | ``extrinsic_`` bool True | Whether to normlize |
|
``reward_norm`` | extrinsic reward |
|
12 | ``extrinsic_`` int 1 | the upper bound of the reward |
|
``reward_norm_max`` | normalization |
|
13 | ``clear_buffer`` int 1 | clear buffer per fixed iters | make sure replay |
|
``_per_iters`` | buffer's data count |
|
| isn't too few. |
|
| (code work in entry) |
|
== ==================== ======== ============= ==================================== ======================= |
|
""" |
|
config = dict( |
|
|
|
type='icm', |
|
|
|
intrinsic_reward_type='add', |
|
|
|
learning_rate=1e-3, |
|
|
|
obs_shape=6, |
|
|
|
action_shape=7, |
|
|
|
batch_size=64, |
|
|
|
hidden_size_list=[64, 64, 128], |
|
|
|
|
|
|
|
update_per_collect=100, |
|
|
|
reverse_scale=1, |
|
|
|
|
|
intrinsic_reward_weight=0.003, |
|
|
|
|
|
extrinsic_reward_norm=True, |
|
|
|
extrinsic_reward_norm_max=1, |
|
|
|
clear_buffer_per_iters=100, |
|
) |
|
|
|
def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None: |
|
super(ICMRewardModel, self).__init__() |
|
self.cfg = config |
|
assert device == "cpu" or device.startswith("cuda") |
|
self.device = device |
|
self.tb_logger = tb_logger |
|
self.reward_model = ICMNetwork(config.obs_shape, config.hidden_size_list, config.action_shape) |
|
self.reward_model.to(self.device) |
|
self.intrinsic_reward_type = config.intrinsic_reward_type |
|
assert self.intrinsic_reward_type in ['add', 'new', 'assign'] |
|
self.train_data = [] |
|
self.train_states = [] |
|
self.train_next_states = [] |
|
self.train_actions = [] |
|
self.opt = optim.Adam(self.reward_model.parameters(), config.learning_rate) |
|
self.ce = nn.CrossEntropyLoss(reduction="mean") |
|
self.forward_mse = nn.MSELoss(reduction='none') |
|
self.reverse_scale = config.reverse_scale |
|
self.res = nn.Softmax(dim=-1) |
|
self.estimate_cnt_icm = 0 |
|
self.train_cnt_icm = 0 |
|
|
|
def _train(self) -> None: |
|
self.train_cnt_icm += 1 |
|
train_data_list = [i for i in range(0, len(self.train_states))] |
|
train_data_index = random.sample(train_data_list, self.cfg.batch_size) |
|
data_states: list = [self.train_states[i] for i in train_data_index] |
|
data_states: torch.Tensor = torch.stack(data_states).to(self.device) |
|
data_next_states: list = [self.train_next_states[i] for i in train_data_index] |
|
data_next_states: torch.Tensor = torch.stack(data_next_states).to(self.device) |
|
data_actions: list = [self.train_actions[i] for i in train_data_index] |
|
data_actions: torch.Tensor = torch.cat(data_actions).to(self.device) |
|
|
|
real_next_state_feature, pred_next_state_feature, pred_action_logit = self.reward_model( |
|
data_states, data_next_states, data_actions |
|
) |
|
inverse_loss = self.ce(pred_action_logit, data_actions.long()) |
|
forward_loss = self.forward_mse(pred_next_state_feature, real_next_state_feature.detach()).mean() |
|
self.tb_logger.add_scalar('icm_reward/forward_loss', forward_loss, self.train_cnt_icm) |
|
self.tb_logger.add_scalar('icm_reward/inverse_loss', inverse_loss, self.train_cnt_icm) |
|
action = torch.argmax(self.res(pred_action_logit), -1) |
|
accuracy = torch.sum(action == data_actions.squeeze(-1)).item() / data_actions.shape[0] |
|
self.tb_logger.add_scalar('icm_reward/action_accuracy', accuracy, self.train_cnt_icm) |
|
loss = self.reverse_scale * inverse_loss + forward_loss |
|
self.tb_logger.add_scalar('icm_reward/total_loss', loss, self.train_cnt_icm) |
|
loss = self.reverse_scale * inverse_loss + forward_loss |
|
self.opt.zero_grad() |
|
loss.backward() |
|
self.opt.step() |
|
|
|
def train(self) -> None: |
|
for _ in range(self.cfg.update_per_collect): |
|
self._train() |
|
|
|
def estimate(self, data: list) -> List[Dict]: |
|
|
|
|
|
train_data_augmented = self.reward_deepcopy(data) |
|
states, next_states, actions = collect_states(train_data_augmented) |
|
states = torch.stack(states).to(self.device) |
|
next_states = torch.stack(next_states).to(self.device) |
|
actions = torch.cat(actions).to(self.device) |
|
with torch.no_grad(): |
|
real_next_state_feature, pred_next_state_feature, _ = self.reward_model(states, next_states, actions) |
|
raw_icm_reward = self.forward_mse(real_next_state_feature, pred_next_state_feature).mean(dim=1) |
|
self.estimate_cnt_icm += 1 |
|
self.tb_logger.add_scalar('icm_reward/raw_icm_reward_max', raw_icm_reward.max(), self.estimate_cnt_icm) |
|
self.tb_logger.add_scalar('icm_reward/raw_icm_reward_mean', raw_icm_reward.mean(), self.estimate_cnt_icm) |
|
self.tb_logger.add_scalar('icm_reward/raw_icm_reward_min', raw_icm_reward.min(), self.estimate_cnt_icm) |
|
self.tb_logger.add_scalar('icm_reward/raw_icm_reward_std', raw_icm_reward.std(), self.estimate_cnt_icm) |
|
icm_reward = (raw_icm_reward - raw_icm_reward.min()) / (raw_icm_reward.max() - raw_icm_reward.min() + 1e-8) |
|
self.tb_logger.add_scalar('icm_reward/icm_reward_max', icm_reward.max(), self.estimate_cnt_icm) |
|
self.tb_logger.add_scalar('icm_reward/icm_reward_mean', icm_reward.mean(), self.estimate_cnt_icm) |
|
self.tb_logger.add_scalar('icm_reward/icm_reward_min', icm_reward.min(), self.estimate_cnt_icm) |
|
self.tb_logger.add_scalar('icm_reward/icm_reward_std', icm_reward.std(), self.estimate_cnt_icm) |
|
icm_reward = (raw_icm_reward - raw_icm_reward.min()) / (raw_icm_reward.max() - raw_icm_reward.min() + 1e-8) |
|
icm_reward = icm_reward.to(self.device) |
|
for item, icm_rew in zip(train_data_augmented, icm_reward): |
|
if self.intrinsic_reward_type == 'add': |
|
if self.cfg.extrinsic_reward_norm: |
|
item['reward'] = item[ |
|
'reward'] / self.cfg.extrinsic_reward_norm_max + icm_rew * self.cfg.intrinsic_reward_weight |
|
else: |
|
item['reward'] = item['reward'] + icm_rew * self.cfg.intrinsic_reward_weight |
|
elif self.intrinsic_reward_type == 'new': |
|
item['intrinsic_reward'] = icm_rew |
|
if self.cfg.extrinsic_reward_norm: |
|
item['reward'] = item['reward'] / self.cfg.extrinsic_reward_norm_max |
|
elif self.intrinsic_reward_type == 'assign': |
|
item['reward'] = icm_rew |
|
|
|
return train_data_augmented |
|
|
|
def collect_data(self, data: list) -> None: |
|
self.train_data.extend(collect_states(data)) |
|
states, next_states, actions = collect_states(data) |
|
self.train_states.extend(states) |
|
self.train_next_states.extend(next_states) |
|
self.train_actions.extend(actions) |
|
|
|
def clear_data(self) -> None: |
|
self.train_data.clear() |
|
self.train_states.clear() |
|
self.train_next_states.clear() |
|
self.train_actions.clear() |
|
|
|
def state_dict(self) -> Dict: |
|
return self.reward_model.state_dict() |
|
|
|
def load_state_dict(self, _state_dict: Dict) -> None: |
|
self.reward_model.load_state_dict(_state_dict) |
|
|