File size: 2,422 Bytes
55ca09f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch.nn as nn

from .clip import FrozenCLIPEmbedder
from .switti import Switti
from .vqvae import VQVAE
from .pipeline import SwittiPipeline


def build_models(
    # Shared args
    device,
    patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16),  # 10 steps by default
    # VQVAE args
    V=4096,
    Cvae=32,
    ch=160,
    share_quant_resi=4,
    # Switti args
    depth=16,
    rope=True,
    rope_theta=10000,
    rope_size=128,
    use_swiglu_ffn=True,
    use_ar=False,
    use_crop_cond=True,
    attn_l2_norm=True,
    init_adaln=0.5,
    init_adaln_gamma=1e-5,
    init_head=0.02,
    init_std=-1,  # init_std < 0: automated
    drop_rate=0.0,
    attn_drop_rate=0.0,
    dpr=0,
    norm_eps=1e-6,
    # pipeline args
    text_encoder_path="openai/clip-vit-large-patch14",
    text_encoder_2_path="laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
) -> tuple[VQVAE, Switti]:
    heads = depth
    width = depth * 64
    if dpr > 0:
        dpr = dpr * depth / 24

    # disable built-in initialization for speed
    for clz in (
        nn.Linear,
        nn.LayerNorm,
        nn.BatchNorm2d,
        nn.SyncBatchNorm,
        nn.Conv1d,
        nn.Conv2d,
        nn.ConvTranspose1d,
        nn.ConvTranspose2d,
    ):
        setattr(clz, "reset_parameters", lambda self: None)

    # build models
    vae_local = VQVAE(
        vocab_size=V,
        z_channels=Cvae,
        ch=ch,
        test_mode=True,
        share_quant_resi=share_quant_resi,
        v_patch_nums=patch_nums,
    ).to(device)
    
    switti_wo_ddp = Switti(
        depth=depth,
        embed_dim=width,
        num_heads=heads,
        drop_rate=drop_rate,
        attn_drop_rate=attn_drop_rate,
        drop_path_rate=dpr,
        norm_eps=norm_eps,
        attn_l2_norm=attn_l2_norm,
        patch_nums=patch_nums,
        rope=rope,
        rope_theta=rope_theta,
        rope_size=rope_size,
        use_swiglu_ffn=use_swiglu_ffn,
        use_ar=use_ar,
        use_crop_cond=use_crop_cond,
    ).to(device)
    
    switti_wo_ddp.init_weights(
        init_adaln=init_adaln,
        init_adaln_gamma=init_adaln_gamma,
        init_head=init_head,
        init_std=init_std,
    )
    text_encoder = FrozenCLIPEmbedder(text_encoder_path)
    text_encoder_2 = FrozenCLIPEmbedder(text_encoder_2_path)
    pipe = SwittiPipeline(switti_wo_ddp, vae_local, text_encoder, text_encoder_2, device)

    return vae_local, switti_wo_ddp, pipe