Memory-efficient Diffusion Transformers with Quanto and Diffusers

Published July 30, 2024
Update on GitHub

This article is also available in Chinese 简体中文.

Over the past few months, we have seen an emergence in the use of Transformer-based diffusion backbones for high-resolution text-to-image (T2I) generation. These models use the transformer architecture as the building block for the diffusion process, instead of the UNet architecture that was prevalent in many of the initial diffusion models. Thanks to the nature of Transformers, these backbones show good scalability, with models ranging from 0.6B to 8B parameters.

As models become larger, memory requirements increase. The problem intensifies because a diffusion pipeline usually consists of several components: a text encoder, a diffusion backbone, and an image decoder. Furthermore, modern diffusion pipelines use multiple text encoders – for example, there are three in the case of Stable Diffusion 3. It takes 18.765 GB of GPU memory to run SD3 inference using FP16 precision.

These high memory requirements can make it difficult to use these models with consumer GPUs, slowing adoption and making experimentation harder. In this post, we show how to improve the memory efficiency of Transformer-based diffusion pipelines by leveraging Quanto's quantization utilities from the Diffusers library.

Table of contents

Preliminaries

For a detailed introduction to Quanto, please refer to this post. In short, Quanto is a quantization toolkit built on PyTorch. It's part of Hugging Face Optimum, a set of tools for hardware optimization.

Model quantization is a popular tool among LLM practitioners, but not so much with diffusion models. Quanto can help bridge this gap and provide memory savings with little or no quality degradation.

For benchmarking purposes, we use an H100 GPU with the following environment:

Unless otherwise specified, we default to performing computations in FP16. We chose not to quantize the VAE to prevent numerical instability issues. Our benchmarking code can be found here.

At the time of this writing, we have the following Transformer-based diffusion pipelines for text-to-image generation in Diffusers:

We also have Latte, a Transformer-based text-to-video generation pipeline.

For brevity, we keep our study limited to the following three: PixArt-Sigma, Stable Diffusion 3, and Aura Flow. The table below shows the parameter counts of their diffusion backbones:

It’s worth keeping in mind that this post primarily focuses on memory efficiency at a slight or negligible cost of inference latency.

Quantizing a DiffusionPipeline with Quanto

Quantizing a model with Quanto is straightforward.

from optimum.quanto import freeze, qfloat8, quantize
from diffusers import PixArtSigmaPipeline
import torch

pipeline = PixArtSigmaPipeline.from_pretrained(
    "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16
).to("cuda")

quantize(pipeline.transformer, weights=qfloat8)
freeze(pipeline.transformer)

We call quantize() on the module to be quantized, specifying what we want to quantize. In the above case, we are just quantizing the parameters, leaving the activations as is. We’re quantizing to the FP8 data-type. We finally call freeze() to replace the original parameters with the quantized parameters.

We can then call this pipeline normally:

image = pipeline("ghibli style, a fantasy landscape with castles").images[0]
FP16 Diffusion Transformer in FP8
FP16 image. FP8 quantized image.

We notice the following memory savings when using FP8, with slightly higher latency and almost no quality degradation:

Batch Size Quantization Memory (GB) Latency (Seconds)
1 None 12.086 1.200
1 FP8 11.547 1.540
4 None 12.087 4.482
4 FP8 11.548 5.109

We can quantize the text encoder in the same way:

quantize(pipeline.text_encoder, weights=qfloat8)
freeze(pipeline.text_encoder)

The text encoder is also a transformer model, and we can quantize it too. Quantizing both the text encoder and the diffusion backbone leads to much larger memory improvements:

Batch Size Quantization Quantize TE Memory (GB) Latency (Seconds)
1 FP8 False 11.547 1.540
1 FP8 True 5.363 1.601
4 FP8 False 11.548 5.109
4 FP8 True 5.364 5.141

Quantizing the text encoder produces results very similar to the previous case:

ckpt@pixart-bs@1-dtype@fp16-qtype@fp8-qte@1.png

Generality of the observations

Quantizing the text encoder together with the diffusion backbone generally works for the models we tried. Stable Diffusion 3 is a special case, as it uses three different text encoders. We found that quantizing the second text encoder does not work well, so we recommend the following alternatives:

The table below gives an idea about the expected memory savings for various text encoder quantization combinations (the diffusion transformer is quantized in all cases):

Batch Size Quantization Quantize TE 1 Quantize TE 2 Quantize TE 3 Memory (GB) Latency (Seconds)
1 FP8 1 1 1 8.200 2.858
1 ✅ FP8 0 0 1 8.294 2.781
1 FP8 1 1 0 14.384 2.833
1 FP8 0 1 0 14.475 2.818
1 ✅ FP8 1 0 0 14.384 2.730
1 FP8 0 1 1 8.325 2.875
1 ✅ FP8 1 0 1 8.204 2.789
1 None - - - 16.403 2.118
Quantized Text Encoder: 1 Quantized Text Encoder: 3 Quantized Text Encoders: 1 and 3
Image with quantized text encoder 1. Image with quantized text encoder 3. Image with quantized text encoders 1 and 3.

Misc findings

bfloat16 is usually better on H100

Using bfloat16 can be faster for supported GPU architectures, such as H100 or 4090. The table below presents some numbers for PixArt measured on our H100 reference hardware:

Batch Size Precision Quantization Memory (GB) Latency (Seconds) Quantize TE
1 FP16 INT8 5.363 1.538 True
1 BF16 INT8 5.364 1.454 True
1 FP16 FP8 5.363 1.601 True
1 BF16 FP8 5.363 1.495 True

The promise of qint8

We found quantizing with qint8 (instead of qfloat8) is generally better in terms of inference latency. This effect gets more pronounced when we horizontally fuse the attention QKV projections (calling fuse_qkv_projections() in Diffusers), thereby thickening the dimensions of the int8 kernels to speed up computation. We present some evidence below for PixArt:

Batch Size Quantization Memory (GB) Latency (Seconds) Quantize TE QKV Projection
1 INT8 5.363 1.538 True False
1 INT8 5.536 1.504 True True
4 INT8 5.365 5.129 True False
4 INT8 5.538 4.989 True True

How about INT4?

We additionally experimented with qint4 when using bfloat16. This is only applicable to bfloat16 on H100 because other configurations are not supported yet. With qint4, we can expect to see more improvements in memory consumption at the cost of increased inference latency. Increased latency is expected, because there is no native hardware support for int4 computation – the weights are transferred using 4 bits, but computation is still done in bfloat16. The table below shows our results for PixArt-Sigma:

Batch Size Quantize TE Memory (GB) Latency (Seconds)
1 No 9.380 7.431
1 Yes 3.058 7.604

Note, however, that due to the aggressive discretization of INT4, the end results can take a hit. This is why, for Transformer-based models in general, we usually leave the final projection layer out of quantization. In Quanto, we do this by:

quantize(pipeline.transformer, weights=qint4, exclude="proj_out")
freeze(pipeline.transformer)

"proj_out" corresponds to the final layer in pipeline.transformer. The table below presents results for various settings:

Quantize TE: No, Layer exclusion: None Quantize TE: No, Layer exclusion: "proj_out" Quantize TE: Yes, Layer exclusion: None Quantize TE: Yes, Layer exclusion: "proj_out"
Image 1 without text encoder quantization. Image 2 without text encoder quantization but with proj_out excluded in diffusion transformer quantization. Image 3 with text encoder quantization. Image 3 with text encoder quantization but with proj_out excluded in diffusion transformer quantization..

To recover the lost image quality, a common practice is to perform quantization-aware training, which is also supported in Quanto. This technique is out of the scope of this post, feel free to contact us if you're interested!

All the results of our experiments for this post can be found here.

Bonus - saving and loading Diffusers models in Quanto

Quantized Diffusers models can be saved and loaded:

from diffusers import PixArtTransformer2DModel
from optimum.quanto import QuantizedPixArtTransformer2DModel, qfloat8

model = PixArtTransformer2DModel.from_pretrained("PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", subfolder="transformer")
qmodel = QuantizedPixArtTransformer2DModel.quantize(model, weights=qfloat8)
qmodel.save_pretrained("pixart-sigma-fp8")

The resulting checkpoint is 587MB in size, instead of the original 2.44GB. We can then load it:

from optimum.quanto import QuantizedPixArtTransformer2DModel
import torch

transformer = QuantizedPixArtTransformer2DModel.from_pretrained("pixart-sigma-fp8") 
transformer.to(device="cuda", dtype=torch.float16)

And use it in a DiffusionPipeline:

from diffusers import DiffusionPipeline
import torch

pipe = DiffusionPipeline.from_pretrained(
    "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", 
    transformer=None,
    torch_dtype=torch.float16,
).to("cuda")
pipe.transformer = transformer

prompt = "A small cactus with a happy face in the Sahara desert."
image = pipe(prompt).images[0]

In the future, we can expect to pass the transformer directly when initializing the pipeline so that this will work:

pipe = PixArtSigmaPipeline.from_pretrained(
    "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", 
-    transformer=None,
+    transformer=transformer,
    torch_dtype=torch.float16,
).to("cuda")

QuantizedPixArtTransformer2DModel implementation is available here for reference. If you want more models from Diffusers supported in Quanto for saving and loading, please open an issue here and mention @sayakpaul.

Tips

  • Based on your requirements, you may want to apply different types of quantization to different pipeline modules. For example, you could use FP8 for the text encoder but INT8 for the diffusion transformer. Thanks to the flexibility of Diffusers and Quanto, this can be done seamlessly.
  • To optimize for your use cases, you can even combine quantization with other memory optimization techniques in Diffusers, such as enable_model_cpu_offload().

Conclusion

In this post, we showed how to quantize Transformer models from Diffusers and optimize their memory consumption. The effects of quantization become more visible when we additionally quantize the text encoders involved in the mix. We hope you will apply some of the workflows to your projects and benefit from them 🤗.

Thanks to Pedro Cuenca for his extensive reviews on the post.