Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,422 Bytes
55ca09f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import torch.nn as nn
from .clip import FrozenCLIPEmbedder
from .switti import Switti
from .vqvae import VQVAE
from .pipeline import SwittiPipeline
def build_models(
# Shared args
device,
patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
# VQVAE args
V=4096,
Cvae=32,
ch=160,
share_quant_resi=4,
# Switti args
depth=16,
rope=True,
rope_theta=10000,
rope_size=128,
use_swiglu_ffn=True,
use_ar=False,
use_crop_cond=True,
attn_l2_norm=True,
init_adaln=0.5,
init_adaln_gamma=1e-5,
init_head=0.02,
init_std=-1, # init_std < 0: automated
drop_rate=0.0,
attn_drop_rate=0.0,
dpr=0,
norm_eps=1e-6,
# pipeline args
text_encoder_path="openai/clip-vit-large-patch14",
text_encoder_2_path="laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
) -> tuple[VQVAE, Switti]:
heads = depth
width = depth * 64
if dpr > 0:
dpr = dpr * depth / 24
# disable built-in initialization for speed
for clz in (
nn.Linear,
nn.LayerNorm,
nn.BatchNorm2d,
nn.SyncBatchNorm,
nn.Conv1d,
nn.Conv2d,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
):
setattr(clz, "reset_parameters", lambda self: None)
# build models
vae_local = VQVAE(
vocab_size=V,
z_channels=Cvae,
ch=ch,
test_mode=True,
share_quant_resi=share_quant_resi,
v_patch_nums=patch_nums,
).to(device)
switti_wo_ddp = Switti(
depth=depth,
embed_dim=width,
num_heads=heads,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=dpr,
norm_eps=norm_eps,
attn_l2_norm=attn_l2_norm,
patch_nums=patch_nums,
rope=rope,
rope_theta=rope_theta,
rope_size=rope_size,
use_swiglu_ffn=use_swiglu_ffn,
use_ar=use_ar,
use_crop_cond=use_crop_cond,
).to(device)
switti_wo_ddp.init_weights(
init_adaln=init_adaln,
init_adaln_gamma=init_adaln_gamma,
init_head=init_head,
init_std=init_std,
)
text_encoder = FrozenCLIPEmbedder(text_encoder_path)
text_encoder_2 = FrozenCLIPEmbedder(text_encoder_2_path)
pipe = SwittiPipeline(switti_wo_ddp, vae_local, text_encoder, text_encoder_2, device)
return vae_local, switti_wo_ddp, pipe
|