File size: 4,627 Bytes
b7f3942 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
from arch.hourglass import image_transformer_v2 as itv2
from arch.hourglass.image_transformer_v2 import ImageTransformerDenoiserModelV2
from arch.swinir.swinir import SwinIR
def create_arch(arch, condition_channels=0):
# arch should be, e.g., swinir_XL, or hdit_XL
arch_name, arch_size = arch.split('_')
arch_config = arch_configs[arch_name][arch_size].copy()
arch_config['in_channels'] += condition_channels
return arch_name_to_object[arch_name](**arch_config)
arch_configs = {
'hdit': {
"ImageNet256Sp4": {
'in_channels': 3,
'out_channels': 3,
'widths': [256, 512, 1024],
'depths': [2, 2, 8],
'patch_size': [4, 4],
'self_attns': [
{"type": "neighborhood", "d_head": 64, "kernel_size": 7},
{"type": "neighborhood", "d_head": 64, "kernel_size": 7},
{"type": "global", "d_head": 64}
],
'mapping_depth': 2,
'mapping_width': 768,
'dropout_rate': [0, 0, 0],
'mapping_dropout_rate': 0.0
},
"XL2": {
'in_channels': 3,
'out_channels': 3,
'widths': [384, 768],
'depths': [2, 11],
'patch_size': [4, 4],
'self_attns': [
{"type": "neighborhood", "d_head": 64, "kernel_size": 7},
{"type": "global", "d_head": 64}
],
'mapping_depth': 2,
'mapping_width': 768,
'dropout_rate': [0, 0],
'mapping_dropout_rate': 0.0
}
},
'swinir': {
"M": {
'in_channels': 3,
'out_channels': 3,
'embed_dim': 120,
'depths': [6, 6, 6, 6, 6],
'num_heads': [6, 6, 6, 6, 6],
'resi_connection': '1conv',
'sf': 8
},
"L": {
'in_channels': 3,
'out_channels': 3,
'embed_dim': 180,
'depths': [6, 6, 6, 6, 6, 6, 6, 6],
'num_heads': [6, 6, 6, 6, 6, 6, 6, 6],
'resi_connection': '1conv',
'sf': 8
},
},
}
def create_swinir_model(in_channels, out_channels, embed_dim, depths, num_heads, resi_connection,
sf):
return SwinIR(
img_size=64,
patch_size=1,
in_chans=in_channels,
num_out_ch=out_channels,
embed_dim=embed_dim,
depths=depths,
num_heads=num_heads,
window_size=8,
mlp_ratio=2,
sf=sf,
img_range=1.0,
upsampler="nearest+conv",
resi_connection=resi_connection,
unshuffle=True,
unshuffle_scale=8
)
def create_hdit_model(widths,
depths,
self_attns,
dropout_rate,
mapping_depth,
mapping_width,
mapping_dropout_rate,
in_channels,
out_channels,
patch_size
):
assert len(widths) == len(depths)
assert len(widths) == len(self_attns)
assert len(widths) == len(dropout_rate)
mapping_d_ff = mapping_width * 3
d_ffs = []
for width in widths:
d_ffs.append(width * 3)
levels = []
for depth, width, d_ff, self_attn, dropout in zip(depths, widths, d_ffs, self_attns, dropout_rate):
if self_attn['type'] == 'global':
self_attn = itv2.GlobalAttentionSpec(self_attn.get('d_head', 64))
elif self_attn['type'] == 'neighborhood':
self_attn = itv2.NeighborhoodAttentionSpec(self_attn.get('d_head', 64), self_attn.get('kernel_size', 7))
elif self_attn['type'] == 'shifted-window':
self_attn = itv2.ShiftedWindowAttentionSpec(self_attn.get('d_head', 64), self_attn['window_size'])
elif self_attn['type'] == 'none':
self_attn = itv2.NoAttentionSpec()
else:
raise ValueError(f'unsupported self attention type {self_attn["type"]}')
levels.append(itv2.LevelSpec(depth, width, d_ff, self_attn, dropout))
mapping = itv2.MappingSpec(mapping_depth, mapping_width, mapping_d_ff, mapping_dropout_rate)
model = ImageTransformerDenoiserModelV2(
levels=levels,
mapping=mapping,
in_channels=in_channels,
out_channels=out_channels,
patch_size=patch_size,
num_classes=0,
mapping_cond_dim=0,
)
return model
arch_name_to_object = {
'hdit': create_hdit_model,
'swinir': create_swinir_model,
}
|