File size: 1,576 Bytes
4e3cd77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import torch
import torch.nn as nn
import numpy as np
from .utils import activations, forward_default, get_activation, Transpose
def forward_swin(pretrained, x):
return forward_default(pretrained, x)
def _make_swin_backbone(
model,
hooks=[1, 1, 17, 1],
patch_grid=[96, 96]
):
pretrained = nn.Module()
pretrained.model = model
pretrained.model.layers[0].blocks[hooks[0]].register_forward_hook(get_activation("1"))
pretrained.model.layers[1].blocks[hooks[1]].register_forward_hook(get_activation("2"))
pretrained.model.layers[2].blocks[hooks[2]].register_forward_hook(get_activation("3"))
pretrained.model.layers[3].blocks[hooks[3]].register_forward_hook(get_activation("4"))
pretrained.activations = activations
if hasattr(model, "patch_grid"):
used_patch_grid = model.patch_grid
else:
used_patch_grid = patch_grid
patch_grid_size = np.array(used_patch_grid, dtype=int)
pretrained.act_postprocess1 = nn.Sequential(
Transpose(1, 2),
nn.Unflatten(2, torch.Size(patch_grid_size.tolist()))
)
pretrained.act_postprocess2 = nn.Sequential(
Transpose(1, 2),
nn.Unflatten(2, torch.Size((patch_grid_size // 2).tolist()))
)
pretrained.act_postprocess3 = nn.Sequential(
Transpose(1, 2),
nn.Unflatten(2, torch.Size((patch_grid_size // 4).tolist()))
)
pretrained.act_postprocess4 = nn.Sequential(
Transpose(1, 2),
nn.Unflatten(2, torch.Size((patch_grid_size // 8).tolist()))
)
return pretrained
|