#!/bin/bash # This script is used for Causal VAE Training # It undergoes a two-stage training # Stage-1: image and video mixed training # Stage-2: pure video training, using context parallel to load video with more video frames (up to 257 frames) GPUS=8 # The gpu number VAE_MODEL_PATH=PATH/vae_ckpt # The vae model dir LPIPS_CKPT=vgg_lpips.pth # The LPIPS VGG CKPT path, used for calculating the lpips loss OUTPUT_DIR=/PATH/output_dir # The checkpoint saving dir IMAGE_ANNO=annotation/image_text.jsonl # The image annotation file path VIDEO_ANNO=annotation/video_text.jsonl # The video annotation file path RESOLUTION=256 # The training resolution, default is 256 NUM_FRAMES=17 # x * 8 + 1, the number of video frames BATCH_SIZE=2 # Stage-1 torchrun --nproc_per_node $GPUS \ train/train_video_vae.py \ --num_workers 6 \ --model_path $VAE_MODEL_PATH \ --model_dtype bf16 \ --lpips_ckpt $LPIPS_CKPT \ --output_dir $OUTPUT_DIR \ --image_anno $IMAGE_ANNO \ --video_anno $VIDEO_ANNO \ --use_image_video_mixed_training \ --image_mix_ratio 0.1 \ --resolution $RESOLUTION \ --max_frames $NUM_FRAMES \ --disc_start 250000 \ --kl_weight 1e-12 \ --pixelloss_weight 10.0 \ --perceptual_weight 1.0 \ --disc_weight 0.5 \ --batch_size $BATCH_SIZE \ --opt adamw \ --opt_betas 0.9 0.95 \ --seed 42 \ --weight_decay 1e-3 \ --clip_grad 1.0 \ --lr 1e-4 \ --lr_disc 1e-4 \ --warmup_epochs 1 \ --epochs 100 \ --iters_per_epoch 2000 \ --print_freq 40 \ --save_ckpt_freq 1 # Stage-2 CONTEXT_SIZE=2 # context parallel size, GPUS % CONTEXT_SIZE == 0 NUM_FRAMES=33 # 17 * CONTEXT_SIZE + 1 VAE_CKPT_PATH=stage1_path # The stage-1 trained ckpt torchrun --nproc_per_node $GPUS \ train/train_video_vae.py \ --num_workers 6 \ --model_path $VAE_MODEL_PATH \ --model_dtype bf16 \ --pretrained_vae_weight $VAE_CKPT_PATH \ --use_context_parallel \ --context_size $CONTEXT_SIZE \ --lpips_ckpt $LPIPS_CKPT \ --output_dir $OUTPUT_DIR \ --video_anno $VIDEO_ANNO \ --image_mix_ratio 0.0 \ --resolution $RESOLUTION \ --max_frames $NUM_FRAMES \ --disc_start 250000 \ --kl_weight 1e-12 \ --pixelloss_weight 10.0 \ --perceptual_weight 1.0 \ --disc_weight 0.5 \ --batch_size $BATCH_SIZE \ --opt adamw \ --opt_betas 0.9 0.95 \ --seed 42 \ --weight_decay 1e-3 \ --clip_grad 1.0 \ --lr 1e-4 \ --lr_disc 1e-4 \ --warmup_epochs 1 \ --epochs 100 \ --iters_per_epoch 2000 \ --print_freq 40 \ --save_ckpt_freq 1