File size: 1,576 Bytes
4e3cd77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch

import torch.nn as nn
import numpy as np

from .utils import activations, forward_default, get_activation, Transpose


def forward_swin(pretrained, x):
    return forward_default(pretrained, x)


def _make_swin_backbone(
        model,
        hooks=[1, 1, 17, 1],
        patch_grid=[96, 96]
):
    pretrained = nn.Module()

    pretrained.model = model
    pretrained.model.layers[0].blocks[hooks[0]].register_forward_hook(get_activation("1"))
    pretrained.model.layers[1].blocks[hooks[1]].register_forward_hook(get_activation("2"))
    pretrained.model.layers[2].blocks[hooks[2]].register_forward_hook(get_activation("3"))
    pretrained.model.layers[3].blocks[hooks[3]].register_forward_hook(get_activation("4"))

    pretrained.activations = activations

    if hasattr(model, "patch_grid"):
        used_patch_grid = model.patch_grid
    else:
        used_patch_grid = patch_grid

    patch_grid_size = np.array(used_patch_grid, dtype=int)

    pretrained.act_postprocess1 = nn.Sequential(
        Transpose(1, 2),
        nn.Unflatten(2, torch.Size(patch_grid_size.tolist()))
    )
    pretrained.act_postprocess2 = nn.Sequential(
        Transpose(1, 2),
        nn.Unflatten(2, torch.Size((patch_grid_size // 2).tolist()))
    )
    pretrained.act_postprocess3 = nn.Sequential(
        Transpose(1, 2),
        nn.Unflatten(2, torch.Size((patch_grid_size // 4).tolist()))
    )
    pretrained.act_postprocess4 = nn.Sequential(
        Transpose(1, 2),
        nn.Unflatten(2, torch.Size((patch_grid_size // 8).tolist()))
    )

    return pretrained