Dragunflie-420
commited on
Commit
•
a625758
1
Parent(s):
797076f
Update README.md
Browse files
README.md
CHANGED
@@ -1,198 +1,166 @@
|
|
1 |
---
|
2 |
license: mit
|
3 |
---
|
|
|
4 |
|
5 |
-
|
6 |
|
7 |
-
|
8 |
|
9 |
-
This
|
|
|
10 |
|
11 |
-
|
|
|
|
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
|
16 |
|
|
|
|
|
|
|
|
|
17 |
|
|
|
18 |
|
19 |
-
- **Developed by:** [More Information Needed]
|
20 |
-
- **Funded by [optional]:** [More Information Needed]
|
21 |
-
- **Shared by [optional]:** [More Information Needed]
|
22 |
-
- **Model type:** [More Information Needed]
|
23 |
-
- **Language(s) (NLP):** [More Information Needed]
|
24 |
-
- **License:** [More Information Needed]
|
25 |
-
- **Finetuned from model [optional]:** [More Information Needed]
|
26 |
|
27 |
-
|
28 |
|
29 |
-
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
34 |
|
35 |
-
|
|
|
36 |
|
37 |
-
|
|
|
|
|
|
|
38 |
|
39 |
-
### Direct Use
|
40 |
|
41 |
-
|
|
|
42 |
|
43 |
-
[
|
|
|
|
|
|
|
44 |
|
45 |
-
|
|
|
|
|
46 |
|
47 |
-
|
48 |
|
49 |
-
|
|
|
|
|
|
|
50 |
|
51 |
-
### Out-of-Scope Use
|
52 |
|
53 |
-
|
|
|
|
|
54 |
|
55 |
-
|
|
|
|
|
56 |
|
57 |
-
## Bias, Risks, and Limitations
|
58 |
|
59 |
-
|
60 |
|
61 |
-
[
|
|
|
|
|
62 |
|
63 |
-
|
|
|
|
|
64 |
|
65 |
-
|
66 |
|
67 |
-
|
|
|
|
|
68 |
|
69 |
-
|
|
|
|
|
|
|
|
|
70 |
|
71 |
-
|
|
|
72 |
|
73 |
-
|
|
|
|
|
|
|
74 |
|
75 |
-
|
|
|
|
|
|
|
76 |
|
77 |
-
|
|
|
|
|
|
|
|
|
78 |
|
79 |
-
|
80 |
|
81 |
-
|
82 |
|
83 |
-
|
|
|
|
|
|
|
84 |
|
85 |
-
|
|
|
|
|
86 |
|
87 |
-
|
88 |
|
89 |
-
[More Information Needed]
|
90 |
|
|
|
91 |
|
92 |
-
|
|
|
|
|
|
|
93 |
|
94 |
-
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
95 |
|
96 |
-
|
97 |
|
98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
-
[More Information Needed]
|
101 |
|
102 |
-
##
|
|
|
|
|
103 |
|
104 |
-
|
105 |
|
106 |
-
### Testing Data, Factors & Metrics
|
107 |
|
108 |
-
|
109 |
-
|
110 |
-
<!-- This should link to a Dataset Card if possible. -->
|
111 |
-
|
112 |
-
[More Information Needed]
|
113 |
-
|
114 |
-
#### Factors
|
115 |
-
|
116 |
-
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
117 |
-
|
118 |
-
[More Information Needed]
|
119 |
-
|
120 |
-
#### Metrics
|
121 |
-
|
122 |
-
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
123 |
-
|
124 |
-
[More Information Needed]
|
125 |
-
|
126 |
-
### Results
|
127 |
-
|
128 |
-
[More Information Needed]
|
129 |
-
|
130 |
-
#### Summary
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
## Model Examination [optional]
|
135 |
-
|
136 |
-
<!-- Relevant interpretability work for the model goes here -->
|
137 |
-
|
138 |
-
[More Information Needed]
|
139 |
-
|
140 |
-
## Environmental Impact
|
141 |
-
|
142 |
-
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
143 |
-
|
144 |
-
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
145 |
-
|
146 |
-
- **Hardware Type:** [More Information Needed]
|
147 |
-
- **Hours used:** [More Information Needed]
|
148 |
-
- **Cloud Provider:** [More Information Needed]
|
149 |
-
- **Compute Region:** [More Information Needed]
|
150 |
-
- **Carbon Emitted:** [More Information Needed]
|
151 |
-
|
152 |
-
## Technical Specifications [optional]
|
153 |
-
|
154 |
-
### Model Architecture and Objective
|
155 |
-
|
156 |
-
[More Information Needed]
|
157 |
-
|
158 |
-
### Compute Infrastructure
|
159 |
-
|
160 |
-
[More Information Needed]
|
161 |
-
|
162 |
-
#### Hardware
|
163 |
-
|
164 |
-
[More Information Needed]
|
165 |
-
|
166 |
-
#### Software
|
167 |
-
|
168 |
-
[More Information Needed]
|
169 |
-
|
170 |
-
## Citation [optional]
|
171 |
-
|
172 |
-
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
173 |
-
|
174 |
-
**BibTeX:**
|
175 |
-
|
176 |
-
[More Information Needed]
|
177 |
-
|
178 |
-
**APA:**
|
179 |
-
|
180 |
-
[More Information Needed]
|
181 |
-
|
182 |
-
## Glossary [optional]
|
183 |
-
|
184 |
-
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
185 |
-
|
186 |
-
[More Information Needed]
|
187 |
-
|
188 |
-
## More Information [optional]
|
189 |
-
|
190 |
-
[More Information Needed]
|
191 |
-
|
192 |
-
## Model Card Authors [optional]
|
193 |
-
|
194 |
-
[More Information Needed]
|
195 |
-
|
196 |
-
## Model Card Contact
|
197 |
-
|
198 |
-
[More Information Needed]
|
|
|
1 |
---
|
2 |
license: mit
|
3 |
---
|
4 |
+
## Scalable Diffusion Models with Transformers (DiT)<br><sub>Official PyTorch Implementation</sub>
|
5 |
|
6 |
+
### [Paper](http://arxiv.org/abs/2212.09748) | [Project Page](https://www.wpeebles.com/DiT) | Run DiT-XL/2 [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/wpeebles/DiT) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/facebookresearch/DiT/blob/main/run_DiT.ipynb) <a href="https://replicate.com/arielreplicate/scalable_diffusion_with_transformers"><img src="https://replicate.com/arielreplicate/scalable_diffusion_with_transformers/badge"></a>
|
7 |
|
8 |
+
![DiT samples](visuals/sample_grid_0.png)
|
9 |
|
10 |
+
This repo contains PyTorch model definitions, pre-trained weights and training/sampling code for our paper exploring
|
11 |
+
diffusion models with transformers (DiTs). You can find more visualizations on our [project page](https://www.wpeebles.com/DiT).
|
12 |
|
13 |
+
> [**Scalable Diffusion Models with Transformers**](https://www.wpeebles.com/DiT)<br>
|
14 |
+
> [William Peebles](https://www.wpeebles.com), [Saining Xie](https://www.sainingxie.com)
|
15 |
+
> <br>UC Berkeley, New York University<br>
|
16 |
|
17 |
+
We train latent diffusion models, replacing the commonly-used U-Net backbone with a transformer that operates on
|
18 |
+
latent patches. We analyze the scalability of our Diffusion Transformers (DiTs) through the lens of forward pass
|
19 |
+
complexity as measured by Gflops. We find that DiTs with higher Gflops---through increased transformer depth/width or
|
20 |
+
increased number of input tokens---consistently have lower FID. In addition to good scalability properties, our
|
21 |
+
DiT-XL/2 models outperform all prior diffusion models on the class-conditional ImageNet 512×512 and 256×256 benchmarks,
|
22 |
+
achieving a state-of-the-art FID of 2.27 on the latter.
|
23 |
|
24 |
+
This repository contains:
|
25 |
|
26 |
+
* 🪐 A simple PyTorch [implementation](models.py) of DiT
|
27 |
+
* ⚡️ Pre-trained class-conditional DiT models trained on ImageNet (512x512 and 256x256)
|
28 |
+
* 💥 A self-contained [Hugging Face Space](https://huggingface.co/spaces/wpeebles/DiT) and [Colab notebook](http://colab.research.google.com/github/facebookresearch/DiT/blob/main/run_DiT.ipynb) for running pre-trained DiT-XL/2 models
|
29 |
+
* 🛸 A DiT [training script](train.py) using PyTorch DDP
|
30 |
|
31 |
+
An implementation of DiT directly in Hugging Face `diffusers` can also be found [here](https://github.com/huggingface/diffusers/blob/main/docs/source/en/api/pipelines/dit.mdx).
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
+
## Setup
|
35 |
|
36 |
+
First, download and set up the repo:
|
37 |
|
38 |
+
```bash
|
39 |
+
git clone https://github.com/facebookresearch/DiT.git
|
40 |
+
cd DiT
|
41 |
+
```
|
42 |
|
43 |
+
We provide an [`environment.yml`](environment.yml) file that can be used to create a Conda environment. If you only want
|
44 |
+
to run pre-trained models locally on CPU, you can remove the `cudatoolkit` and `pytorch-cuda` requirements from the file.
|
45 |
|
46 |
+
```bash
|
47 |
+
conda env create -f environment.yml
|
48 |
+
conda activate DiT
|
49 |
+
```
|
50 |
|
|
|
51 |
|
52 |
+
## Sampling [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/wpeebles/DiT) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/facebookresearch/DiT/blob/main/run_DiT.ipynb)
|
53 |
+
![More DiT samples](visuals/sample_grid_1.png)
|
54 |
|
55 |
+
**Pre-trained DiT checkpoints.** You can sample from our pre-trained DiT models with [`sample.py`](sample.py). Weights for our pre-trained DiT model will be
|
56 |
+
automatically downloaded depending on the model you use. The script has various arguments to switch between the 256x256
|
57 |
+
and 512x512 models, adjust sampling steps, change the classifier-free guidance scale, etc. For example, to sample from
|
58 |
+
our 512x512 DiT-XL/2 model, you can use:
|
59 |
|
60 |
+
```bash
|
61 |
+
python sample.py --image-size 512 --seed 1
|
62 |
+
```
|
63 |
|
64 |
+
For convenience, our pre-trained DiT models can be downloaded directly here as well:
|
65 |
|
66 |
+
| DiT Model | Image Resolution | FID-50K | Inception Score | Gflops |
|
67 |
+
|---------------|------------------|---------|-----------------|--------|
|
68 |
+
| [XL/2](https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-256x256.pt) | 256x256 | 2.27 | 278.24 | 119 |
|
69 |
+
| [XL/2](https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-512x512.pt) | 512x512 | 3.04 | 240.82 | 525 |
|
70 |
|
|
|
71 |
|
72 |
+
**Custom DiT checkpoints.** If you've trained a new DiT model with [`train.py`](train.py) (see [below](#training-dit)), you can add the `--ckpt`
|
73 |
+
argument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom
|
74 |
+
256x256 DiT-L/4 model, run:
|
75 |
|
76 |
+
```bash
|
77 |
+
python sample.py --model DiT-L/4 --image-size 256 --ckpt /path/to/model.pt
|
78 |
+
```
|
79 |
|
|
|
80 |
|
81 |
+
## Training DiT
|
82 |
|
83 |
+
We provide a training script for DiT in [`train.py`](train.py). This script can be used to train class-conditional
|
84 |
+
DiT models, but it can be easily modified to support other types of conditioning. To launch DiT-XL/2 (256x256) training with `N` GPUs on
|
85 |
+
one node:
|
86 |
|
87 |
+
```bash
|
88 |
+
torchrun --nnodes=1 --nproc_per_node=N train.py --model DiT-XL/2 --data-path /path/to/imagenet/train
|
89 |
+
```
|
90 |
|
91 |
+
### PyTorch Training Results
|
92 |
|
93 |
+
We've trained DiT-XL/2 and DiT-B/4 models from scratch with the PyTorch training script
|
94 |
+
to verify that it reproduces the original JAX results up to several hundred thousand training iterations. Across our experiments, the PyTorch-trained models give
|
95 |
+
similar (and sometimes slightly better) results compared to the JAX-trained models up to reasonable random variation. Some data points:
|
96 |
|
97 |
+
| DiT Model | Train Steps | FID-50K<br> (JAX Training) | FID-50K<br> (PyTorch Training) | PyTorch Global Training Seed |
|
98 |
+
|------------|-------------|----------------------------|--------------------------------|------------------------------|
|
99 |
+
| XL/2 | 400K | 19.5 | **18.1** | 42 |
|
100 |
+
| B/4 | 400K | **68.4** | 68.9 | 42 |
|
101 |
+
| B/4 | 400K | 68.4 | **68.3** | 100 |
|
102 |
|
103 |
+
These models were trained at 256x256 resolution; we used 8x A100s to train XL/2 and 4x A100s to train B/4. Note that FID
|
104 |
+
here is computed with 250 DDPM sampling steps, with the `mse` VAE decoder and without guidance (`cfg-scale=1`).
|
105 |
|
106 |
+
**TF32 Note (important for A100 users).** When we ran the above tests, TF32 matmuls were disabled per PyTorch's defaults.
|
107 |
+
We've enabled them at the top of `train.py` and `sample.py` because it makes training and sampling way way way faster on
|
108 |
+
A100s (and should for other Ampere GPUs too), but note that the use of TF32 may lead to some differences compared to
|
109 |
+
the above results.
|
110 |
|
111 |
+
### Enhancements
|
112 |
+
Training (and sampling) could likely be sped-up significantly by:
|
113 |
+
- [ ] using [Flash Attention](https://github.com/HazyResearch/flash-attention) in the DiT model
|
114 |
+
- [ ] using `torch.compile` in PyTorch 2.0
|
115 |
|
116 |
+
Basic features that would be nice to add:
|
117 |
+
- [ ] Monitor FID and other metrics
|
118 |
+
- [ ] Generate and save samples from the EMA model periodically
|
119 |
+
- [ ] Resume training from a checkpoint
|
120 |
+
- [ ] AMP/bfloat16 support
|
121 |
|
122 |
+
**🔥 Feature Update** Check out this repository at https://github.com/chuanyangjin/fast-DiT to preview a selection of training speed acceleration and memory saving features including gradient checkpointing, mixed precision training and pre-extrated VAE features. With these advancements, we have achieved a training speed of 0.84 steps/sec for DiT-XL/2 using just a single A100 GPU.
|
123 |
|
124 |
+
## Evaluation (FID, Inception Score, etc.)
|
125 |
|
126 |
+
We include a [`sample_ddp.py`](sample_ddp.py) script which samples a large number of images from a DiT model in parallel. This script
|
127 |
+
generates a folder of samples as well as a `.npz` file which can be directly used with [ADM's TensorFlow
|
128 |
+
evaluation suite](https://github.com/openai/guided-diffusion/tree/main/evaluations) to compute FID, Inception Score and
|
129 |
+
other metrics. For example, to sample 50K images from our pre-trained DiT-XL/2 model over `N` GPUs, run:
|
130 |
|
131 |
+
```bash
|
132 |
+
torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py --model DiT-XL/2 --num-fid-samples 50000
|
133 |
+
```
|
134 |
|
135 |
+
There are several additional options; see [`sample_ddp.py`](sample_ddp.py) for details.
|
136 |
|
|
|
137 |
|
138 |
+
## Differences from JAX
|
139 |
|
140 |
+
Our models were originally trained in JAX on TPUs. The weights in this repo are ported directly from the JAX models.
|
141 |
+
There may be minor differences in results stemming from sampling with different floating point precisions. We re-evaluated
|
142 |
+
our ported PyTorch weights at FP32, and they actually perform marginally better than sampling in JAX (2.21 FID
|
143 |
+
versus 2.27 in the paper).
|
144 |
|
|
|
145 |
|
146 |
+
## BibTeX
|
147 |
|
148 |
+
```bibtex
|
149 |
+
@article{Peebles2022DiT,
|
150 |
+
title={Scalable Diffusion Models with Transformers},
|
151 |
+
author={William Peebles and Saining Xie},
|
152 |
+
year={2022},
|
153 |
+
journal={arXiv preprint arXiv:2212.09748},
|
154 |
+
}
|
155 |
+
```
|
156 |
|
|
|
157 |
|
158 |
+
## Acknowledgments
|
159 |
+
We thank Kaiming He, Ronghang Hu, Alexander Berg, Shoubhik Debnath, Tim Brooks, Ilija Radosavovic and Tete Xiao for helpful discussions.
|
160 |
+
William Peebles is supported by the NSF Graduate Research Fellowship.
|
161 |
|
162 |
+
This codebase borrows from OpenAI's diffusion repos, most notably [ADM](https://github.com/openai/guided-diffusion).
|
163 |
|
|
|
164 |
|
165 |
+
## License
|
166 |
+
The code and model weights are licensed under CC-BY-NC. See [`LICENSE.txt`](LICENSE.txt) for details.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|