Training an VQGAN VAE
VQVAEs were first introduced in Neural Discrete Representation Learning and was combined with a GAN in the paper Taming Transformers for High-Resolution Image Synthesis. The basic idea of a VQVAE is it's a type of a variational auto encoder with tokens as the latent space similar to tokens for LLMs. This script was adapted from a pr to huggingface's open-muse project with general code following lucidrian's implementation of the vqgan training script but both of these implementation follow from the taming transformer repo.
Creating a training image set is described in a different document.
Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
Important
To make sure you can successfully run the latest versions of the example scripts, we highly recommend installing from source and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
Then cd in the example folder and run
pip install -r requirements.txt
And initialize an 🤗Accelerate environment with:
accelerate config
Training on CIFAR10
The command to train a VQGAN model on cifar10 dataset:
accelerate launch train_vqgan.py \
--dataset_name=cifar10 \
--image_column=img \
--validation_images images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg \
--resolution=128 \
--train_batch_size=2 \
--gradient_accumulation_steps=8 \
--report_to=wandb
An example training run is here by @sayakpaul and a lower scale one here. The validation images can be obtained from here. The simplest way to improve the quality of a VQGAN model is to maximize the amount of information present in the bottleneck. The easiest way to do this is increasing the image resolution. However, other ways include, but not limited to, lowering compression by downsampling fewer times or increasing the vocaburary size which at most can be around 16384. How to do this is shown below.
Modifying the architecture
To modify the architecture of the vqgan model you can save the config taken from here and then provide that to the script with the option --model_config_name_or_path. This config is below
{
"_class_name": "VQModel",
"_diffusers_version": "0.17.0.dev0",
"act_fn": "silu",
"block_out_channels": [
128,
256,
256,
512
],
"down_block_types": [
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"AttnDownEncoderBlock2D"
],
"in_channels": 3,
"latent_channels": 4,
"layers_per_block": 2,
"norm_num_groups": 32,
"norm_type": "spatial",
"num_vq_embeddings": 16384,
"out_channels": 3,
"sample_size": 32,
"scaling_factor": 0.18215,
"up_block_types": [
"AttnUpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D"
],
"vq_embed_dim": 4
}
To lower the amount of layers in a VQGan, you can remove layers by modifying the block_out_channels, down_block_types, and up_block_types like below
{
"_class_name": "VQModel",
"_diffusers_version": "0.17.0.dev0",
"act_fn": "silu",
"block_out_channels": [
128,
256,
256,
],
"down_block_types": [
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
],
"in_channels": 3,
"latent_channels": 4,
"layers_per_block": 2,
"norm_num_groups": 32,
"norm_type": "spatial",
"num_vq_embeddings": 16384,
"out_channels": 3,
"sample_size": 32,
"scaling_factor": 0.18215,
"up_block_types": [
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D"
],
"vq_embed_dim": 4
}
For increasing the size of the vocaburaries you can increase num_vq_embeddings. However, some research shows that the representation of VQGANs start degrading after 2^14~16384 vq embeddings so it's not recommended to go past that.
Extra training tips/ideas
During logging take care to make sure data_time is low. data_time is the amount spent loading the data and where the GPU is not active. So essentially, it's the time wasted. The easiest way to lower data time is to increase the --dataloader_num_workers to a higher number like 4. Due to a bug in Pytorch, this only works on linux based systems. For more details check here Secondly, training should seem to be done when both the discriminator and the generator loss converges. Thirdly, another low hanging fruit is just using ema using the --use_ema parameter. This tends to make the output images smoother. This has a con where you have to lower your batch size by 1 but it may be worth it. Another more experimental low hanging fruit is changing from the vgg19 to different models for the lpips loss using the --timm_model_backend. If you do this, I recommend also changing the timm_model_layers parameter to the layer in your model which you think is best for representation. However, becareful with the feature map norms since this can easily overdominate the loss.