|
import copy |
|
import random |
|
from typing import Union, Tuple, List, Dict |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from ding.model import FCEncoder, ConvEncoder |
|
from ding.reward_model.base_reward_model import BaseRewardModel |
|
from ding.torch_utils.data_helper import to_tensor |
|
from ding.utils import RunningMeanStd |
|
from ding.utils import SequenceType, REWARD_MODEL_REGISTRY |
|
from easydict import EasyDict |
|
|
|
|
|
class RNDNetwork(nn.Module): |
|
|
|
def __init__(self, obs_shape: Union[int, SequenceType], hidden_size_list: SequenceType) -> None: |
|
super(RNDNetwork, self).__init__() |
|
if isinstance(obs_shape, int) or len(obs_shape) == 1: |
|
self.target = FCEncoder(obs_shape, hidden_size_list) |
|
self.predictor = FCEncoder(obs_shape, hidden_size_list) |
|
elif len(obs_shape) == 3: |
|
self.target = ConvEncoder(obs_shape, hidden_size_list) |
|
self.predictor = ConvEncoder(obs_shape, hidden_size_list) |
|
else: |
|
raise KeyError( |
|
"not support obs_shape for pre-defined encoder: {}, please customize your own RND model". |
|
format(obs_shape) |
|
) |
|
for param in self.target.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
predict_feature = self.predictor(obs) |
|
with torch.no_grad(): |
|
target_feature = self.target(obs) |
|
return predict_feature, target_feature |
|
|
|
|
|
class RNDNetworkRepr(nn.Module): |
|
""" |
|
Overview: |
|
The RND reward model class (https://arxiv.org/abs/1810.12894v1) with representation network. |
|
""" |
|
|
|
def __init__(self, obs_shape: Union[int, SequenceType], latent_shape: Union[int, SequenceType], hidden_size_list: SequenceType, |
|
representation_network) -> None: |
|
super(RNDNetworkRepr, self).__init__() |
|
self.representation_network = representation_network |
|
if isinstance(obs_shape, int) or len(obs_shape) == 1: |
|
self.target = FCEncoder(obs_shape, hidden_size_list) |
|
self.predictor = FCEncoder(latent_shape, hidden_size_list) |
|
elif len(obs_shape) == 3: |
|
self.target = ConvEncoder(obs_shape, hidden_size_list) |
|
self.predictor = ConvEncoder(latent_shape, hidden_size_list) |
|
else: |
|
raise KeyError( |
|
"not support obs_shape for pre-defined encoder: {}, please customize your own RND model". |
|
format(obs_shape) |
|
) |
|
for param in self.target.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
predict_feature = self.predictor(self.representation_network(obs)) |
|
with torch.no_grad(): |
|
target_feature = self.target(obs) |
|
|
|
return predict_feature, target_feature |
|
|
|
|
|
@REWARD_MODEL_REGISTRY.register('rnd_muzero') |
|
class RNDRewardModel(BaseRewardModel): |
|
""" |
|
Overview: |
|
The RND reward model class (https://arxiv.org/abs/1810.12894v1) modified for MuZero. |
|
Interface: |
|
``estimate``, ``train``, ``collect_data``, ``clear_data``, \ |
|
``__init__``, ``_train``, ``load_state_dict``, ``state_dict`` |
|
Config: |
|
== ==================== ===== ============= ======================================= ======================= |
|
ID Symbol Type Default Value Description Other(Shape) |
|
== ==================== ===== ============= ======================================= ======================= |
|
1 ``type`` str rnd | Reward model register name, refer | |
|
| to registry ``REWARD_MODEL_REGISTRY`` | |
|
2 | ``intrinsic_`` str add | the intrinsic reward type | including add, new |
|
| ``reward_type`` | | , or assign |
|
3 | ``learning_rate`` float 0.001 | The step size of gradient descent | |
|
4 | ``batch_size`` int 64 | Training batch size | |
|
5 | ``hidden`` list [64, 64, | the MLP layer shape | |
|
| ``_size_list`` (int) 128] | | |
|
6 | ``update_per_`` int 100 | Number of updates per collect | |
|
| ``collect`` | | |
|
7 | ``input_norm`` bool True | Observation normalization | |
|
8 | ``input_norm_`` int 0 | min clip value for obs normalization | |
|
| ``clamp_min`` |
|
9 | ``input_norm_`` int 1 | max clip value for obs normalization | |
|
| ``clamp_max`` |
|
10 | ``intrinsic_`` float 0.01 | the weight of intrinsic reward | r = w*r_i + r_e |
|
``reward_weight`` |
|
11 | ``extrinsic_`` bool True | Whether to normlize extrinsic reward |
|
``reward_norm`` |
|
12 | ``extrinsic_`` int 1 | the upper bound of the reward |
|
``reward_norm_max`` | normalization |
|
== ==================== ===== ============= ======================================= ======================= |
|
""" |
|
config = dict( |
|
|
|
type='rnd', |
|
|
|
intrinsic_reward_type='add', |
|
|
|
learning_rate=1e-3, |
|
|
|
batch_size=64, |
|
|
|
|
|
|
|
hidden_size_list=[64, 64, 128], |
|
|
|
|
|
|
|
update_per_collect=100, |
|
|
|
input_norm=True, |
|
|
|
input_norm_clamp_min=-1, |
|
|
|
input_norm_clamp_max=1, |
|
|
|
|
|
|
|
intrinsic_reward_weight=0.01, |
|
|
|
|
|
extrinsic_reward_norm=True, |
|
|
|
extrinsic_reward_norm_max=1, |
|
) |
|
|
|
def __init__(self, config: EasyDict, device: str = 'cpu', tb_logger: 'SummaryWriter' = None, |
|
representation_network: nn.Module = None, target_representation_network: nn.Module = None, |
|
use_momentum_representation_network: bool = True) -> None: |
|
super(RNDRewardModel, self).__init__() |
|
self.cfg = config |
|
self.representation_network = representation_network |
|
self.target_representation_network = target_representation_network |
|
self.use_momentum_representation_network = use_momentum_representation_network |
|
self.input_type = self.cfg.input_type |
|
assert self.input_type in ['obs', 'latent_state', 'obs_latent_state'], self.input_type |
|
self.device = device |
|
assert self.device == "cpu" or self.device.startswith("cuda") |
|
self.rnd_buffer_size = config.rnd_buffer_size |
|
self.intrinsic_reward_type = self.cfg.intrinsic_reward_type |
|
if tb_logger is None: |
|
from tensorboardX import SummaryWriter |
|
tb_logger = SummaryWriter('rnd_reward_model') |
|
self.tb_logger = tb_logger |
|
if self.input_type == 'obs': |
|
self.input_shape = self.cfg.obs_shape |
|
self.reward_model = RNDNetwork(self.input_shape, self.cfg.hidden_size_list).to(self.device) |
|
elif self.input_type == 'latent_state': |
|
self.input_shape = self.cfg.latent_state_dim |
|
self.reward_model = RNDNetwork(self.input_shape, self.cfg.hidden_size_list).to(self.device) |
|
elif self.input_type == 'obs_latent_state': |
|
if self.use_momentum_representation_network: |
|
self.reward_model = RNDNetworkRepr(self.cfg.obs_shape, self.cfg.latent_state_dim, self.cfg.hidden_size_list[0:-1], |
|
self.target_representation_network).to(self.device) |
|
else: |
|
self.reward_model = RNDNetworkRepr(self.cfg.obs_shape, self.cfg.latent_state_dim, self.cfg.hidden_size_list[0:-1], |
|
self.representation_network).to(self.device) |
|
|
|
assert self.intrinsic_reward_type in ['add', 'new', 'assign'] |
|
if self.input_type in ['obs', 'obs_latent_state']: |
|
self.train_obs = [] |
|
if self.input_type == 'latent_state': |
|
self.train_latent_state = [] |
|
|
|
self._optimizer_rnd = torch.optim.Adam( |
|
self.reward_model.predictor.parameters(), lr=self.cfg.learning_rate, weight_decay=self.cfg.weight_decay |
|
) |
|
|
|
self._running_mean_std_rnd_reward = RunningMeanStd(epsilon=1e-4) |
|
self._running_mean_std_rnd_obs = RunningMeanStd(epsilon=1e-4) |
|
self.estimate_cnt_rnd = 0 |
|
self.train_cnt_rnd = 0 |
|
|
|
def _train_with_data_one_step(self) -> None: |
|
if self.input_type in ['obs', 'obs_latent_state']: |
|
train_data = random.sample(self.train_obs, self.cfg.batch_size) |
|
elif self.input_type == 'latent_state': |
|
train_data = random.sample(self.train_latent_state, self.cfg.batch_size) |
|
|
|
train_data = torch.stack(train_data).to(self.device) |
|
|
|
if self.cfg.input_norm: |
|
|
|
self._running_mean_std_rnd_obs.update(train_data.detach().cpu().numpy()) |
|
normalized_train_data = (train_data - to_tensor(self._running_mean_std_rnd_obs.mean).to( |
|
self.device)) / to_tensor( |
|
self._running_mean_std_rnd_obs.std |
|
).to(self.device) |
|
train_data = torch.clamp(normalized_train_data, min=self.cfg.input_norm_clamp_min, |
|
max=self.cfg.input_norm_clamp_max) |
|
|
|
predict_feature, target_feature = self.reward_model(train_data) |
|
loss = F.mse_loss(predict_feature, target_feature) |
|
|
|
self.tb_logger.add_scalar('rnd_reward_model/rnd_mse_loss', loss, self.train_cnt_rnd) |
|
self._optimizer_rnd.zero_grad() |
|
loss.backward() |
|
self._optimizer_rnd.step() |
|
|
|
def train_with_data(self) -> None: |
|
for _ in range(self.cfg.update_per_collect): |
|
|
|
|
|
|
|
|
|
|
|
|
|
self._train_with_data_one_step() |
|
self.train_cnt_rnd += 1 |
|
|
|
def estimate(self, data: list) -> List[Dict]: |
|
""" |
|
Rewrite the reward key in each row of the data. |
|
""" |
|
|
|
|
|
|
|
obs_batch_orig = data[0][0] |
|
target_reward = data[1][0] |
|
batch_size = obs_batch_orig.shape[0] |
|
|
|
obs_batch_tmp = np.reshape(obs_batch_orig, (batch_size, self.cfg.obs_shape, 6)) |
|
|
|
obs_batch_tmp = np.reshape(obs_batch_tmp, (batch_size * 6, self.cfg.obs_shape)) |
|
|
|
if self.input_type == 'latent_state': |
|
with torch.no_grad(): |
|
latent_state = self.representation_network(torch.from_numpy(obs_batch_tmp).to(self.device)) |
|
input_data = latent_state |
|
elif self.input_type in ['obs', 'obs_latent_state']: |
|
input_data = to_tensor(obs_batch_tmp).to(self.device) |
|
|
|
|
|
|
|
target_reward_augmented = copy.deepcopy(target_reward) |
|
target_reward_augmented = np.reshape(target_reward_augmented, (batch_size * 6, 1)) |
|
|
|
if self.cfg.input_norm: |
|
|
|
input_data = input_data.clone() |
|
|
|
input_data = (input_data - to_tensor(self._running_mean_std_rnd_obs.mean |
|
).to(self.device)) / to_tensor(self._running_mean_std_rnd_obs.std).to( |
|
self.device) |
|
input_data = torch.clamp(input_data, min=self.cfg.input_norm_clamp_min, max=self.cfg.input_norm_clamp_max) |
|
else: |
|
input_data = input_data |
|
with torch.no_grad(): |
|
predict_feature, target_feature = self.reward_model(input_data) |
|
mse = F.mse_loss(predict_feature, target_feature, reduction='none').mean(dim=1) |
|
self._running_mean_std_rnd_reward.update(mse.detach().cpu().numpy()) |
|
|
|
|
|
rnd_reward = (mse - mse.min()) / (mse.max() - mse.min() + 1e-6) |
|
|
|
|
|
self.estimate_cnt_rnd += 1 |
|
self.tb_logger.add_scalar('rnd_reward_model/rnd_reward_max', rnd_reward.max(), self.estimate_cnt_rnd) |
|
self.tb_logger.add_scalar('rnd_reward_model/rnd_reward_mean', rnd_reward.mean(), self.estimate_cnt_rnd) |
|
self.tb_logger.add_scalar('rnd_reward_model/rnd_reward_min', rnd_reward.min(), self.estimate_cnt_rnd) |
|
self.tb_logger.add_scalar('rnd_reward_model/rnd_reward_std', rnd_reward.std(), self.estimate_cnt_rnd) |
|
|
|
rnd_reward = rnd_reward.to(self.device).unsqueeze(1).cpu().numpy() |
|
if self.intrinsic_reward_type == 'add': |
|
if self.cfg.extrinsic_reward_norm: |
|
target_reward_augmented = target_reward_augmented / self.cfg.extrinsic_reward_norm_max + rnd_reward * self.cfg.intrinsic_reward_weight |
|
else: |
|
target_reward_augmented = target_reward_augmented + rnd_reward * self.cfg.intrinsic_reward_weight |
|
elif self.intrinsic_reward_type == 'new': |
|
if self.cfg.extrinsic_reward_norm: |
|
target_reward_augmented = target_reward_augmented / self.cfg.extrinsic_reward_norm_max |
|
elif self.intrinsic_reward_type == 'assign': |
|
target_reward_augmented = rnd_reward |
|
|
|
self.tb_logger.add_scalar('augmented_reward/reward_max', np.max(target_reward_augmented), self.estimate_cnt_rnd) |
|
self.tb_logger.add_scalar('augmented_reward/reward_mean', np.mean(target_reward_augmented), |
|
self.estimate_cnt_rnd) |
|
self.tb_logger.add_scalar('augmented_reward/reward_min', np.min(target_reward_augmented), self.estimate_cnt_rnd) |
|
self.tb_logger.add_scalar('augmented_reward/reward_std', np.std(target_reward_augmented), self.estimate_cnt_rnd) |
|
|
|
|
|
target_reward_augmented = np.reshape(target_reward_augmented, (batch_size, 6, 1)) |
|
data[1][0] = target_reward_augmented |
|
train_data_augmented = data |
|
|
|
return train_data_augmented |
|
|
|
def collect_data(self, data: list) -> None: |
|
|
|
collected_transitions = np.concatenate([game_segment.obs_segment[:300] for game_segment in data[0]], axis=0) |
|
if self.input_type == 'latent_state': |
|
with torch.no_grad(): |
|
self.train_latent_state.extend( |
|
self.representation_network(torch.from_numpy(collected_transitions).to(self.device))) |
|
elif self.input_type == 'obs': |
|
self.train_obs.extend(to_tensor(collected_transitions).to(self.device)) |
|
elif self.input_type == 'obs_latent_state': |
|
self.train_obs.extend(to_tensor(collected_transitions).to(self.device)) |
|
|
|
def clear_old_data(self) -> None: |
|
if self.input_type == 'latent_state': |
|
if len(self.train_latent_state) >= self.cfg.rnd_buffer_size: |
|
self.train_latent_state = self.train_latent_state[-self.cfg.rnd_buffer_size:] |
|
elif self.input_type == 'obs': |
|
if len(self.train_obs) >= self.cfg.rnd_buffer_size: |
|
self.train_obs = self.train_obs[-self.cfg.rnd_buffer_size:] |
|
elif self.input_type == 'obs_latent_state': |
|
if len(self.train_obs) >= self.cfg.rnd_buffer_size: |
|
self.train_obs = self.train_obs[-self.cfg.rnd_buffer_size:] |
|
|
|
def state_dict(self) -> Dict: |
|
return self.reward_model.state_dict() |
|
|
|
def load_state_dict(self, _state_dict: Dict) -> None: |
|
self.reward_model.load_state_dict(_state_dict) |
|
|
|
def clear_data(self): |
|
pass |
|
|
|
def train(self): |
|
pass |
|
|