import math import numpy as np import torch from torch import nn import torch.nn.functional as F from torch import distributions as torchd from ding.torch_utils import MLP from ding.rl_utils import symlog, inv_symlog class Conv2dSame(torch.nn.Conv2d): """ Overview: Conv2dSame Network for dreamerv3. Interfaces: ``__init__``, ``forward`` """ def calc_same_pad(self, i, k, s, d): """ Overview: Calculate the same padding size. Arguments: - i (:obj:`int`): Input size. - k (:obj:`int`): Kernel size. - s (:obj:`int`): Stride size. - d (:obj:`int`): Dilation size. """ return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) def forward(self, x): """ Overview: compute the forward of Conv2dSame. Arguments: - x (:obj:`torch.Tensor`): Input tensor. """ ih, iw = x.size()[-2:] pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0]) pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1]) if pad_h > 0 or pad_w > 0: x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) ret = F.conv2d( x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, ) return ret class DreamerLayerNorm(nn.Module): """ Overview: DreamerLayerNorm Network for dreamerv3. Interfaces: ``__init__``, ``forward`` """ def __init__(self, ch, eps=1e-03): """ Overview: Init the DreamerLayerNorm class. Arguments: - ch (:obj:`int`): Input channel. - eps (:obj:`float`): Epsilon. """ super(DreamerLayerNorm, self).__init__() self.norm = torch.nn.LayerNorm(ch, eps=eps) def forward(self, x): """ Overview: compute the forward of DreamerLayerNorm. Arguments: - x (:obj:`torch.Tensor`): Input tensor. """ x = x.permute(0, 2, 3, 1) x = self.norm(x) x = x.permute(0, 3, 1, 2) return x class DenseHead(nn.Module): """ Overview: DenseHead Network for value head, reward head, and discount head of dreamerv3. Interfaces: ``__init__``, ``forward`` """ def __init__( self, inp_dim, shape, # (255,) layer_num, units, # 512 act='SiLU', norm='LN', dist='normal', std=1.0, outscale=1.0, device='cpu', ): """ Overview: Init the DenseHead class. Arguments: - inp_dim (:obj:`int`): Input dimension. - shape (:obj:`tuple`): Output shape. - layer_num (:obj:`int`): Number of layers. - units (:obj:`int`): Number of units. - act (:obj:`str`): Activation function. - norm (:obj:`str`): Normalization function. - dist (:obj:`str`): Distribution function. - std (:obj:`float`): Standard deviation. - outscale (:obj:`float`): Output scale. - device (:obj:`str`): Device. """ super(DenseHead, self).__init__() self._shape = (shape, ) if isinstance(shape, int) else shape if len(self._shape) == 0: self._shape = (1, ) self._layer_num = layer_num self._units = units self._act = getattr(torch.nn, act)() self._norm = norm self._dist = dist self._std = std self._device = device self.mlp = MLP( inp_dim, self._units, self._units, self._layer_num, layer_fn=nn.Linear, activation=self._act, norm_type=self._norm ) self.mlp.apply(weight_init) self.mean_layer = nn.Linear(self._units, np.prod(self._shape)) self.mean_layer.apply(uniform_weight_init(outscale)) if self._std == "learned": self.std_layer = nn.Linear(self._units, np.prod(self._shape)) self.std_layer.apply(uniform_weight_init(outscale)) def forward(self, features): """ Overview: compute the forward of DenseHead. Arguments: - features (:obj:`torch.Tensor`): Input tensor. """ x = features out = self.mlp(x) # (batch, time, _units=512) mean = self.mean_layer(out) # (batch, time, 255) if self._std == "learned": std = self.std_layer(out) else: std = self._std if self._dist == "normal": return ContDist(torchd.independent.Independent(torchd.normal.Normal(mean, std), len(self._shape))) elif self._dist == "huber": return ContDist(torchd.independent.Independent(UnnormalizedHuber(mean, std, 1.0), len(self._shape))) elif self._dist == "binary": return Bernoulli(torchd.independent.Independent(torchd.bernoulli.Bernoulli(logits=mean), len(self._shape))) elif self._dist == "twohot_symlog": return TwoHotDistSymlog(logits=mean, device=self._device) raise NotImplementedError(self._dist) class ActionHead(nn.Module): """ Overview: ActionHead Network for action head of dreamerv3. Interfaces: ``__init__``, ``forward`` """ def __init__( self, inp_dim, size, layers, units, act=nn.ELU, norm=nn.LayerNorm, dist="trunc_normal", init_std=0.0, min_std=0.1, max_std=1.0, temp=0.1, outscale=1.0, unimix_ratio=0.01, ): """ Overview: Initialize the ActionHead class. Arguments: - inp_dim (:obj:`int`): Input dimension. - size (:obj:`int`): Output size. - layers (:obj:`int`): Number of layers. - units (:obj:`int`): Number of units. - act (:obj:`str`): Activation function. - norm (:obj:`str`): Normalization function. - dist (:obj:`str`): Distribution function. - init_std (:obj:`float`): Initial standard deviation. - min_std (:obj:`float`): Minimum standard deviation. - max_std (:obj:`float`): Maximum standard deviation. - temp (:obj:`float`): Temperature. - outscale (:obj:`float`): Output scale. - unimix_ratio (:obj:`float`): Unimix ratio. """ super(ActionHead, self).__init__() self._size = size self._layers = layers self._units = units self._dist = dist self._act = getattr(torch.nn, act) self._norm = getattr(torch.nn, norm) self._min_std = min_std self._max_std = max_std self._init_std = init_std self._unimix_ratio = unimix_ratio self._temp = temp() if callable(temp) else temp pre_layers = [] for index in range(self._layers): pre_layers.append(nn.Linear(inp_dim, self._units, bias=False)) pre_layers.append(self._norm(self._units, eps=1e-03)) pre_layers.append(self._act()) if index == 0: inp_dim = self._units self._pre_layers = nn.Sequential(*pre_layers) self._pre_layers.apply(weight_init) if self._dist in ["tanh_normal", "tanh_normal_5", "normal", "trunc_normal"]: self._dist_layer = nn.Linear(self._units, 2 * self._size) self._dist_layer.apply(uniform_weight_init(outscale)) elif self._dist in ["normal_1", "onehot", "onehot_gumbel"]: self._dist_layer = nn.Linear(self._units, self._size) self._dist_layer.apply(uniform_weight_init(outscale)) def forward(self, features): """ Overview: compute the forward of ActionHead. Arguments: - features (:obj:`torch.Tensor`): Input tensor. """ x = features x = self._pre_layers(x) if self._dist == "tanh_normal": x = self._dist_layer(x) mean, std = torch.split(x, 2, -1) mean = torch.tanh(mean) std = F.softplus(std + self._init_std) + self._min_std dist = torchd.normal.Normal(mean, std) dist = torchd.transformed_distribution.TransformedDistribution(dist, TanhBijector()) dist = torchd.independent.Independent(dist, 1) dist = SampleDist(dist) elif self._dist == "tanh_normal_5": x = self._dist_layer(x) mean, std = torch.split(x, 2, -1) mean = 5 * torch.tanh(mean / 5) std = F.softplus(std + 5) + 5 dist = torchd.normal.Normal(mean, std) dist = torchd.transformed_distribution.TransformedDistribution(dist, TanhBijector()) dist = torchd.independent.Independent(dist, 1) dist = SampleDist(dist) elif self._dist == "normal": x = self._dist_layer(x) mean, std = torch.split(x, [self._size] * 2, -1) std = (self._max_std - self._min_std) * torch.sigmoid(std + 2.0) + self._min_std dist = torchd.normal.Normal(torch.tanh(mean), std) dist = ContDist(torchd.independent.Independent(dist, 1)) elif self._dist == "normal_1": x = self._dist_layer(x) dist = torchd.normal.Normal(mean, 1) dist = ContDist(torchd.independent.Independent(dist, 1)) elif self._dist == "trunc_normal": x = self._dist_layer(x) mean, std = torch.split(x, [self._size] * 2, -1) mean = torch.tanh(mean) std = 2 * torch.sigmoid(std / 2) + self._min_std dist = SafeTruncatedNormal(mean, std, -1, 1) dist = ContDist(torchd.independent.Independent(dist, 1)) elif self._dist == "onehot": x = self._dist_layer(x) dist = OneHotDist(x, unimix_ratio=self._unimix_ratio) elif self._dist == "onehot_gumble": x = self._dist_layer(x) temp = self._temp dist = ContDist(torchd.gumbel.Gumbel(x, 1 / temp)) else: raise NotImplementedError(self._dist) return dist class SampleDist: """ Overview: A kind of sample Dist for ActionHead of dreamerv3. Interfaces: ``__init__``, ``mean``, ``mode``, ``entropy`` """ def __init__(self, dist, samples=100): """ Overview: Initialize the SampleDist class. Arguments: - dist (:obj:`torch.Tensor`): Distribution. - samples (:obj:`int`): Number of samples. """ self._dist = dist self._samples = samples def mean(self): """ Overview: Calculate the mean of the distribution. """ samples = self._dist.sample(self._samples) return torch.mean(samples, 0) def mode(self): """ Overview: Calculate the mode of the distribution. """ sample = self._dist.sample(self._samples) logprob = self._dist.log_prob(sample) return sample[torch.argmax(logprob)][0] def entropy(self): """ Overview: Calculate the entropy of the distribution. """ sample = self._dist.sample(self._samples) logprob = self.log_prob(sample) return -torch.mean(logprob, 0) class OneHotDist(torchd.one_hot_categorical.OneHotCategorical): """ Overview: A kind of onehot Dist for dreamerv3. Interfaces: ``__init__``, ``mode``, ``sample`` """ def __init__(self, logits=None, probs=None, unimix_ratio=0.0): """ Overview: Initialize the OneHotDist class. Arguments: - logits (:obj:`torch.Tensor`): Logits. - probs (:obj:`torch.Tensor`): Probabilities. - unimix_ratio (:obj:`float`): Unimix ratio. """ if logits is not None and unimix_ratio > 0.0: probs = F.softmax(logits, dim=-1) probs = probs * (1.0 - unimix_ratio) + unimix_ratio / probs.shape[-1] logits = torch.log(probs) super().__init__(logits=logits, probs=None) else: super().__init__(logits=logits, probs=probs) def mode(self): """ Overview: Calculate the mode of the distribution. """ _mode = F.one_hot(torch.argmax(super().logits, axis=-1), super().logits.shape[-1]) return _mode.detach() + super().logits - super().logits.detach() def sample(self, sample_shape=(), seed=None): """ Overview: Sample from the distribution. Arguments: - sample_shape (:obj:`tuple`): Sample shape. - seed (:obj:`int`): Seed. """ if seed is not None: raise ValueError('need to check') sample = super().sample(sample_shape) probs = super().probs while len(probs.shape) < len(sample.shape): probs = probs[None] sample += probs - probs.detach() return sample class TwoHotDistSymlog: """ Overview: A kind of twohotsymlog Dist for dreamerv3. Interfaces: ``__init__``, ``mode``, ``mean``, ``log_prob``, ``log_prob_target`` """ def __init__(self, logits=None, low=-20.0, high=20.0, device='cpu'): """ Overview: Initialize the TwoHotDistSymlog class. Arguments: - logits (:obj:`torch.Tensor`): Logits. - low (:obj:`float`): Low. - high (:obj:`float`): High. - device (:obj:`str`): Device. """ self.logits = logits self.probs = torch.softmax(logits, -1) self.buckets = torch.linspace(low, high, steps=255).to(device) self.width = (self.buckets[-1] - self.buckets[0]) / 255 def mean(self): """ Overview: Calculate the mean of the distribution. """ _mean = self.probs * self.buckets return inv_symlog(torch.sum(_mean, dim=-1, keepdim=True)) def mode(self): """ Overview: Calculate the mode of the distribution. """ _mode = self.probs * self.buckets return inv_symlog(torch.sum(_mode, dim=-1, keepdim=True)) # Inside OneHotCategorical, log_prob is calculated using only max element in targets def log_prob(self, x): """ Overview: Calculate the log probability of the distribution. Arguments: - x (:obj:`torch.Tensor`): Input tensor. """ x = symlog(x) # x(time, batch, 1) below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) - 1 above = len(self.buckets) - torch.sum((self.buckets > x[..., None]).to(torch.int32), dim=-1) below = torch.clip(below, 0, len(self.buckets) - 1) above = torch.clip(above, 0, len(self.buckets) - 1) equal = (below == above) dist_to_below = torch.where(equal, 1, torch.abs(self.buckets[below] - x)) dist_to_above = torch.where(equal, 1, torch.abs(self.buckets[above] - x)) total = dist_to_below + dist_to_above weight_below = dist_to_above / total weight_above = dist_to_below / total target = ( F.one_hot(below, num_classes=len(self.buckets)) * weight_below[..., None] + F.one_hot(above, num_classes=len(self.buckets)) * weight_above[..., None] ) log_pred = self.logits - torch.logsumexp(self.logits, -1, keepdim=True) target = target.squeeze(-2) return (target * log_pred).sum(-1) def log_prob_target(self, target): """ Overview: Calculate the log probability of the target. Arguments: - target (:obj:`torch.Tensor`): Target tensor. """ log_pred = super().logits - torch.logsumexp(super().logits, -1, keepdim=True) return (target * log_pred).sum(-1) class SymlogDist: """ Overview: A kind of Symlog Dist for dreamerv3. Interfaces: ``__init__``, ``entropy``, ``mode``, ``mean``, ``log_prob`` """ def __init__(self, mode, dist='mse', aggregation='sum', tol=1e-8, dim_to_reduce=[-1, -2, -3]): """ Overview: Initialize the SymlogDist class. Arguments: - mode (:obj:`torch.Tensor`): Mode. - dist (:obj:`str`): Distribution function. - aggregation (:obj:`str`): Aggregation function. - tol (:obj:`float`): Tolerance. - dim_to_reduce (:obj:`list`): Dimension to reduce. """ self._mode = mode self._dist = dist self._aggregation = aggregation self._tol = tol self._dim_to_reduce = dim_to_reduce def mode(self): """ Overview: Calculate the mode of the distribution. """ return inv_symlog(self._mode) def mean(self): """ Overview: Calculate the mean of the distribution. """ return inv_symlog(self._mode) def log_prob(self, value): """ Overview: Calculate the log probability of the distribution. Arguments: - value (:obj:`torch.Tensor`): Input tensor. """ assert self._mode.shape == value.shape if self._dist == 'mse': distance = (self._mode - symlog(value)) ** 2.0 distance = torch.where(distance < self._tol, 0, distance) elif self._dist == 'abs': distance = torch.abs(self._mode - symlog(value)) distance = torch.where(distance < self._tol, 0, distance) else: raise NotImplementedError(self._dist) if self._aggregation == 'mean': loss = distance.mean(self._dim_to_reduce) elif self._aggregation == 'sum': loss = distance.sum(self._dim_to_reduce) else: raise NotImplementedError(self._aggregation) return -loss class ContDist: """ Overview: A kind of ordinary Dist for dreamerv3. Interfaces: ``__init__``, ``entropy``, ``mode``, ``sample``, ``log_prob`` """ def __init__(self, dist=None): """ Overview: Initialize the ContDist class. Arguments: - dist (:obj:`torch.Tensor`): Distribution. """ super().__init__() self._dist = dist self.mean = dist.mean def __getattr__(self, name): """ Overview: Get attribute. Arguments: - name (:obj:`str`): Attribute name. """ return getattr(self._dist, name) def entropy(self): """ Overview: Calculate the entropy of the distribution. """ return self._dist.entropy() def mode(self): """ Overview: Calculate the mode of the distribution. """ return self._dist.mean def sample(self, sample_shape=()): """ Overview: Sample from the distribution. Arguments: - sample_shape (:obj:`tuple`): Sample shape. """ return self._dist.rsample(sample_shape) def log_prob(self, x): return self._dist.log_prob(x) class Bernoulli: """ Overview: A kind of Bernoulli Dist for dreamerv3. Interfaces: ``__init__``, ``entropy``, ``mode``, ``sample``, ``log_prob`` """ def __init__(self, dist=None): """ Overview: Initialize the Bernoulli distribution. Arguments: - dist (:obj:`torch.Tensor`): Distribution. """ super().__init__() self._dist = dist self.mean = dist.mean def __getattr__(self, name): """ Overview: Get attribute. Arguments: - name (:obj:`str`): Attribute name. """ return getattr(self._dist, name) def entropy(self): """ Overview: Calculate the entropy of the distribution. """ return self._dist.entropy() def mode(self): """ Overview: Calculate the mode of the distribution. """ _mode = torch.round(self._dist.mean) return _mode.detach() + self._dist.mean - self._dist.mean.detach() def sample(self, sample_shape=()): """ Overview: Sample from the distribution. Arguments: - sample_shape (:obj:`tuple`): Sample shape. """ return self._dist.rsample(sample_shape) def log_prob(self, x): """ Overview: Calculate the log probability of the distribution. Arguments: - x (:obj:`torch.Tensor`): Input tensor. """ _logits = self._dist.base_dist.logits log_probs0 = -F.softplus(_logits) log_probs1 = -F.softplus(-_logits) return log_probs0 * (1 - x) + log_probs1 * x class UnnormalizedHuber(torchd.normal.Normal): """ Overview: A kind of UnnormalizedHuber Dist for dreamerv3. Interfaces: ``__init__``, ``mode``, ``log_prob`` """ def __init__(self, loc, scale, threshold=1, **kwargs): """ Overview: Initialize the UnnormalizedHuber class. Arguments: - loc (:obj:`torch.Tensor`): Location. - scale (:obj:`torch.Tensor`): Scale. - threshold (:obj:`float`): Threshold. """ super().__init__(loc, scale, **kwargs) self._threshold = threshold def log_prob(self, event): """ Overview: Calculate the log probability of the distribution. Arguments: - event (:obj:`torch.Tensor`): Event. """ return -(torch.sqrt((event - self.mean) ** 2 + self._threshold ** 2) - self._threshold) def mode(self): """ Overview: Calculate the mode of the distribution. """ return self.mean class SafeTruncatedNormal(torchd.normal.Normal): """ Overview: A kind of SafeTruncatedNormal Dist for dreamerv3. Interfaces: ``__init__``, ``sample`` """ def __init__(self, loc, scale, low, high, clip=1e-6, mult=1): """ Overview: Initialize the SafeTruncatedNormal class. Arguments: - loc (:obj:`torch.Tensor`): Location. - scale (:obj:`torch.Tensor`): Scale. - low (:obj:`float`): Low. - high (:obj:`float`): High. - clip (:obj:`float`): Clip. - mult (:obj:`float`): Mult. """ super().__init__(loc, scale) self._low = low self._high = high self._clip = clip self._mult = mult def sample(self, sample_shape): """ Overview: Sample from the distribution. Arguments: - sample_shape (:obj:`tuple`): Sample shape. """ event = super().sample(sample_shape) if self._clip: clipped = torch.clip(event, self._low + self._clip, self._high - self._clip) event = event - event.detach() + clipped.detach() if self._mult: event *= self._mult return event class TanhBijector(torchd.Transform): """ Overview: A kind of TanhBijector Dist for dreamerv3. Interfaces: ``__init__``, ``_forward``, ``_inverse``, ``_forward_log_det_jacobian`` """ def __init__(self, validate_args=False, name='tanh'): """ Overview: Initialize the TanhBijector class. Arguments: - validate_args (:obj:`bool`): Validate arguments. - name (:obj:`str`): Name. """ super().__init__() def _forward(self, x): """ Overview: Calculate the forward of the distribution. Arguments: - x (:obj:`torch.Tensor`): Input tensor. """ return torch.tanh(x) def _inverse(self, y): """ Overview: Calculate the inverse of the distribution. Arguments: - y (:obj:`torch.Tensor`): Input tensor. """ y = torch.where((torch.abs(y) <= 1.), torch.clamp(y, -0.99999997, 0.99999997), y) y = torch.atanh(y) return y def _forward_log_det_jacobian(self, x): """ Overview: Calculate the forward log det jacobian of the distribution. Arguments: - x (:obj:`torch.Tensor`): Input tensor. """ log2 = torch.math.log(2.0) return 2.0 * (log2 - x - torch.softplus(-2.0 * x)) def static_scan(fn, inputs, start): """ Overview: Static scan function. Arguments: - fn (:obj:`function`): Function. - inputs (:obj:`tuple`): Inputs. - start (:obj:`torch.Tensor`): Start tensor. """ last = start # {logit, stoch, deter:[batch_size, self._deter]} indices = range(inputs[0].shape[0]) flag = True for index in indices: inp = lambda x: (_input[x] for _input in inputs) # inputs:(action:(time, batch, 6), embed:(time, batch, 4096)) last = fn(last, *inp(index)) # post, prior if flag: if isinstance(last, dict): outputs = {key: value.clone().unsqueeze(0) for key, value in last.items()} else: outputs = [] for _last in last: if isinstance(_last, dict): outputs.append({key: value.clone().unsqueeze(0) for key, value in _last.items()}) else: outputs.append(_last.clone().unsqueeze(0)) flag = False else: if isinstance(last, dict): for key in last.keys(): outputs[key] = torch.cat([outputs[key], last[key].unsqueeze(0)], dim=0) else: for j in range(len(outputs)): if isinstance(last[j], dict): for key in last[j].keys(): outputs[j][key] = torch.cat([outputs[j][key], last[j][key].unsqueeze(0)], dim=0) else: outputs[j] = torch.cat([outputs[j], last[j].unsqueeze(0)], dim=0) if isinstance(last, dict): outputs = [outputs] return outputs def weight_init(m): """ Overview: weight_init for Linear, Conv2d, ConvTranspose2d, and LayerNorm. Arguments: - m (:obj:`torch.nn`): Module. """ if isinstance(m, nn.Linear): in_num = m.in_features out_num = m.out_features denoms = (in_num + out_num) / 2.0 scale = 1.0 / denoms std = np.sqrt(scale) / 0.87962566103423978 nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0) if hasattr(m.bias, 'data'): m.bias.data.fill_(0.0) elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): space = m.kernel_size[0] * m.kernel_size[1] in_num = space * m.in_channels out_num = space * m.out_channels denoms = (in_num + out_num) / 2.0 scale = 1.0 / denoms std = np.sqrt(scale) / 0.87962566103423978 nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0) if hasattr(m.bias, 'data'): m.bias.data.fill_(0.0) elif isinstance(m, nn.LayerNorm): m.weight.data.fill_(1.0) if hasattr(m.bias, 'data'): m.bias.data.fill_(0.0) def uniform_weight_init(given_scale): """ Overview: weight_init for Linear and LayerNorm. Arguments: - given_scale (:obj:`float`): Given scale. """ def f(m): if isinstance(m, nn.Linear): in_num = m.in_features out_num = m.out_features denoms = (in_num + out_num) / 2.0 scale = given_scale / denoms limit = np.sqrt(3 * scale) nn.init.uniform_(m.weight.data, a=-limit, b=limit) if hasattr(m.bias, 'data'): m.bias.data.fill_(0.0) elif isinstance(m, nn.LayerNorm): m.weight.data.fill_(1.0) if hasattr(m.bias, 'data'): m.bias.data.fill_(0.0) return f