emedinac's picture
adding amass and h36m models
2f1078d verified
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from ..layers import deformable_conv, SE
torch.manual_seed(0)
# This is the simple CNN layer,that performs a 2-D convolution while maintaining the dimensions of the input(except for the features dimension)
class CNN_layer(nn.Module):
def __init__(self,
in_ch,
out_ch,
kernel_size,
dropout,
bias=True):
super(CNN_layer, self).__init__()
self.kernel_size = kernel_size
padding = (
(kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) # padding so that both dimensions are maintained
assert kernel_size[0] % 2 == 1 and kernel_size[1] % 2 == 1
self.block1 = [nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=padding, dilation=(1, 1)),
nn.BatchNorm2d(out_ch),
nn.Dropout(dropout, inplace=True),
]
self.block1 = nn.Sequential(*self.block1)
def forward(self, x):
output = self.block1(x)
return output
class FPN(nn.Module):
def __init__(self, in_ch,
out_ch,
kernel, # (3,1)
dropout,
reduction,
):
super(FPN, self).__init__()
kernel_size = kernel if isinstance(kernel, (tuple, list)) else (kernel, kernel)
padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
pad1 = (padding[0], padding[1])
pad2 = (padding[0] + pad1[0], padding[1] + pad1[1])
pad3 = (padding[0] + pad2[0], padding[1] + pad2[1])
dil1 = (1, 1)
dil2 = (1 + pad1[0], 1 + pad1[1])
dil3 = (1 + pad2[0], 1 + pad2[1])
self.block1 = nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=pad1, dilation=dil1),
nn.BatchNorm2d(out_ch),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
)
self.block2 = nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=pad2, dilation=dil2),
nn.BatchNorm2d(out_ch),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
)
self.block3 = nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=pad3, dilation=dil3),
nn.BatchNorm2d(out_ch),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
)
self.pooling = nn.AdaptiveAvgPool2d((1, 1)) # Action Context.
self.compress = nn.Conv2d(out_ch * 3 + in_ch,
out_ch,
kernel_size=(1, 1)) # PRELU is outside the loop, check at the end of the code.
def forward(self, x):
b, dim, joints, seq = x.shape
global_action = F.interpolate(self.pooling(x), (joints, seq))
out = torch.cat((self.block1(x), self.block2(x), self.block3(x), global_action), dim=1)
out = self.compress(out)
return out
def mish(x):
return (x * torch.tanh(F.softplus(x)))
class ConvTemporalGraphical(nn.Module):
# Source : https://github.com/yysijie/st-gcn/blob/master/net/st_gcn.py
r"""The basic module for applying a graph convolution.
Args:
Shape:
- Input: Input graph sequence in :math:`(N, in_ch, T_{in}, V)` format
- Output: Outpu graph sequence in :math:`(N, out_ch, T_{out}, V)` format
where
:math:`N` is a batch size,
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
:math:`T_{in}/T_{out}` is a length of input/output sequence,
:math:`V` is the number of graph nodes.
"""
def __init__(self, time_dim, joints_dim, domain, interpratable):
super(ConvTemporalGraphical, self).__init__()
if domain == "time":
# learnable, graph-agnostic 3-d adjacency matrix(or edge importance matrix)
size = joints_dim
if not interpratable:
self.A = nn.Parameter(torch.FloatTensor(time_dim, size, size))
self.domain = 'nctv,tvw->nctw'
else:
self.domain = 'nctv,ntvw->nctw'
elif domain == "space":
size = time_dim
if not interpratable:
self.A = nn.Parameter(torch.FloatTensor(joints_dim, size, size))
self.domain = 'nctv,vtq->ncqv'
else:
self.domain = 'nctv,nvtq->ncqv'
if not interpratable:
stdv = 1. / math.sqrt(self.A.size(1))
self.A.data.uniform_(-stdv, stdv)
def forward(self, x):
x = torch.einsum(self.domain, (x, self.A))
return x.contiguous()
class Map2Adj(nn.Module):
def __init__(self,
in_ch,
time_dim,
joints_dim,
domain,
dropout,
):
super(Map2Adj, self).__init__()
self.domain = domain
inter_ch = in_ch // 2
self.time_compress = nn.Sequential(nn.Conv2d(in_ch, inter_ch, kernel_size=1, bias=False),
nn.BatchNorm2d(inter_ch),
nn.PReLU(),
nn.Conv2d(inter_ch, inter_ch, kernel_size=(time_dim, 1), bias=False),
nn.BatchNorm2d(inter_ch),
nn.Dropout(dropout, inplace=True),
nn.Conv2d(inter_ch, time_dim, kernel_size=1, bias=False),
)
self.joint_compress = nn.Sequential(nn.Conv2d(in_ch, inter_ch, kernel_size=1, bias=False),
nn.BatchNorm2d(inter_ch),
nn.PReLU(),
nn.Conv2d(inter_ch, inter_ch, kernel_size=(1, joints_dim), bias=False),
nn.BatchNorm2d(inter_ch),
nn.Dropout(dropout, inplace=True),
nn.Conv2d(inter_ch, joints_dim, kernel_size=1, bias=False),
)
if self.domain == "space":
ch = joints_dim
self.perm1 = (0, 1, 2, 3)
self.perm2 = (0, 3, 2, 1)
if self.domain == "time":
ch = time_dim
self.perm1 = (0, 2, 1, 3)
self.perm2 = (0, 1, 2, 3)
inter_ch = ch # // 2
self.expansor = nn.Sequential(nn.Conv2d(ch, inter_ch, kernel_size=1, bias=False),
nn.BatchNorm2d(inter_ch),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
nn.Conv2d(inter_ch, ch, kernel_size=1, bias=False),
)
self.time_compress.apply(self._init_weights)
self.joint_compress.apply(self._init_weights)
self.expansor.apply(self._init_weights)
def _init_weights(self, m, gain=0.05):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight, gain=gain)
if isinstance(m, (nn.Conv2d, nn.Conv1d)):
torch.nn.init.xavier_normal_(m.weight, gain=gain)
if isinstance(m, nn.PReLU):
torch.nn.init.constant_(m.weight, 0.25)
def forward(self, x):
b, dims, seq, joints = x.shape
dim_seq = self.time_compress(x)
dim_space = self.joint_compress(x)
o = torch.matmul(dim_space.permute(self.perm1), dim_seq.permute(self.perm2))
Adj = self.expansor(o)
return Adj
class Domain_GCNN_layer(nn.Module):
"""
Shape:
- Input[0]: Input graph sequence in :math:`(N, in_ch, T_{in}, V)` format
- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
- Output[0]: Outpu graph sequence in :math:`(N, out_ch, T_{out}, V)` format
where
:math:`N` is a batch size,
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
:math:`T_{in}/T_{out}` is a length of input/output sequence,
:math:`V` is the number of graph nodes.
:in_ch= dimension of coordinates
: out_ch=dimension of coordinates
+
"""
def __init__(self,
in_ch,
out_ch,
kernel_size,
stride,
time_dim,
joints_dim,
domain,
interpratable,
dropout,
bias=True):
super(Domain_GCNN_layer, self).__init__()
self.kernel_size = kernel_size
assert self.kernel_size[0] % 2 == 1
assert self.kernel_size[1] % 2 == 1
padding = ((self.kernel_size[0] - 1) // 2, (self.kernel_size[1] - 1) // 2)
self.interpratable = interpratable
self.domain = domain
self.gcn = ConvTemporalGraphical(time_dim, joints_dim, domain, interpratable)
self.tcn = nn.Sequential(nn.Conv2d(in_ch,
out_ch,
(self.kernel_size[0], self.kernel_size[1]),
(stride, stride),
padding,
),
nn.BatchNorm2d(out_ch),
nn.Dropout(dropout, inplace=True),
)
if stride != 1 or in_ch != out_ch:
self.residual = nn.Sequential(nn.Conv2d(in_ch,
out_ch,
kernel_size=1,
stride=(1, 1)),
nn.BatchNorm2d(out_ch),
)
else:
self.residual = nn.Identity()
if self.interpratable:
self.map_to_adj = Map2Adj(in_ch,
time_dim,
joints_dim,
domain,
dropout,
)
else:
self.map_to_adj = nn.Identity()
self.prelu = nn.PReLU()
def forward(self, x):
# assert A.shape[0] == self.kernel_size[1], print(A.shape[0],self.kernel_size)
res = self.residual(x)
self.Adj = self.map_to_adj(x)
if self.interpratable:
self.gcn.A = self.Adj
x1 = self.gcn(x)
x2 = self.tcn(x1)
x3 = x2 + res
x4 = self.prelu(x3)
return x4
# Dynamic SpatioTemporal Decompose Graph Convolutions (DSTD-GC)
class DSTD_GC(nn.Module):
"""
Shape:
- Input[0]: Input graph sequence in :math:`(N, in_ch, T_{in}, V)` format
- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
- Output[0]: Outpu graph sequence in :math:`(N, out_ch, T_{out}, V)` format
where
:math:`N` is a batch size,
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
:math:`T_{in}/T_{out}` is a length of input/output sequence,
:math:`V` is the number of graph nodes.
: in_ch= dimension of coordinates
: out_ch=dimension of coordinates
+
"""
def __init__(self,
in_ch,
out_ch,
interpratable,
kernel_size,
stride,
time_dim,
joints_dim,
reduction,
dropout):
super(DSTD_GC, self).__init__()
self.dsgn = Domain_GCNN_layer(in_ch, out_ch, kernel_size, stride,
time_dim, joints_dim, "space", interpratable, dropout)
self.tsgn = Domain_GCNN_layer(in_ch, out_ch, kernel_size, stride,
time_dim, joints_dim, "time", interpratable, dropout)
self.compressor = nn.Sequential(nn.Conv2d(out_ch * 2, out_ch, 1, bias=False),
nn.BatchNorm2d(out_ch),
nn.PReLU(),
SE.SELayer2d(out_ch, reduction=reduction),
)
if stride != 1 or in_ch != out_ch:
self.residual = nn.Sequential(nn.Conv2d(in_ch,
out_ch,
kernel_size=1,
stride=(1, 1)),
nn.BatchNorm2d(out_ch),
)
else:
self.residual = nn.Identity()
# Weighting features
out_ch_c = out_ch // 2 if out_ch // 2 > 1 else 1
self.global_norm = nn.BatchNorm2d(in_ch)
self.conv_s = nn.Sequential(nn.Conv2d(in_ch, out_ch_c, (time_dim, 1), bias=False),
nn.BatchNorm2d(out_ch_c),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
nn.Conv2d(out_ch_c, out_ch, (1, joints_dim), bias=False),
nn.BatchNorm2d(out_ch),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
)
self.conv_t = nn.Sequential(nn.Conv2d(in_ch, out_ch_c, (time_dim, 1), bias=False),
nn.BatchNorm2d(out_ch_c),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
nn.Conv2d(out_ch_c, out_ch, (1, joints_dim), bias=False),
nn.BatchNorm2d(out_ch),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
)
self.map_s = nn.Sequential(nn.Linear(out_ch + 2 + time_dim * 2, out_ch, bias=False),
nn.BatchNorm1d(out_ch),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
nn.Linear(out_ch, out_ch, bias=False),
)
self.map_t = nn.Sequential(nn.Linear(out_ch + 2 + time_dim * 2, out_ch, bias=False),
nn.BatchNorm1d(out_ch),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
nn.Linear(out_ch, out_ch, bias=False),
)
self.prelu1 = nn.Sequential(nn.BatchNorm2d(out_ch),
nn.PReLU(),
)
self.prelu2 = nn.Sequential(nn.BatchNorm2d(out_ch),
nn.PReLU(),
)
def _get_stats_(self, x):
global_avg_pool = x.mean((3, 2)).mean(1, keepdims=True)
global_avg_pool_features = x.mean(3).mean(1)
global_std_pool = x.std((3, 2)).std(1, keepdims=True)
global_std_pool_features = x.std(3).std(1)
return torch.cat((
global_avg_pool,
global_avg_pool_features,
global_std_pool,
global_std_pool_features,
),
dim=1)
def forward(self, x):
b, dim, seq, joints = x.shape # 64, 3, 10, 22
xn = self.global_norm(x)
stats = self._get_stats_(xn)
w1 = torch.cat((self.conv_s(xn).view(b, -1), stats), dim=1)
stats = self._get_stats_(xn)
w2 = torch.cat((self.conv_t(xn).view(b, -1), stats), dim=1)
self.w1 = self.map_s(w1)
self.w2 = self.map_t(w2)
w1 = self.w1[..., None, None]
w2 = self.w2[..., None, None]
x1 = self.dsgn(xn)
x2 = self.tsgn(xn)
out = torch.cat((self.prelu1(w1 * x1), self.prelu2(w2 * x2)), dim=1)
out = self.compressor(out)
return out + self.residual(xn)
class ContextLayer(nn.Module):
def __init__(self,
in_ch,
hidden_ch,
output_seq,
input_seq,
joints,
dims=3,
reduction=8,
dropout=0.1,
):
super(ContextLayer, self).__init__()
self.n_output = output_seq
self.n_joints = joints
self.n_input = input_seq
self.context_conv1 = nn.Sequential(nn.Conv2d(in_ch, hidden_ch, 1, bias=False),
nn.BatchNorm2d(hidden_ch),
nn.PReLU(),
)
self.context_conv2 = nn.Sequential(nn.Conv2d(in_ch, hidden_ch, (input_seq, 1), bias=False),
nn.BatchNorm2d(hidden_ch),
nn.PReLU(),
)
self.context_conv3 = nn.Sequential(nn.Conv2d(in_ch, hidden_ch, 1, bias=False),
nn.BatchNorm2d(hidden_ch),
nn.PReLU(),
)
self.map1 = nn.Sequential(nn.Linear(hidden_ch, self.n_output, bias=False),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
)
self.map2 = nn.Sequential(nn.Linear(hidden_ch, self.n_output, bias=False),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
)
self.map3 = nn.Sequential(nn.Linear(hidden_ch, self.n_output, bias=False),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
)
self.fmap_s = nn.Sequential(nn.Linear(self.n_output * 3, self.n_joints, bias=False),
nn.BatchNorm1d(self.n_joints),
nn.Dropout(dropout, inplace=True), )
self.fmap_t = nn.Sequential(nn.Linear(self.n_output * 3, self.n_output, bias=False),
nn.BatchNorm1d(self.n_output),
nn.Dropout(dropout, inplace=True), )
# inter_ch = self.n_joints # // 2
self.norm_map = nn.Sequential(nn.Conv1d(self.n_output, self.n_output, 1, bias=False),
nn.BatchNorm1d(self.n_output),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
SE.SELayer1d(self.n_output, reduction=reduction),
nn.Conv1d(self.n_output, self.n_output, 1, bias=False),
nn.BatchNorm1d(self.n_output),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
)
self.fconv = nn.Sequential(nn.Conv2d(1, dims, 1, bias=False),
nn.BatchNorm2d(dims),
nn.PReLU(),
nn.Conv2d(dims, dims, 1, bias=False),
nn.BatchNorm2d(dims),
nn.PReLU(),
)
self.SE = SE.SELayer2d(self.n_output, reduction=reduction)
def forward(self, x):
b, _, seq, joint_dim = x.shape
y1 = self.context_conv1(x).max(-1)[0].max(-1)[0]
y2 = self.context_conv2(x).view(b, -1, joint_dim).max(-1)[0]
ym = self.context_conv3(x).mean((2, 3))
y = torch.cat((self.map1(y1), self.map2(y2), self.map3(ym)), dim=1)
self.joints = self.fmap_s(y)
self.displacements = self.fmap_t(y) # .cumsum(1)
self.seq_joints = torch.bmm(self.displacements.unsqueeze(2), self.joints.unsqueeze(1))
self.seq_joints_n = self.norm_map(self.seq_joints)
self.seq_joints_dims = self.fconv(self.seq_joints_n.view(b, 1, self.n_output, self.n_joints))
o = self.SE(self.seq_joints_dims.permute(0, 2, 3, 1))
return o
class MlpMixer_ext(nn.Module):
"""
Shape:
- Input[0]: Input sequence in :math:`(N, in_ch,T_in, V)` format
- Output[0]: Output sequence in :math:`(N,T_out,in_ch, V)` format
where
:math:`N` is a batch size,
:math:`T_{in}/T_{out}` is a length of input/output sequence,
:math:`V` is the number of graph nodes.
:in_ch=number of channels for the coordiantes(default=3)
+
"""
def __init__(self, arch, learn):
super(MlpMixer_ext, self).__init__()
self.clipping = arch.model_params.clipping
self.n_input = arch.model_params.input_n
self.n_output = arch.model_params.output_n
self.n_joints = arch.model_params.joints
self.n_txcnn_layers = arch.model_params.n_txcnn_layers
self.txc_kernel_size = [arch.model_params.txc_kernel_size] * 2
self.input_gcn = arch.model_params.input_gcn
self.output_gcn = arch.model_params.output_gcn
self.reduction = arch.model_params.reduction
self.hidden_dim = arch.model_params.hidden_dim
self.st_gcnns = nn.ModuleList()
self.txcnns = nn.ModuleList()
self.se = nn.ModuleList()
self.in_conv = nn.ModuleList()
self.context_layer = nn.ModuleList()
self.trans = nn.ModuleList()
self.in_ch = 10
self.model_tx = self.input_gcn.model_complexity.copy()
self.model_tx.insert(0, 1) # add 1 in the position 0.
self.input_gcn.model_complexity.insert(0, self.in_ch)
self.input_gcn.model_complexity.append(self.in_ch)
# self.input_gcn.interpretable.insert(0, True)
# self.input_gcn.interpretable.append(False)
for i in range(len(self.input_gcn.model_complexity) - 1):
self.st_gcnns.append(DSTD_GC(self.input_gcn.model_complexity[i],
self.input_gcn.model_complexity[i + 1],
self.input_gcn.interpretable[i],
[1, 1], 1, self.n_input, self.n_joints, self.reduction, learn.dropout))
self.context_layer = ContextLayer(1, self.hidden_dim,
self.n_output, self.n_output, self.n_joints,
3, self.reduction, learn.dropout
)
# at this point, we must permute the dimensions of the gcn network, from (N,C,T,V) into (N,T,C,V)
# with kernel_size[3,3] the dimensions of C,V will be maintained
self.txcnns.append(FPN(self.n_input, self.n_output, self.txc_kernel_size, 0., self.reduction))
for i in range(1, self.n_txcnn_layers):
self.txcnns.append(FPN(self.n_output, self.n_output, self.txc_kernel_size, 0., self.reduction))
self.prelus = nn.ModuleList()
for j in range(self.n_txcnn_layers):
self.prelus.append(nn.PReLU())
self.dim_conversor = nn.Sequential(nn.Conv2d(self.in_ch, 3, 1, bias=False),
nn.BatchNorm2d(3),
nn.PReLU(),
nn.Conv2d(3, 3, 1, bias=False),
nn.PReLU(3), )
self.st_gcnns_o = nn.ModuleList()
self.output_gcn.model_complexity.insert(0, 3)
for i in range(len(self.output_gcn.model_complexity) - 1):
self.st_gcnns_o.append(DSTD_GC(self.output_gcn.model_complexity[i],
self.output_gcn.model_complexity[i + 1],
self.output_gcn.interpretable[i],
[1, 1], 1, self.n_joints, self.n_output, self.reduction, learn.dropout))
self.st_gcnns_o.apply(self._init_weights)
self.st_gcnns.apply(self._init_weights)
self.txcnns.apply(self._init_weights)
def _init_weights(self, m, gain=0.1):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight, gain=gain)
# if isinstance(m, (nn.Conv2d, nn.Conv1d)):
# torch.nn.init.xavier_normal_(m.weight, gain=gain)
if isinstance(m, nn.PReLU):
torch.nn.init.constant_(m.weight, 0.25)
def forward(self, x):
b, seq, joints, dim = x.shape
vel = torch.zeros_like(x)
vel[:, :-1] = torch.diff(x, dim=1)
vel[:, -1] = x[:, -1]
acc = torch.zeros_like(x)
acc[:, :-1] = torch.diff(vel, dim=1)
acc[:, -1] = vel[:, -1]
x1 = torch.cat((x, acc, vel, torch.norm(vel, dim=-1, keepdim=True)), dim=-1)
x2 = x1.permute((0, 3, 1, 2)) # (torch.Size([64, 10, 22, 7])
x3 = x2
for i in range(len(self.st_gcnns)):
x3 = self.st_gcnns[i](x3)
x5 = x3.permute(0, 2, 1, 3) # prepare the input for the Time-Extrapolator-CNN (NCTV->NTCV)
x6 = self.prelus[0](self.txcnns[0](x5))
for i in range(1, self.n_txcnn_layers):
x6 = self.prelus[i](self.txcnns[i](x6)) + x6 # residual connection
x6 = self.dim_conversor(x6.permute(0, 2, 1, 3)).permute(0, 2, 3, 1)
x7 = x6.cumsum(1)
act = self.context_layer(x7.reshape(b, 1, self.n_output, joints * x7.shape[-1]))
x8 = x7.permute(0, 3, 2, 1)
for i in range(len(self.st_gcnns_o)):
x8 = self.st_gcnns_o[i](x8)
x9 = x8.permute(0, 3, 2, 1) + act
return x[:, -1:] + x9,