File size: 26,750 Bytes
2f1078d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 |
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from ..layers import deformable_conv, SE
torch.manual_seed(0)
# This is the simple CNN layer,that performs a 2-D convolution while maintaining the dimensions of the input(except for the features dimension)
class CNN_layer(nn.Module):
def __init__(self,
in_ch,
out_ch,
kernel_size,
dropout,
bias=True):
super(CNN_layer, self).__init__()
self.kernel_size = kernel_size
padding = (
(kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) # padding so that both dimensions are maintained
assert kernel_size[0] % 2 == 1 and kernel_size[1] % 2 == 1
self.block1 = [nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=padding, dilation=(1, 1)),
nn.BatchNorm2d(out_ch),
nn.Dropout(dropout, inplace=True),
]
self.block1 = nn.Sequential(*self.block1)
def forward(self, x):
output = self.block1(x)
return output
class FPN(nn.Module):
def __init__(self, in_ch,
out_ch,
kernel, # (3,1)
dropout,
reduction,
):
super(FPN, self).__init__()
kernel_size = kernel if isinstance(kernel, (tuple, list)) else (kernel, kernel)
padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
pad1 = (padding[0], padding[1])
pad2 = (padding[0] + pad1[0], padding[1] + pad1[1])
pad3 = (padding[0] + pad2[0], padding[1] + pad2[1])
dil1 = (1, 1)
dil2 = (1 + pad1[0], 1 + pad1[1])
dil3 = (1 + pad2[0], 1 + pad2[1])
self.block1 = nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=pad1, dilation=dil1),
nn.BatchNorm2d(out_ch),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
)
self.block2 = nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=pad2, dilation=dil2),
nn.BatchNorm2d(out_ch),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
)
self.block3 = nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=pad3, dilation=dil3),
nn.BatchNorm2d(out_ch),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
)
self.pooling = nn.AdaptiveAvgPool2d((1, 1)) # Action Context.
self.compress = nn.Conv2d(out_ch * 3 + in_ch,
out_ch,
kernel_size=(1, 1)) # PRELU is outside the loop, check at the end of the code.
def forward(self, x):
b, dim, joints, seq = x.shape
global_action = F.interpolate(self.pooling(x), (joints, seq))
out = torch.cat((self.block1(x), self.block2(x), self.block3(x), global_action), dim=1)
out = self.compress(out)
return out
def mish(x):
return (x * torch.tanh(F.softplus(x)))
class ConvTemporalGraphical(nn.Module):
# Source : https://github.com/yysijie/st-gcn/blob/master/net/st_gcn.py
r"""The basic module for applying a graph convolution.
Args:
Shape:
- Input: Input graph sequence in :math:`(N, in_ch, T_{in}, V)` format
- Output: Outpu graph sequence in :math:`(N, out_ch, T_{out}, V)` format
where
:math:`N` is a batch size,
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
:math:`T_{in}/T_{out}` is a length of input/output sequence,
:math:`V` is the number of graph nodes.
"""
def __init__(self, time_dim, joints_dim, domain, interpratable):
super(ConvTemporalGraphical, self).__init__()
if domain == "time":
# learnable, graph-agnostic 3-d adjacency matrix(or edge importance matrix)
size = joints_dim
if not interpratable:
self.A = nn.Parameter(torch.FloatTensor(time_dim, size, size))
self.domain = 'nctv,tvw->nctw'
else:
self.domain = 'nctv,ntvw->nctw'
elif domain == "space":
size = time_dim
if not interpratable:
self.A = nn.Parameter(torch.FloatTensor(joints_dim, size, size))
self.domain = 'nctv,vtq->ncqv'
else:
self.domain = 'nctv,nvtq->ncqv'
if not interpratable:
stdv = 1. / math.sqrt(self.A.size(1))
self.A.data.uniform_(-stdv, stdv)
def forward(self, x):
x = torch.einsum(self.domain, (x, self.A))
return x.contiguous()
class Map2Adj(nn.Module):
def __init__(self,
in_ch,
time_dim,
joints_dim,
domain,
dropout,
):
super(Map2Adj, self).__init__()
self.domain = domain
inter_ch = in_ch // 2
self.time_compress = nn.Sequential(nn.Conv2d(in_ch, inter_ch, kernel_size=1, bias=False),
nn.BatchNorm2d(inter_ch),
nn.PReLU(),
nn.Conv2d(inter_ch, inter_ch, kernel_size=(time_dim, 1), bias=False),
nn.BatchNorm2d(inter_ch),
nn.Dropout(dropout, inplace=True),
nn.Conv2d(inter_ch, time_dim, kernel_size=1, bias=False),
)
self.joint_compress = nn.Sequential(nn.Conv2d(in_ch, inter_ch, kernel_size=1, bias=False),
nn.BatchNorm2d(inter_ch),
nn.PReLU(),
nn.Conv2d(inter_ch, inter_ch, kernel_size=(1, joints_dim), bias=False),
nn.BatchNorm2d(inter_ch),
nn.Dropout(dropout, inplace=True),
nn.Conv2d(inter_ch, joints_dim, kernel_size=1, bias=False),
)
if self.domain == "space":
ch = joints_dim
self.perm1 = (0, 1, 2, 3)
self.perm2 = (0, 3, 2, 1)
if self.domain == "time":
ch = time_dim
self.perm1 = (0, 2, 1, 3)
self.perm2 = (0, 1, 2, 3)
inter_ch = ch # // 2
self.expansor = nn.Sequential(nn.Conv2d(ch, inter_ch, kernel_size=1, bias=False),
nn.BatchNorm2d(inter_ch),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
nn.Conv2d(inter_ch, ch, kernel_size=1, bias=False),
)
self.time_compress.apply(self._init_weights)
self.joint_compress.apply(self._init_weights)
self.expansor.apply(self._init_weights)
def _init_weights(self, m, gain=0.05):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight, gain=gain)
if isinstance(m, (nn.Conv2d, nn.Conv1d)):
torch.nn.init.xavier_normal_(m.weight, gain=gain)
if isinstance(m, nn.PReLU):
torch.nn.init.constant_(m.weight, 0.25)
def forward(self, x):
b, dims, seq, joints = x.shape
dim_seq = self.time_compress(x)
dim_space = self.joint_compress(x)
o = torch.matmul(dim_space.permute(self.perm1), dim_seq.permute(self.perm2))
Adj = self.expansor(o)
return Adj
class Domain_GCNN_layer(nn.Module):
"""
Shape:
- Input[0]: Input graph sequence in :math:`(N, in_ch, T_{in}, V)` format
- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
- Output[0]: Outpu graph sequence in :math:`(N, out_ch, T_{out}, V)` format
where
:math:`N` is a batch size,
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
:math:`T_{in}/T_{out}` is a length of input/output sequence,
:math:`V` is the number of graph nodes.
:in_ch= dimension of coordinates
: out_ch=dimension of coordinates
+
"""
def __init__(self,
in_ch,
out_ch,
kernel_size,
stride,
time_dim,
joints_dim,
domain,
interpratable,
dropout,
bias=True):
super(Domain_GCNN_layer, self).__init__()
self.kernel_size = kernel_size
assert self.kernel_size[0] % 2 == 1
assert self.kernel_size[1] % 2 == 1
padding = ((self.kernel_size[0] - 1) // 2, (self.kernel_size[1] - 1) // 2)
self.interpratable = interpratable
self.domain = domain
self.gcn = ConvTemporalGraphical(time_dim, joints_dim, domain, interpratable)
self.tcn = nn.Sequential(nn.Conv2d(in_ch,
out_ch,
(self.kernel_size[0], self.kernel_size[1]),
(stride, stride),
padding,
),
nn.BatchNorm2d(out_ch),
nn.Dropout(dropout, inplace=True),
)
if stride != 1 or in_ch != out_ch:
self.residual = nn.Sequential(nn.Conv2d(in_ch,
out_ch,
kernel_size=1,
stride=(1, 1)),
nn.BatchNorm2d(out_ch),
)
else:
self.residual = nn.Identity()
if self.interpratable:
self.map_to_adj = Map2Adj(in_ch,
time_dim,
joints_dim,
domain,
dropout,
)
else:
self.map_to_adj = nn.Identity()
self.prelu = nn.PReLU()
def forward(self, x):
# assert A.shape[0] == self.kernel_size[1], print(A.shape[0],self.kernel_size)
res = self.residual(x)
self.Adj = self.map_to_adj(x)
if self.interpratable:
self.gcn.A = self.Adj
x1 = self.gcn(x)
x2 = self.tcn(x1)
x3 = x2 + res
x4 = self.prelu(x3)
return x4
# Dynamic SpatioTemporal Decompose Graph Convolutions (DSTD-GC)
class DSTD_GC(nn.Module):
"""
Shape:
- Input[0]: Input graph sequence in :math:`(N, in_ch, T_{in}, V)` format
- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
- Output[0]: Outpu graph sequence in :math:`(N, out_ch, T_{out}, V)` format
where
:math:`N` is a batch size,
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
:math:`T_{in}/T_{out}` is a length of input/output sequence,
:math:`V` is the number of graph nodes.
: in_ch= dimension of coordinates
: out_ch=dimension of coordinates
+
"""
def __init__(self,
in_ch,
out_ch,
interpratable,
kernel_size,
stride,
time_dim,
joints_dim,
reduction,
dropout):
super(DSTD_GC, self).__init__()
self.dsgn = Domain_GCNN_layer(in_ch, out_ch, kernel_size, stride,
time_dim, joints_dim, "space", interpratable, dropout)
self.tsgn = Domain_GCNN_layer(in_ch, out_ch, kernel_size, stride,
time_dim, joints_dim, "time", interpratable, dropout)
self.compressor = nn.Sequential(nn.Conv2d(out_ch * 2, out_ch, 1, bias=False),
nn.BatchNorm2d(out_ch),
nn.PReLU(),
SE.SELayer2d(out_ch, reduction=reduction),
)
if stride != 1 or in_ch != out_ch:
self.residual = nn.Sequential(nn.Conv2d(in_ch,
out_ch,
kernel_size=1,
stride=(1, 1)),
nn.BatchNorm2d(out_ch),
)
else:
self.residual = nn.Identity()
# Weighting features
out_ch_c = out_ch // 2 if out_ch // 2 > 1 else 1
self.global_norm = nn.BatchNorm2d(in_ch)
self.conv_s = nn.Sequential(nn.Conv2d(in_ch, out_ch_c, (time_dim, 1), bias=False),
nn.BatchNorm2d(out_ch_c),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
nn.Conv2d(out_ch_c, out_ch, (1, joints_dim), bias=False),
nn.BatchNorm2d(out_ch),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
)
self.conv_t = nn.Sequential(nn.Conv2d(in_ch, out_ch_c, (time_dim, 1), bias=False),
nn.BatchNorm2d(out_ch_c),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
nn.Conv2d(out_ch_c, out_ch, (1, joints_dim), bias=False),
nn.BatchNorm2d(out_ch),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
)
self.map_s = nn.Sequential(nn.Linear(out_ch + 2 + time_dim * 2, out_ch, bias=False),
nn.BatchNorm1d(out_ch),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
nn.Linear(out_ch, out_ch, bias=False),
)
self.map_t = nn.Sequential(nn.Linear(out_ch + 2 + time_dim * 2, out_ch, bias=False),
nn.BatchNorm1d(out_ch),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
nn.Linear(out_ch, out_ch, bias=False),
)
self.prelu1 = nn.Sequential(nn.BatchNorm2d(out_ch),
nn.PReLU(),
)
self.prelu2 = nn.Sequential(nn.BatchNorm2d(out_ch),
nn.PReLU(),
)
def _get_stats_(self, x):
global_avg_pool = x.mean((3, 2)).mean(1, keepdims=True)
global_avg_pool_features = x.mean(3).mean(1)
global_std_pool = x.std((3, 2)).std(1, keepdims=True)
global_std_pool_features = x.std(3).std(1)
return torch.cat((
global_avg_pool,
global_avg_pool_features,
global_std_pool,
global_std_pool_features,
),
dim=1)
def forward(self, x):
b, dim, seq, joints = x.shape # 64, 3, 10, 22
xn = self.global_norm(x)
stats = self._get_stats_(xn)
w1 = torch.cat((self.conv_s(xn).view(b, -1), stats), dim=1)
stats = self._get_stats_(xn)
w2 = torch.cat((self.conv_t(xn).view(b, -1), stats), dim=1)
self.w1 = self.map_s(w1)
self.w2 = self.map_t(w2)
w1 = self.w1[..., None, None]
w2 = self.w2[..., None, None]
x1 = self.dsgn(xn)
x2 = self.tsgn(xn)
out = torch.cat((self.prelu1(w1 * x1), self.prelu2(w2 * x2)), dim=1)
out = self.compressor(out)
return torch.clip(out + self.residual(xn), -1e5, 1e5)
class ContextLayer(nn.Module):
def __init__(self,
in_ch,
hidden_ch,
output_seq,
input_seq,
joints,
dims=3,
reduction=8,
dropout=0.1,
):
super(ContextLayer, self).__init__()
self.n_output = output_seq
self.n_joints = joints
self.n_input = input_seq
self.context_conv1 = nn.Sequential(nn.Conv2d(in_ch, hidden_ch, 1, bias=False),
nn.BatchNorm2d(hidden_ch),
nn.PReLU(),
)
self.context_conv2 = nn.Sequential(nn.Conv2d(in_ch, hidden_ch, (input_seq, 1), bias=False),
nn.BatchNorm2d(hidden_ch),
nn.PReLU(),
)
self.context_conv3 = nn.Sequential(nn.Conv2d(in_ch, hidden_ch, 1, bias=False),
nn.BatchNorm2d(hidden_ch),
nn.PReLU(),
)
self.map1 = nn.Sequential(nn.Linear(hidden_ch, self.n_output, bias=False),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
)
self.map2 = nn.Sequential(nn.Linear(hidden_ch, self.n_output, bias=False),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
)
self.map3 = nn.Sequential(nn.Linear(hidden_ch, self.n_output, bias=False),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
)
self.fmap_s = nn.Sequential(nn.Linear(self.n_output * 3, self.n_joints, bias=False),
nn.BatchNorm1d(self.n_joints),
nn.Dropout(dropout, inplace=True), )
self.fmap_t = nn.Sequential(nn.Linear(self.n_output * 3, self.n_output, bias=False),
nn.BatchNorm1d(self.n_output),
nn.Dropout(dropout, inplace=True), )
# inter_ch = self.n_joints # // 2
self.norm_map = nn.Sequential(nn.Conv1d(self.n_output, self.n_output, 1, bias=False),
nn.BatchNorm1d(self.n_output),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
SE.SELayer1d(self.n_output, reduction=reduction),
nn.Conv1d(self.n_output, self.n_output, 1, bias=False),
nn.BatchNorm1d(self.n_output),
nn.Dropout(dropout, inplace=True),
nn.PReLU(),
)
self.fconv = nn.Sequential(nn.Conv2d(1, dims, 1, bias=False),
nn.BatchNorm2d(dims),
nn.PReLU(),
nn.Conv2d(dims, dims, 1, bias=False),
nn.BatchNorm2d(dims),
nn.PReLU(),
)
self.SE = SE.SELayer2d(self.n_output, reduction=reduction)
def forward(self, x):
b, _, seq, joint_dim = x.shape
y1 = self.context_conv1(x).max(-1)[0].max(-1)[0]
y2 = self.context_conv2(x).view(b, -1, joint_dim).max(-1)[0]
ym = self.context_conv3(x).mean((2, 3))
y = torch.cat((self.map1(y1), self.map2(y2), self.map3(ym)), dim=1)
self.joints = self.fmap_s(y)
self.displacements = self.fmap_t(y) # .cumsum(1)
self.seq_joints = torch.bmm(self.displacements.unsqueeze(2), self.joints.unsqueeze(1))
self.seq_joints_n = self.norm_map(self.seq_joints)
self.seq_joints_dims = self.fconv(self.seq_joints_n.view(b, 1, self.n_output, self.n_joints))
o = self.SE(self.seq_joints_dims.permute(0, 2, 3, 1))
return o
class CISTGCN(nn.Module):
"""
Shape:
- Input[0]: Input sequence in :math:`(N, in_ch,T_in, V)` format
- Output[0]: Output sequence in :math:`(N,T_out,in_ch, V)` format
where
:math:`N` is a batch size,
:math:`T_{in}/T_{out}` is a length of input/output sequence,
:math:`V` is the number of graph nodes.
:in_ch=number of channels for the coordiantes(default=3)
+
"""
def __init__(self, arch, learn):
super(CISTGCN, self).__init__()
self.clipping = arch.model_params.clipping
self.n_input = arch.model_params.input_n
self.n_output = arch.model_params.output_n
self.n_joints = arch.model_params.joints
self.n_txcnn_layers = arch.model_params.n_txcnn_layers
self.txc_kernel_size = [arch.model_params.txc_kernel_size] * 2
self.input_gcn = arch.model_params.input_gcn
self.output_gcn = arch.model_params.output_gcn
self.reduction = arch.model_params.reduction
self.hidden_dim = arch.model_params.hidden_dim
self.st_gcnns = nn.ModuleList()
self.txcnns = nn.ModuleList()
self.se = nn.ModuleList()
self.in_conv = nn.ModuleList()
self.context_layer = nn.ModuleList()
self.trans = nn.ModuleList()
self.in_ch = 10
self.model_tx = self.input_gcn.model_complexity.copy()
self.model_tx.insert(0, 1) # add 1 in the position 0.
self.input_gcn.model_complexity.insert(0, self.in_ch)
self.input_gcn.model_complexity.append(self.in_ch)
# self.input_gcn.interpretable.insert(0, True)
# self.input_gcn.interpretable.append(False)
for i in range(len(self.input_gcn.model_complexity) - 1):
self.st_gcnns.append(DSTD_GC(self.input_gcn.model_complexity[i],
self.input_gcn.model_complexity[i + 1],
self.input_gcn.interpretable[i],
[1, 1], 1, self.n_input, self.n_joints, self.reduction, learn.dropout))
self.context_layer = ContextLayer(1, self.hidden_dim,
self.n_output, self.n_output, self.n_joints,
3, self.reduction, learn.dropout
)
# at this point, we must permute the dimensions of the gcn network, from (N,C,T,V) into (N,T,C,V)
# with kernel_size[3,3] the dimensions of C,V will be maintained
self.txcnns.append(FPN(self.n_input, self.n_output, self.txc_kernel_size, 0., self.reduction))
for i in range(1, self.n_txcnn_layers):
self.txcnns.append(FPN(self.n_output, self.n_output, self.txc_kernel_size, 0., self.reduction))
self.prelus = nn.ModuleList()
for j in range(self.n_txcnn_layers):
self.prelus.append(nn.PReLU())
self.dim_conversor = nn.Sequential(nn.Conv2d(self.in_ch, 3, 1, bias=False),
nn.BatchNorm2d(3),
nn.PReLU(),
nn.Conv2d(3, 3, 1, bias=False),
nn.PReLU(3), )
self.st_gcnns_o = nn.ModuleList()
self.output_gcn.model_complexity.insert(0, 3)
for i in range(len(self.output_gcn.model_complexity) - 1):
self.st_gcnns_o.append(DSTD_GC(self.output_gcn.model_complexity[i],
self.output_gcn.model_complexity[i + 1],
self.output_gcn.interpretable[i],
[1, 1], 1, self.n_joints, self.n_output, self.reduction, learn.dropout))
self.st_gcnns_o.apply(self._init_weights)
self.st_gcnns.apply(self._init_weights)
self.txcnns.apply(self._init_weights)
def _init_weights(self, m, gain=0.1):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight, gain=gain)
# if isinstance(m, (nn.Conv2d, nn.Conv1d)):
# torch.nn.init.xavier_normal_(m.weight, gain=gain)
if isinstance(m, nn.PReLU):
torch.nn.init.constant_(m.weight, 0.25)
def forward(self, x):
b, seq, joints, dim = x.shape
vel = torch.zeros_like(x)
vel[:, :-1] = torch.diff(x, dim=1)
vel[:, -1] = x[:, -1]
acc = torch.zeros_like(x)
acc[:, :-1] = torch.diff(vel, dim=1)
acc[:, -1] = vel[:, -1]
x1 = torch.cat((x, acc, vel, torch.norm(vel, dim=-1, keepdim=True)), dim=-1)
x2 = x1.permute((0, 3, 1, 2)) # (torch.Size([64, 10, 22, 7])
x3 = x2
for i in range(len(self.st_gcnns)):
x3 = self.st_gcnns[i](x3)
x5 = x3.permute(0, 2, 1, 3) # prepare the input for the Time-Extrapolator-CNN (NCTV->NTCV)
x6 = self.prelus[0](self.txcnns[0](x5))
for i in range(1, self.n_txcnn_layers):
x6 = self.prelus[i](self.txcnns[i](x6)) + x6 # residual connection
x6 = self.dim_conversor(x6.permute(0, 2, 1, 3)).permute(0, 2, 3, 1)
x7 = x6.cumsum(1)
act = self.context_layer(x7.reshape(b, 1, self.n_output, joints * x7.shape[-1]))
x8 = x7.permute(0, 3, 2, 1)
for i in range(len(self.st_gcnns_o)):
x8 = self.st_gcnns_o[i](x8)
x9 = x8.permute(0, 3, 2, 1) + act
return x[:, -1:] + x9,
|