Spaces:
Running
on
A10G
Running
on
A10G
File size: 4,128 Bytes
320e465 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from tracker.model.channel_attn import CAResBlock
def interpolate_groups(g: torch.Tensor, ratio: float, mode: str,
align_corners: bool) -> torch.Tensor:
batch_size, num_objects = g.shape[:2]
g = F.interpolate(g.flatten(start_dim=0, end_dim=1),
scale_factor=ratio,
mode=mode,
align_corners=align_corners)
g = g.view(batch_size, num_objects, *g.shape[1:])
return g
def upsample_groups(g: torch.Tensor,
ratio: float = 2,
mode: str = 'bilinear',
align_corners: bool = False) -> torch.Tensor:
return interpolate_groups(g, ratio, mode, align_corners)
def downsample_groups(g: torch.Tensor,
ratio: float = 1 / 2,
mode: str = 'area',
align_corners: bool = None) -> torch.Tensor:
return interpolate_groups(g, ratio, mode, align_corners)
class GConv2d(nn.Conv2d):
def forward(self, g: torch.Tensor) -> torch.Tensor:
batch_size, num_objects = g.shape[:2]
g = super().forward(g.flatten(start_dim=0, end_dim=1))
return g.view(batch_size, num_objects, *g.shape[1:])
class GroupResBlock(nn.Module):
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
if in_dim == out_dim:
self.downsample = nn.Identity()
else:
self.downsample = GConv2d(in_dim, out_dim, kernel_size=1)
self.conv1 = GConv2d(in_dim, out_dim, kernel_size=3, padding=1)
self.conv2 = GConv2d(out_dim, out_dim, kernel_size=3, padding=1)
def forward(self, g: torch.Tensor) -> torch.Tensor:
out_g = self.conv1(F.relu(g))
out_g = self.conv2(F.relu(out_g))
g = self.downsample(g)
return out_g + g
class MainToGroupDistributor(nn.Module):
def __init__(self,
x_transform: Optional[nn.Module] = None,
g_transform: Optional[nn.Module] = None,
method: str = 'cat',
reverse_order: bool = False):
super().__init__()
self.x_transform = x_transform
self.g_transform = g_transform
self.method = method
self.reverse_order = reverse_order
def forward(self, x: torch.Tensor, g: torch.Tensor, skip_expand: bool = False) -> torch.Tensor:
num_objects = g.shape[1]
if self.x_transform is not None:
x = self.x_transform(x)
if self.g_transform is not None:
g = self.g_transform(g)
if not skip_expand:
x = x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)
if self.method == 'cat':
if self.reverse_order:
g = torch.cat([g, x], 2)
else:
g = torch.cat([x, g], 2)
elif self.method == 'add':
g = x + g
elif self.method == 'mulcat':
g = torch.cat([x * g, g], dim=2)
elif self.method == 'muladd':
g = x * g + g
else:
raise NotImplementedError
return g
class GroupFeatureFusionBlock(nn.Module):
def __init__(self, x_in_dim: int, g_in_dim: int, out_dim: int):
super().__init__()
x_transform = nn.Conv2d(x_in_dim, out_dim, kernel_size=1)
g_transform = GConv2d(g_in_dim, out_dim, kernel_size=1)
self.distributor = MainToGroupDistributor(x_transform=x_transform,
g_transform=g_transform,
method='add')
self.block1 = CAResBlock(out_dim, out_dim)
self.block2 = CAResBlock(out_dim, out_dim)
def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
batch_size, num_objects = g.shape[:2]
g = self.distributor(x, g)
g = g.flatten(start_dim=0, end_dim=1)
g = self.block1(g)
g = self.block2(g)
g = g.view(batch_size, num_objects, *g.shape[1:])
return g |