zjowowen's picture
init space
079c32c
raw
history blame
21.9 kB
from typing import Union, List, Dict
from collections import namedtuple
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from ding.utils import list_split, MODEL_REGISTRY, squeeze, SequenceType
def extract(a, t, x_shape):
"""
Overview:
extract output from a through index t.
Arguments:
- a (:obj:`torch.Tensor`): input tensor
- t (:obj:`torch.Tensor`): index tensor
- x_shape (:obj:`torch.Tensor`): shape of x
"""
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1, ) * (len(x_shape) - 1)))
def cosine_beta_schedule(timesteps: int, s: float = 0.008, dtype=torch.float32):
"""
Overview:
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
Arguments:
- timesteps (:obj:`int`): timesteps of diffusion step
- s (:obj:`float`): s
- dtype (:obj:`torch.dtype`): dtype of beta
Return:
Tensor of beta [timesteps,], computing by cosine.
"""
steps = timesteps + 1
x = np.linspace(0, steps, steps)
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
betas_clipped = np.clip(betas, a_min=0, a_max=0.999)
return torch.tensor(betas_clipped, dtype=dtype)
def apply_conditioning(x, conditions, action_dim):
"""
Overview:
add condition into x
Arguments:
- x (:obj:`torch.Tensor`): input tensor
- conditions (:obj:`dict`): condition dict, key is timestep, value is condition
- action_dim (:obj:`int`): action dim
"""
for t, val in conditions.items():
x[:, t, action_dim:] = val.clone()
return x
class DiffusionConv1d(nn.Module):
"""
Overview:
Conv1d with activation and normalization for diffusion models.
Interfaces:
``__init__``, ``forward``
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
padding: int,
activation: nn.Module = None,
n_groups: int = 8
) -> None:
"""
Overview:
Create a 1-dim convlution layer with activation and normalization. This Conv1d have GropuNorm.
And need add 1-dim when compute norm
Arguments:
- in_channels (:obj:`int`): Number of channels in the input tensor
- out_channels (:obj:`int`): Number of channels in the output tensor
- kernel_size (:obj:`int`): Size of the convolving kernel
- padding (:obj:`int`): Zero-padding added to both sides of the input
- activation (:obj:`nn.Module`): the optional activation function
"""
super().__init__()
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
self.norm = nn.GroupNorm(n_groups, out_channels)
self.act = activation
def forward(self, inputs) -> torch.Tensor:
"""
Overview:
compute conv1d for inputs.
Arguments:
- inputs (:obj:`torch.Tensor`): input tensor
Return:
- out (:obj:`torch.Tensor`): output tensor
"""
x = self.conv1(inputs)
# [batch, channels, horizon] -> [batch, channels, 1, horizon]
x = x.unsqueeze(-2)
x = self.norm(x)
# [batch, channels, 1, horizon] -> [batch, channels, horizon]
x = x.squeeze(-2)
out = self.act(x)
return out
class SinusoidalPosEmb(nn.Module):
"""
Overview:
class for computing sin position embeding
Interfaces:
``__init__``, ``forward``
"""
def __init__(self, dim: int) -> None:
"""
Overview:
Initialization of SinusoidalPosEmb class
Arguments:
- dim (:obj:`int`): dimension of embeding
"""
super().__init__()
self.dim = dim
def forward(self, x) -> torch.Tensor:
"""
Overview:
compute sin position embeding
Arguments:
- x (:obj:`torch.Tensor`): input tensor
Return:
- emb (:obj:`torch.Tensor`): output tensor
"""
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=1)
return emb
class Residual(nn.Module):
"""
Overview:
Basic Residual block
Interfaces:
``__init__``, ``forward``
"""
def __init__(self, fn):
"""
Overview:
Initialization of Residual class
Arguments:
- fn (:obj:`nn.Module`): function of residual block
"""
super().__init__()
self.fn = fn
def forward(self, x, *arg, **kwargs):
"""
Overview:
compute residual block
Arguments:
- x (:obj:`torch.Tensor`): input tensor
"""
return self.fn(x, *arg, **kwargs) + x
class LayerNorm(nn.Module):
"""
Overview:
LayerNorm, compute dim = 1, because Temporal input x [batch, dim, horizon]
Interfaces:
``__init__``, ``forward``
"""
def __init__(self, dim, eps=1e-5) -> None:
"""
Overview:
Initialization of LayerNorm class
Arguments:
- dim (:obj:`int`): dimension of input
- eps (:obj:`float`): eps of LayerNorm
"""
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1))
def forward(self, x):
"""
Overview:
compute LayerNorm
Arguments:
- x (:obj:`torch.Tensor`): input tensor
"""
print('x.shape:', x.shape)
var = torch.var(x, dim=1, unbiased=False, keepdim=True)
mean = torch.mean(x, dim=1, keepdim=True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
"""
Overview:
PreNorm, compute dim = 1, because Temporal input x [batch, dim, horizon]
Interfaces:
``__init__``, ``forward``
"""
def __init__(self, dim, fn) -> None:
"""
Overview:
Initialization of PreNorm class
Arguments:
- dim (:obj:`int`): dimension of input
- fn (:obj:`nn.Module`): function of residual block
"""
super().__init__()
self.fn = fn
self.norm = LayerNorm(dim)
def forward(self, x):
"""
Overview:
compute PreNorm
Arguments:
- x (:obj:`torch.Tensor`): input tensor
"""
x = self.norm(x)
return self.fn(x)
class LinearAttention(nn.Module):
"""
Overview:
Linear Attention head
Interfaces:
``__init__``, ``forward``
"""
def __init__(self, dim, heads=4, dim_head=32) -> None:
"""
Overview:
Initialization of LinearAttention class
Arguments:
- dim (:obj:`int`): dimension of input
- heads (:obj:`int`): heads of attention
- dim_head (:obj:`int`): dim of head
"""
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv1d(hidden_dim, dim, 1)
def forward(self, x):
"""
Overview:
compute LinearAttention
Arguments:
- x (:obj:`torch.Tensor`): input tensor
"""
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = map(lambda t: t.reshape(t.shape[0], self.heads, -1, t.shape[-1]), qkv)
q = q * self.scale
k = k.softmax(dim=-1)
context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
out = out.reshape(out.shape[0], -1, out.shape[-1])
return self.to_out(out)
class ResidualTemporalBlock(nn.Module):
"""
Overview:
Residual block of temporal
Interfaces:
``__init__``, ``forward``
"""
def __init__(
self, in_channels: int, out_channels: int, embed_dim: int, kernel_size: int = 5, mish: bool = True
) -> None:
"""
Overview:
Initialization of ResidualTemporalBlock class
Arguments:
- in_channels (:obj:'int'): dim of in_channels
- out_channels (:obj:'int'): dim of out_channels
- embed_dim (:obj:'int'): dim of embeding layer
- kernel_size (:obj:'int'): kernel_size of conv1d
- mish (:obj:'bool'): whether use mish as a activate function
"""
super().__init__()
if mish:
act = nn.Mish()
else:
act = nn.SiLU()
self.blocks = nn.ModuleList(
[
DiffusionConv1d(in_channels, out_channels, kernel_size, kernel_size // 2, act),
DiffusionConv1d(out_channels, out_channels, kernel_size, kernel_size // 2, act),
]
)
self.time_mlp = nn.Sequential(
act,
nn.Linear(embed_dim, out_channels),
)
self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
if in_channels != out_channels else nn.Identity()
def forward(self, x, t):
"""
Overview:
compute residual block
Arguments:
- x (:obj:'tensor'): input tensor
- t (:obj:'tensor'): time tensor
"""
out = self.blocks[0](x) + self.time_mlp(t).unsqueeze(-1)
out = self.blocks[1](out)
return out + self.residual_conv(x)
class DiffusionUNet1d(nn.Module):
"""
Overview:
Diffusion unet for 1d vector data
Interfaces:
``__init__``, ``forward``, ``get_pred``
"""
def __init__(
self,
transition_dim: int,
dim: int = 32,
dim_mults: SequenceType = [1, 2, 4, 8],
returns_condition: bool = False,
condition_dropout: float = 0.1,
calc_energy: bool = False,
kernel_size: int = 5,
attention: bool = False,
) -> None:
"""
Overview:
Initialization of DiffusionUNet1d class
Arguments:
- transition_dim (:obj:'int'): dim of transition, it is obs_dim + action_dim
- dim (:obj:'int'): dim of layer
- dim_mults (:obj:'SequenceType'): mults of dim
- returns_condition (:obj:'bool'): whether use return as a condition
- condition_dropout (:obj:'float'): dropout of returns condition
- calc_energy (:obj:'bool'): whether use calc_energy
- kernel_size (:obj:'int'): kernel_size of conv1d
- attention (:obj:'bool'): whether use attention
"""
super().__init__()
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
if calc_energy:
mish = False
act = nn.SiLU()
else:
mish = True
act = nn.Mish()
self.time_dim = dim
self.returns_dim = dim
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
act,
nn.Linear(dim * 4, dim),
)
self.returns_condition = returns_condition
self.condition_dropout = condition_dropout
self.cale_energy = calc_energy
if self.returns_condition:
self.returns_mlp = nn.Sequential(
nn.Linear(1, dim),
act,
nn.Linear(dim, dim * 4),
act,
nn.Linear(dim * 4, dim),
)
self.mask_dist = torch.distributions.Bernoulli(probs=1 - self.condition_dropout)
embed_dim = 2 * dim
else:
embed_dim = dim
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolution = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolution - 1)
self.downs.append(
nn.ModuleList(
[
ResidualTemporalBlock(dim_in, dim_out, embed_dim, kernel_size, mish=mish),
ResidualTemporalBlock(dim_out, dim_out, embed_dim, kernel_size, mish=mish),
Residual(PreNorm(dim_out, LinearAttention(dim_out))) if attention else nn.Identity(),
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity()
]
)
)
mid_dim = dims[-1]
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim, kernel_size, mish)
self.mid_atten = Residual(PreNorm(mid_dim, LinearAttention(mid_dim))) if attention else nn.Identity()
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim, kernel_size, mish)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolution - 1)
self.ups.append(
nn.ModuleList(
[
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim, kernel_size, mish=mish),
ResidualTemporalBlock(dim_in, dim_in, embed_dim, kernel_size, mish=mish),
Residual(PreNorm(dim_in, LinearAttention(dim_in))) if attention else nn.Identity(),
nn.ConvTranspose1d(dim_in, dim_in, 4, 2, 1) if not is_last else nn.Identity()
]
)
)
self.final_conv = nn.Sequential(
DiffusionConv1d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, activation=act),
nn.Conv1d(dim, transition_dim, 1),
)
def forward(self, x, cond, time, returns=None, use_dropout: bool = True, force_dropout: bool = False):
"""
Overview:
compute diffusion unet forward
Arguments:
- x (:obj:'tensor'): noise trajectory
- cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0
- time (:obj:'int'): timestep of diffusion step
- returns (:obj:'tensor'): condition returns of trajectory, returns is normal return
- use_dropout (:obj:'bool'): Whether use returns condition mask
- force_dropout (:obj:'bool'): Whether use returns condition
"""
if self.cale_energy:
x_inp = x
# [batch, horizon, transition ] -> [batch, transition , horizon]
x = x.transpose(1, 2)
t = self.time_mlp(time)
if self.returns_condition:
assert returns is not None
returns_embed = self.returns_mlp(returns)
if use_dropout:
mask = self.mask_dist.sample(sample_shape=(returns_embed.size(0), 1)).to(returns_embed.device)
returns_embed = mask * returns_embed
if force_dropout:
returns_embed = 0 * returns_embed
t = torch.cat([t, returns_embed], dim=-1)
h = []
for resnet, resnet2, atten, downsample in self.downs:
x = resnet(x, t)
x = resnet2(x, t)
x = atten(x)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_atten(x)
x = self.mid_block2(x, t)
for resnet, resnet2, atten, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, t)
x = resnet2(x, t)
x = atten(x)
x = upsample(x)
x = self.final_conv(x)
# [batch, transition , horizon] -> [batch, horizon, transition ]
x = x.transpose(1, 2)
if self.cale_energy:
# Energy function
energy = ((x - x_inp) ** 2).mean()
grad = torch.autograd.grad(outputs=energy, inputs=x_inp, create_graph=True)
return grad[0]
else:
return x
def get_pred(self, x, cond, time, returns: bool = None, use_dropout: bool = True, force_dropout: bool = False):
"""
Overview:
compute diffusion unet forward
Arguments:
- x (:obj:'tensor'): noise trajectory
- cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0
- time (:obj:'int'): timestep of diffusion step
- returns (:obj:'tensor'): condition returns of trajectory, returns is normal return
- use_dropout (:obj:'bool'): Whether use returns condition mask
- force_dropout (:obj:'bool'): Whether use returns condition
"""
# [batch, horizon, transition ] -> [batch, transition , horizon]
x = x.transpose(1, 2)
t = self.time_mlp(time)
if self.returns_condition:
assert returns is not None
returns_embed = self.returns_mlp(returns)
if use_dropout:
mask = self.mask_dist.sample(sample_shape=(returns_embed.size(0), 1)).to(returns_embed.device)
returns_embed = mask * returns_embed
if force_dropout:
returns_embed = 0 * returns_embed
t = torch.cat([t, returns_embed], dim=-1)
h = []
for resnet, resnet2, downsample in self.downs:
x = resnet(x, t)
x = resnet2(x, t)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_block2(x, t)
for resnet, resnet2, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, t)
x = resnet2(x, t)
x = upsample(x)
x = self.final_conv(x)
# [batch, transition , horizon] -> [batch, horizon, transition ]
x = x.transpose(1, 2)
return x
class TemporalValue(nn.Module):
"""
Overview:
temporal net for value function
Interfaces:
``__init__``, ``forward``
"""
def __init__(
self,
horizon: int,
transition_dim: int,
dim: int = 32,
time_dim: int = None,
out_dim: int = 1,
kernel_size: int = 5,
dim_mults: SequenceType = [1, 2, 4, 8],
) -> None:
"""
Overview:
Initialization of TemporalValue class
Arguments:
- horizon (:obj:'int'): horizon of trajectory
- transition_dim (:obj:'int'): dim of transition, it is obs_dim + action_dim
- dim (:obj:'int'): dim of layer
- time_dim (:obj:'int'): dim of time
- out_dim (:obj:'int'): dim of output
- kernel_size (:obj:'int'): kernel_size of conv1d
- dim_mults (:obj:'SequenceType'): mults of dim
"""
super().__init__()
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
time_dim = time_dim or dim
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.Mish(),
nn.Linear(dim * 4, dim),
)
self.blocks = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out):
self.blocks.append(
nn.ModuleList(
[
ResidualTemporalBlock(dim_in, dim_out, kernel_size=kernel_size, embed_dim=time_dim),
ResidualTemporalBlock(dim_out, dim_out, kernel_size=kernel_size, embed_dim=time_dim),
nn.Conv1d(dim_out, dim_out, 3, 2, 1)
]
)
)
horizon = horizon // 2
mid_dim = dims[-1]
mid_dim_2 = mid_dim // 2
mid_dim_3 = mid_dim // 4
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim_2, kernel_size=kernel_size, embed_dim=time_dim)
self.mid_down1 = nn.Conv1d(mid_dim_2, mid_dim_2, 3, 2, 1)
horizon = horizon // 2
self.mid_block2 = ResidualTemporalBlock(mid_dim_2, mid_dim_3, kernel_size=kernel_size, embed_dim=time_dim)
self.mid_down2 = nn.Conv1d(mid_dim_3, mid_dim_3, 3, 2, 1)
horizon = horizon // 2
fc_dim = mid_dim_3 * max(horizon, 1)
self.final_block = nn.Sequential(
nn.Linear(fc_dim + time_dim, fc_dim // 2),
nn.Mish(),
nn.Linear(fc_dim // 2, out_dim),
)
def forward(self, x, cond, time, *args):
"""
Overview:
compute temporal value forward
Arguments:
- x (:obj:'tensor'): noise trajectory
- cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0
- time (:obj:'int'): timestep of diffusion step
"""
# [batch, horizon, transition ] -> [batch, transition , horizon]
x = x.transpose(1, 2)
t = self.time_mlp(time)
for resnet, resnet2, downsample in self.blocks:
x = resnet(x, t)
x = resnet2(x, t)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_down1(x)
x = self.mid_block2(x, t)
x = self.mid_down2(x)
x = x.view(len(x), -1)
out = self.final_block(torch.cat([x, t], dim=-1))
return out