jiangqin commited on
Commit
200d6ae
·
verified ·
1 Parent(s): 3bf556e

End of training

Browse files
README.md ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: openrail++
3
+ library_name: diffusers
4
+ tags:
5
+ - text-to-image
6
+ - text-to-image
7
+ - diffusers-training
8
+ - diffusers
9
+ - dora
10
+ - template:sd-lora
11
+ - stable-diffusion-xl
12
+ - stable-diffusion-xl-diffusers
13
+ base_model: stabilityai/stable-diffusion-xl-base-1.0
14
+ instance_prompt: a photo of TOK screw icon
15
+ widget: []
16
+ ---
17
+
18
+ <!-- This model card has been generated automatically according to the information the training script had access to. You
19
+ should probably proofread and complete it, then remove this comment. -->
20
+
21
+
22
+ # SDXL LoRA DreamBooth - jiangqin/3d-icon-sdxl-lora
23
+
24
+ <Gallery />
25
+
26
+ ## Model description
27
+
28
+ These are jiangqin/3d-icon-sdxl-lora LoRA adaption weights for stabilityai/stable-diffusion-xl-base-1.0.
29
+
30
+ The weights were trained using [DreamBooth](https://dreambooth.github.io/).
31
+
32
+ LoRA for the text encoder was enabled: False.
33
+
34
+ Special VAE used for training: madebyollin/sdxl-vae-fp16-fix.
35
+
36
+ ## Trigger words
37
+
38
+ You should use a photo of TOK screw icon to trigger the image generation.
39
+
40
+ ## Download model
41
+
42
+ Weights for this model are available in Safetensors format.
43
+
44
+ [Download](jiangqin/3d-icon-sdxl-lora/tree/main) them in the Files & versions tab.
45
+
46
+
47
+
48
+ ## Intended uses & limitations
49
+
50
+ #### How to use
51
+
52
+ ```python
53
+ # TODO: add an example code snippet for running this diffusion pipeline
54
+ ```
55
+
56
+ #### Limitations and bias
57
+
58
+ [TODO: provide examples of latent issues and potential remediations]
59
+
60
+ ## Training details
61
+
62
+ [TODO: describe the data used to train the model]
__pycache__/train_dreambooth_lora_sdxl.cpython-310.pyc ADDED
Binary file (44.4 kB). View file
 
logs/dreambooth-lora-sd-xl/1716833427.9533217/events.out.tfevents.1716833427.6e875dbe58bb.346.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d54233d50bed8768d766bd9cc52ddb65e896391b3912cf7f4045c59c00522454
3
+ size 3316
logs/dreambooth-lora-sd-xl/1716833427.957748/hparams.yml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adam_beta1: 0.9
2
+ adam_beta2: 0.999
3
+ adam_epsilon: 1.0e-08
4
+ adam_weight_decay: 0.0001
5
+ adam_weight_decay_text_encoder: 0.001
6
+ allow_tf32: false
7
+ cache_dir: null
8
+ caption_column: prompt
9
+ center_crop: false
10
+ checkpointing_steps: 717
11
+ checkpoints_total_limit: null
12
+ class_data_dir: null
13
+ class_prompt: null
14
+ dataloader_num_workers: 0
15
+ dataset_config_name: null
16
+ dataset_name: /kaggle/input/screwsss
17
+ do_edm_style_training: false
18
+ enable_xformers_memory_efficient_attention: false
19
+ gradient_accumulation_steps: 3
20
+ gradient_checkpointing: true
21
+ hub_model_id: null
22
+ hub_token: null
23
+ image_column: image
24
+ instance_data_dir: null
25
+ instance_prompt: a photo of TOK dog
26
+ learning_rate: 0.0001
27
+ local_rank: 0
28
+ logging_dir: logs
29
+ lr_num_cycles: 1
30
+ lr_power: 1.0
31
+ lr_scheduler: constant
32
+ lr_warmup_steps: 0
33
+ max_grad_norm: 1.0
34
+ max_train_steps: 500
35
+ mixed_precision: fp16
36
+ num_class_images: 100
37
+ num_train_epochs: 39
38
+ num_validation_images: 4
39
+ optimizer: AdamW
40
+ output_dir: /kaggle/working/
41
+ output_kohya_format: false
42
+ pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0
43
+ pretrained_vae_model_name_or_path: madebyollin/sdxl-vae-fp16-fix
44
+ prior_generation_precision: null
45
+ prior_loss_weight: 1.0
46
+ prodigy_beta3: null
47
+ prodigy_decouple: true
48
+ prodigy_safeguard_warmup: true
49
+ prodigy_use_bias_correction: true
50
+ push_to_hub: false
51
+ random_flip: false
52
+ rank: 4
53
+ repeats: 1
54
+ report_to: tensorboard
55
+ resolution: 1024
56
+ resume_from_checkpoint: null
57
+ revision: null
58
+ sample_batch_size: 4
59
+ scale_lr: false
60
+ seed: 0
61
+ snr_gamma: 5.0
62
+ text_encoder_lr: 5.0e-06
63
+ train_batch_size: 1
64
+ train_text_encoder: false
65
+ use_8bit_adam: true
66
+ use_dora: false
67
+ validation_epochs: 50
68
+ validation_prompt: null
69
+ variant: null
70
+ with_prior_preservation: false
logs/dreambooth-lora-sd-xl/1716833649.9581614/events.out.tfevents.1716833649.6e875dbe58bb.446.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d98ac573304c58fd3f97b577c81e2384f13bd1a027467525bf1a76f60953d88
3
+ size 3324
logs/dreambooth-lora-sd-xl/1716833649.9637702/hparams.yml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ adam_beta1: 0.9
2
+ adam_beta2: 0.999
3
+ adam_epsilon: 1.0e-08
4
+ adam_weight_decay: 0.0001
5
+ adam_weight_decay_text_encoder: 0.001
6
+ allow_tf32: false
7
+ cache_dir: null
8
+ caption_column: prompt
9
+ center_crop: false
10
+ checkpointing_steps: 717
11
+ checkpoints_total_limit: null
12
+ class_data_dir: null
13
+ class_prompt: null
14
+ dataloader_num_workers: 0
15
+ dataset_config_name: null
16
+ dataset_name: /kaggle/input/screwsss
17
+ do_edm_style_training: false
18
+ enable_xformers_memory_efficient_attention: false
19
+ gradient_accumulation_steps: 3
20
+ gradient_checkpointing: true
21
+ hub_model_id: null
22
+ hub_token: null
23
+ image_column: image
24
+ instance_data_dir: null
25
+ instance_prompt: 'a photo of TOK screw icon '
26
+ learning_rate: 0.0001
27
+ local_rank: 0
28
+ logging_dir: logs
29
+ lr_num_cycles: 1
30
+ lr_power: 1.0
31
+ lr_scheduler: constant
32
+ lr_warmup_steps: 0
33
+ max_grad_norm: 1.0
34
+ max_train_steps: 500
35
+ mixed_precision: fp16
36
+ num_class_images: 100
37
+ num_train_epochs: 39
38
+ num_validation_images: 4
39
+ optimizer: AdamW
40
+ output_dir: /kaggle/working/
41
+ output_kohya_format: false
42
+ pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0
43
+ pretrained_vae_model_name_or_path: madebyollin/sdxl-vae-fp16-fix
44
+ prior_generation_precision: null
45
+ prior_loss_weight: 1.0
46
+ prodigy_beta3: null
47
+ prodigy_decouple: true
48
+ prodigy_safeguard_warmup: true
49
+ prodigy_use_bias_correction: true
50
+ push_to_hub: false
51
+ random_flip: false
52
+ rank: 4
53
+ repeats: 1
54
+ report_to: tensorboard
55
+ resolution: 1024
56
+ resume_from_checkpoint: null
57
+ revision: null
58
+ sample_batch_size: 4
59
+ scale_lr: false
60
+ seed: 0
61
+ snr_gamma: 5.0
62
+ text_encoder_lr: 5.0e-06
63
+ train_batch_size: 1
64
+ train_text_encoder: false
65
+ use_8bit_adam: true
66
+ use_dora: false
67
+ validation_epochs: 50
68
+ validation_prompt: null
69
+ variant: null
70
+ with_prior_preservation: false
logs/dreambooth-lora-sd-xl/events.out.tfevents.1716833427.6e875dbe58bb.346.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d835996c3dc411d35b5475d7774134c95a0ecf955d99cdf2fdbe41a6c077f9c1
3
+ size 1720
logs/dreambooth-lora-sd-xl/events.out.tfevents.1716833649.6e875dbe58bb.446.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44c9408b0f69bff2185ef010c862997658ce1dc2d1bf059d74e2d2764cebc8ed
3
+ size 125314
metadata.jsonl ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-modified-truss-4295189333-v1.png", "prompt": "a photo of TOK screw icon, screw screw screw vector icon design"}
2
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-thumb-4295189265-v1.png", "prompt": "a photo of TOK screw icon, a screw screw screw screw screw screw screw screw screw screw screw screw screw screw screw screw screw screw screw screw"}
3
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-pin-in-head-hex-4296094551-v1.png", "prompt": "a photo of TOK screw icon, a black and white circle with an orange dot"}
4
+ {"file_name": "hardware-fasteners-screws-562860-features-587081-coated-4295291747-v1.png", "prompt": "a photo of TOK screw icon, a hand holding a spray gun with orange dots"}
5
+ {"file_name": "hardware-fasteners-screws-562860-fastener-type--screw-cover--v1.png", "prompt": "a photo of TOK screw icon, a black and white icon of a metal cone"}
6
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-tri-wing-4296094256-v1.png", "prompt": "a photo of TOK screw icon, a black and white icon of a arrow"}
7
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-countersunk-head-4296088423-v1.png", "prompt": "a photo of TOK screw icon, a screw screw with a long thread icon"}
8
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-external-hex-4296094345-v1.png", "prompt": "a photo of TOK screw icon, a hexagon icon"}
9
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-wafer-head-4295189324-v1.png", "prompt": "a photo of TOK screw icon, a microphone icon"}
10
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-button-head-4295189376-v1.png", "prompt": "a photo of TOK screw icon, a screw screw with a flat head"}
11
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-washer-head-4295189252-v1.png", "prompt": "a photo of TOK screw icon, a screw screw with a screwdriveer on top"}
12
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-fillister-head-4295189409-v1.png", "prompt": "a photo of TOK screw icon, a screw screw with a flat head"}
13
+ {"file_name": "hardware-fasteners-screws-562860-features-587081-serrated-edges-4295291735-v1.png", "prompt": "a photo of TOK screw icon, a black and white drawing of a letter b"}
14
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-socket-head-4295324600-v1.png", "prompt": "a photo of TOK screw icon, a black and white icon of a brush"}
15
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-flange-hex-4295189152-v1.png", "prompt": "a photo of TOK screw icon, a screw screw with a screwdriveer on it"}
16
+ {"file_name": "hardware-fasteners-screws-562860-indoor-outdoor-508147-indoor-4294843918-v1.png", "prompt": "a photo of TOK screw icon, a house with an arrow pointing up to the right"}
17
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-round-head-4295189169-v1.png", "prompt": "a photo of TOK screw icon, a screw screw with a flat head and a flat head"}
18
+ {"file_name": "hardware-fasteners-screws-562860-features-587081-self-drilling-4295290986-v1.png", "prompt": "a photo of TOK screw icon, a pencil icon"}
19
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-square-head-4295189291-v1.png", "prompt": "a photo of TOK screw icon, a screw screw with a screwdriveer on top"}
20
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-cylinder-head-4295325859-v1.png", "prompt": "a photo of TOK screw icon, a screw screw with a screwdriveer on top"}
21
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-external-square-4296094427-v1.png", "prompt": "a photo of TOK screw icon, a black and white square frame"}
22
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-phillips-square-4296094400-v1.png", "prompt": "a photo of TOK screw icon, a black and white icon of a circular arrow"}
23
+ {"file_name": "hardware-fasteners-screws-562860-size-562860-screw-size--v1.png", "prompt": "a photo of TOK screw icon, a white and orange icon with the letter o"}
24
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-flat-undercut-4295189380-v1.png", "prompt": "a photo of TOK screw icon, a screw screw with a long thread"}
25
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-pin-in-star-4296094312-v1.png", "prompt": "a photo of TOK screw icon, a black and white circle with an orange dot"}
26
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-truss-4295189213-v1.png", "prompt": "a photo of TOK screw icon, screw screw screw vector icon design"}
27
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-socket-cap-head-4295189259-v1.png", "prompt": "a photo of TOK screw icon, a black and white icon of a brush"}
28
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-pan-head-4295189323-v1.png", "prompt": "a photo of TOK screw icon, a screw screw with a flat head"}
29
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-hex-4295189384-v1.png", "prompt": "a photo of TOK screw icon, screw screw screw vector icon design"}
30
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-pancake-head-4295213029-v1.png", "prompt": "a photo of TOK screw icon, screw screw screw vector icon design"}
31
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-slotted-4296094277-v1.png", "prompt": "a photo of TOK screw icon, a black and white icon of a no entry sign"}
32
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-torx-4296094695-v1.png", "prompt": "a photo of TOK screw icon, a black and white icon of a circle"}
33
+ {"file_name": "hardware-fasteners-screws-562860-features-587081-tamper-resistant-4295291155-v1.png", "prompt": "a photo of TOK screw icon, a white and orange egg with a black outline"}
34
+ {"file_name": "hardware-fasteners-screws-562860-interior-exterior-584860-exterior-4295122957-v1.png", "prompt": "a photo of TOK screw icon, a house with an arrow pointing up to the roof"}
35
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-thumbscrew-4296094403-v1.png", "prompt": "a photo of TOK screw icon, a black and white icon of a circle"}
36
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-bugle-head-4295189239-v1.png", "prompt": "a photo of TOK screw icon, a black and white letter t"}
37
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-oval-4295189149-v1.png", "prompt": "a photo of TOK screw icon, a light bulb icon"}
38
+ {"file_name": "hardware-fasteners-screws-562860-features-587081-self-tapping-4295289557-v1.png", "prompt": "a photo of TOK screw icon, a black and white icon of a lightning"}
39
+ {"file_name": "hardware-fasteners-screws-562860-product-weight-lb-544538-weight--v1.png", "prompt": "a photo of TOK screw icon, a black and white icon of a purse"}
40
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-lox-4296094477-v1.png", "prompt": "a photo of TOK screw icon, a black and white icon of a cross"}
41
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-t-star-plus-4296094689-v1.png", "prompt": "a photo of TOK screw icon, a black and white icon of a star"}
42
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-12-point-triple-square-4296094318-v1.png", "prompt": "a photo of TOK screw icon, a black and white circle with a star inside"}
43
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-cross-4296094710-v1.png", "prompt": "a photo of TOK screw icon, a black and white icon of a cross"}
44
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-headless-4295295552-v1.png", "prompt": "a photo of TOK screw icon, a stack of books icon"}
45
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-oval-head-4295189257-v1.png", "prompt": "a photo of TOK screw icon, a light bulb icon"}
46
+ {"file_name": "hardware-fasteners-screws-562860-screw-length-571675-screw-length--v1.png", "prompt": "a photo of TOK screw icon, two screws with a long, pointed end"}
47
+ {"file_name": "hardware-fasteners-screws-562860-indoor-outdoor-508147-indoor-outdoor-4294844044-v1.png", "prompt": "a photo of TOK screw icon, a house with an arrow pointing up to the right"}
48
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-flat-head-4295189272-v1.png", "prompt": "a photo of TOK screw icon, a white and black icon of a funnel"}
49
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-phillips-slotted-4296094468-v1.png", "prompt": "a photo of TOK screw icon, a black and white icon of a cross"}
50
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-510344-phillips-4294966012-v1.png", "prompt": "a photo of TOK screw icon, a cross in a circle"}
51
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-scavenger-head-4295373652-v1.png", "prompt": "a photo of TOK screw icon, a screw screw with a long thread icon"}
52
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-one-way-4296094379-v1.png", "prompt": "a photo of TOK screw icon, a black and white circle with a white cross in the middle"}
53
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-spanner-4296094604-v1.png", "prompt": "a photo of TOK screw icon, a black and white icon of a smiley face"}
54
+ {"file_name": "hardware-fasteners-screws-562860-package-quantity-510300-package-quantity--v1.png", "prompt": "a photo of TOK screw icon, a cube with a cross on it"}
55
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-spade-4295189406-v1.png", "prompt": "a photo of TOK screw icon, screw screw screw vector icon design"}
56
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-internal-hex-4296094585-v1.png", "prompt": "a photo of TOK screw icon, a black and white icon of a pentagon"}
57
+ {"file_name": "hardware-fasteners-screws-562861-head-style--serrated-flange-hex--v1.png", "prompt": "a photo of TOK screw icon, a light bulb with a long tail icon"}
58
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-allen-4296094304-v1.png", "prompt": "a photo of TOK screw icon, a black and white icon of a pentagon"}
59
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-double-ended-4296094441-v1.png", "prompt": "a photo of TOK screw icon, a stack of books icon"}
60
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-truss-head-4295189302-v1.png", "prompt": "a photo of TOK screw icon, screw screw screw vector icon design"}
61
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-low-profile-head-4296110807-v1.png", "prompt": "a photo of TOK screw icon, screw screw screw vector icon design"}
62
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-star-4296094439-v1.png", "prompt": "a photo of TOK screw icon, a black and white icon of a circle"}
63
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-6-lobe-4296098857-v1.png", "prompt": "a photo of TOK screw icon, a black and white icon of a circle"}
64
+ {"file_name": "hardware-fasteners-screws-562860-features-587081-self-piercing-4296057728-v1.png", "prompt": "a photo of TOK screw icon, a black and white icon of a lightning"}
65
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-low-profile-washer-4295189351-v1.png", "prompt": "a photo of TOK screw icon, a podium icon"}
66
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-internal-square-4296094329-v1.png", "prompt": "a photo of TOK screw icon, a black and white circle with a square in the middle"}
67
+ {"file_name": "hardware-fasteners-screws-562860-head-style-571657-flat-4294425208-v1.png", "prompt": "a photo of TOK screw icon, a white and black icon of a funnel"}
68
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-trim-4295189215-v1.png", "prompt": "a photo of TOK screw icon, a black and white icon of a funnel"}
69
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-dowel-4295189316-v1.png", "prompt": "a photo of TOK screw icon, a stack of books icon"}
70
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-security-torx-4296094680-v1.png", "prompt": "a photo of TOK screw icon, a black and white circle with an orange dot"}
71
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-knurled-4295189411-v1.png", "prompt": "a photo of TOK screw icon, a screw screw screw screw screw screw screw screw screw screw screw screw screw screw screw screw screw screw screw screw"}
72
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-phillips-4296094358-v1.png", "prompt": "a photo of TOK screw icon, a cross in a circle"}
73
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-hex-4296094521-v1.png", "prompt": "a photo of TOK screw icon, a hexagon icon"}
74
+ {"file_name": "hardware-fasteners-screws-562860-interior-exterior-584860-interior-4295122971-v1.png", "prompt": "a photo of TOK screw icon, a house with an arrow pointing up to the right"}
75
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-square-4296094651-v1.png", "prompt": "a photo of TOK screw icon, a black and white circle with a square in the middle"}
76
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-binding-4295189210-v1.png", "prompt": "a photo of TOK screw icon, screw screw screw vector icon design"}
77
+ {"file_name": "hardware-fasteners-screws-562860-head-style-587571-wing-4295189166-v1.png", "prompt": "a photo of TOK screw icon, a light bulb with a light bulb inside"}
78
+ {"file_name": "hardware-fasteners-screws-562860-drive-style-5295328677-combo-4296094687-v1.png", "prompt": "a photo of TOK screw icon, a black and white icon of a cross"}
pytorch_lora_weights.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c941dd5fdd8cc1bbe2fd6662cce9ee2049ee8007e221123c3a50bb51d5a9dbb0
3
+ size 23390424
train_dreambooth_lora_sdxl.py ADDED
@@ -0,0 +1,1984 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import gc
18
+ import itertools
19
+ import json
20
+ import logging
21
+ import math
22
+ import os
23
+ import random
24
+ import shutil
25
+ import warnings
26
+ from contextlib import nullcontext
27
+ from pathlib import Path
28
+
29
+ import numpy as np
30
+ import torch
31
+ import torch.nn.functional as F
32
+ import torch.utils.checkpoint
33
+ import transformers
34
+ from accelerate import Accelerator
35
+ from accelerate.logging import get_logger
36
+ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
37
+ from huggingface_hub import create_repo, hf_hub_download, upload_folder
38
+ from huggingface_hub.utils import insecure_hashlib
39
+ from packaging import version
40
+ from peft import LoraConfig, set_peft_model_state_dict
41
+ from peft.utils import get_peft_model_state_dict
42
+ from PIL import Image
43
+ from PIL.ImageOps import exif_transpose
44
+ from safetensors.torch import load_file, save_file
45
+ from torch.utils.data import Dataset
46
+ from torchvision import transforms
47
+ from torchvision.transforms.functional import crop
48
+ from tqdm.auto import tqdm
49
+ from transformers import AutoTokenizer, PretrainedConfig
50
+
51
+ import diffusers
52
+ from diffusers import (
53
+ AutoencoderKL,
54
+ DDPMScheduler,
55
+ DPMSolverMultistepScheduler,
56
+ EDMEulerScheduler,
57
+ EulerDiscreteScheduler,
58
+ StableDiffusionXLPipeline,
59
+ UNet2DConditionModel,
60
+ )
61
+ from diffusers.loaders import LoraLoaderMixin
62
+ from diffusers.optimization import get_scheduler
63
+ from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
64
+ from diffusers.utils import (
65
+ check_min_version,
66
+ convert_all_state_dict_to_peft,
67
+ convert_state_dict_to_diffusers,
68
+ convert_state_dict_to_kohya,
69
+ convert_unet_state_dict_to_peft,
70
+ is_wandb_available,
71
+ )
72
+ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
73
+ from diffusers.utils.import_utils import is_xformers_available
74
+ from diffusers.utils.torch_utils import is_compiled_module
75
+
76
+
77
+ if is_wandb_available():
78
+ import wandb
79
+
80
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
81
+ check_min_version("0.28.0.dev0")
82
+
83
+ logger = get_logger(__name__)
84
+
85
+
86
+ def determine_scheduler_type(pretrained_model_name_or_path, revision):
87
+ model_index_filename = "model_index.json"
88
+ if os.path.isdir(pretrained_model_name_or_path):
89
+ model_index = os.path.join(pretrained_model_name_or_path, model_index_filename)
90
+ else:
91
+ model_index = hf_hub_download(
92
+ repo_id=pretrained_model_name_or_path, filename=model_index_filename, revision=revision
93
+ )
94
+
95
+ with open(model_index, "r") as f:
96
+ scheduler_type = json.load(f)["scheduler"][1]
97
+ return scheduler_type
98
+
99
+
100
+ def save_model_card(
101
+ repo_id: str,
102
+ use_dora: bool,
103
+ images=None,
104
+ base_model: str = None,
105
+ train_text_encoder=False,
106
+ instance_prompt=None,
107
+ validation_prompt=None,
108
+ repo_folder=None,
109
+ vae_path=None,
110
+ ):
111
+ widget_dict = []
112
+ if images is not None:
113
+ for i, image in enumerate(images):
114
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
115
+ widget_dict.append(
116
+ {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
117
+ )
118
+
119
+ model_description = f"""
120
+ # {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
121
+
122
+ <Gallery />
123
+
124
+ ## Model description
125
+
126
+ These are {repo_id} LoRA adaption weights for {base_model}.
127
+
128
+ The weights were trained using [DreamBooth](https://dreambooth.github.io/).
129
+
130
+ LoRA for the text encoder was enabled: {train_text_encoder}.
131
+
132
+ Special VAE used for training: {vae_path}.
133
+
134
+ ## Trigger words
135
+
136
+ You should use {instance_prompt} to trigger the image generation.
137
+
138
+ ## Download model
139
+
140
+ Weights for this model are available in Safetensors format.
141
+
142
+ [Download]({repo_id}/tree/main) them in the Files & versions tab.
143
+
144
+ """
145
+ if "playground" in base_model:
146
+ model_description += """\n
147
+ ## License
148
+
149
+ Please adhere to the licensing terms as described [here](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/blob/main/LICENSE.md).
150
+ """
151
+ model_card = load_or_create_model_card(
152
+ repo_id_or_path=repo_id,
153
+ from_training=True,
154
+ license="openrail++" if "playground" not in base_model else "playground-v2dot5-community",
155
+ base_model=base_model,
156
+ prompt=instance_prompt,
157
+ model_description=model_description,
158
+ widget=widget_dict,
159
+ )
160
+ tags = [
161
+ "text-to-image",
162
+ "text-to-image",
163
+ "diffusers-training",
164
+ "diffusers",
165
+ "lora" if not use_dora else "dora",
166
+ "template:sd-lora",
167
+ ]
168
+ if "playground" in base_model:
169
+ tags.extend(["playground", "playground-diffusers"])
170
+ else:
171
+ tags.extend(["stable-diffusion-xl", "stable-diffusion-xl-diffusers"])
172
+
173
+ model_card = populate_model_card(model_card, tags=tags)
174
+ model_card.save(os.path.join(repo_folder, "README.md"))
175
+
176
+
177
+ def log_validation(
178
+ pipeline,
179
+ args,
180
+ accelerator,
181
+ pipeline_args,
182
+ epoch,
183
+ is_final_validation=False,
184
+ ):
185
+ logger.info(
186
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
187
+ f" {args.validation_prompt}."
188
+ )
189
+
190
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
191
+ scheduler_args = {}
192
+
193
+ if not args.do_edm_style_training:
194
+ if "variance_type" in pipeline.scheduler.config:
195
+ variance_type = pipeline.scheduler.config.variance_type
196
+
197
+ if variance_type in ["learned", "learned_range"]:
198
+ variance_type = "fixed_small"
199
+
200
+ scheduler_args["variance_type"] = variance_type
201
+
202
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
203
+
204
+ pipeline = pipeline.to(accelerator.device)
205
+ pipeline.set_progress_bar_config(disable=True)
206
+
207
+ # run inference
208
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
209
+ # Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
210
+ # way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
211
+ if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
212
+ autocast_ctx = nullcontext()
213
+ else:
214
+ autocast_ctx = torch.autocast(accelerator.device.type)
215
+
216
+ with autocast_ctx:
217
+ images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
218
+
219
+ for tracker in accelerator.trackers:
220
+ phase_name = "test" if is_final_validation else "validation"
221
+ if tracker.name == "tensorboard":
222
+ np_images = np.stack([np.asarray(img) for img in images])
223
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
224
+ if tracker.name == "wandb":
225
+ tracker.log(
226
+ {
227
+ phase_name: [
228
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
229
+ ]
230
+ }
231
+ )
232
+
233
+ del pipeline
234
+ if torch.cuda.is_available():
235
+ torch.cuda.empty_cache()
236
+
237
+ return images
238
+
239
+
240
+ def import_model_class_from_model_name_or_path(
241
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
242
+ ):
243
+ text_encoder_config = PretrainedConfig.from_pretrained(
244
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
245
+ )
246
+ model_class = text_encoder_config.architectures[0]
247
+
248
+ if model_class == "CLIPTextModel":
249
+ from transformers import CLIPTextModel
250
+
251
+ return CLIPTextModel
252
+ elif model_class == "CLIPTextModelWithProjection":
253
+ from transformers import CLIPTextModelWithProjection
254
+
255
+ return CLIPTextModelWithProjection
256
+ else:
257
+ raise ValueError(f"{model_class} is not supported.")
258
+
259
+
260
+ def parse_args(input_args=None):
261
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
262
+ parser.add_argument(
263
+ "--pretrained_model_name_or_path",
264
+ type=str,
265
+ default=None,
266
+ required=True,
267
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
268
+ )
269
+ parser.add_argument(
270
+ "--pretrained_vae_model_name_or_path",
271
+ type=str,
272
+ default=None,
273
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
274
+ )
275
+ parser.add_argument(
276
+ "--revision",
277
+ type=str,
278
+ default=None,
279
+ required=False,
280
+ help="Revision of pretrained model identifier from huggingface.co/models.",
281
+ )
282
+ parser.add_argument(
283
+ "--variant",
284
+ type=str,
285
+ default=None,
286
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
287
+ )
288
+ parser.add_argument(
289
+ "--dataset_name",
290
+ type=str,
291
+ default=None,
292
+ help=(
293
+ "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
294
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
295
+ " or to a folder containing files that 🤗 Datasets can understand."
296
+ ),
297
+ )
298
+ parser.add_argument(
299
+ "--dataset_config_name",
300
+ type=str,
301
+ default=None,
302
+ help="The config of the Dataset, leave as None if there's only one config.",
303
+ )
304
+ parser.add_argument(
305
+ "--instance_data_dir",
306
+ type=str,
307
+ default=None,
308
+ help=("A folder containing the training data. "),
309
+ )
310
+
311
+ parser.add_argument(
312
+ "--cache_dir",
313
+ type=str,
314
+ default=None,
315
+ help="The directory where the downloaded models and datasets will be stored.",
316
+ )
317
+
318
+ parser.add_argument(
319
+ "--image_column",
320
+ type=str,
321
+ default="image",
322
+ help="The column of the dataset containing the target image. By "
323
+ "default, the standard Image Dataset maps out 'file_name' "
324
+ "to 'image'.",
325
+ )
326
+ parser.add_argument(
327
+ "--caption_column",
328
+ type=str,
329
+ default=None,
330
+ help="The column of the dataset containing the instance prompt for each image",
331
+ )
332
+
333
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
334
+
335
+ parser.add_argument(
336
+ "--class_data_dir",
337
+ type=str,
338
+ default=None,
339
+ required=False,
340
+ help="A folder containing the training data of class images.",
341
+ )
342
+ parser.add_argument(
343
+ "--instance_prompt",
344
+ type=str,
345
+ default=None,
346
+ required=True,
347
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
348
+ )
349
+ parser.add_argument(
350
+ "--class_prompt",
351
+ type=str,
352
+ default=None,
353
+ help="The prompt to specify images in the same class as provided instance images.",
354
+ )
355
+ parser.add_argument(
356
+ "--validation_prompt",
357
+ type=str,
358
+ default=None,
359
+ help="A prompt that is used during validation to verify that the model is learning.",
360
+ )
361
+ parser.add_argument(
362
+ "--num_validation_images",
363
+ type=int,
364
+ default=4,
365
+ help="Number of images that should be generated during validation with `validation_prompt`.",
366
+ )
367
+ parser.add_argument(
368
+ "--validation_epochs",
369
+ type=int,
370
+ default=50,
371
+ help=(
372
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
373
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
374
+ ),
375
+ )
376
+ parser.add_argument(
377
+ "--do_edm_style_training",
378
+ default=False,
379
+ action="store_true",
380
+ help="Flag to conduct training using the EDM formulation as introduced in https://arxiv.org/abs/2206.00364.",
381
+ )
382
+ parser.add_argument(
383
+ "--with_prior_preservation",
384
+ default=False,
385
+ action="store_true",
386
+ help="Flag to add prior preservation loss.",
387
+ )
388
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
389
+ parser.add_argument(
390
+ "--num_class_images",
391
+ type=int,
392
+ default=100,
393
+ help=(
394
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
395
+ " class_data_dir, additional images will be sampled with class_prompt."
396
+ ),
397
+ )
398
+ parser.add_argument(
399
+ "--output_dir",
400
+ type=str,
401
+ default="lora-dreambooth-model",
402
+ help="The output directory where the model predictions and checkpoints will be written.",
403
+ )
404
+ parser.add_argument(
405
+ "--output_kohya_format",
406
+ action="store_true",
407
+ help="Flag to additionally generate final state dict in the Kohya format so that it becomes compatible with A111, Comfy, Kohya, etc.",
408
+ )
409
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
410
+ parser.add_argument(
411
+ "--resolution",
412
+ type=int,
413
+ default=1024,
414
+ help=(
415
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
416
+ " resolution"
417
+ ),
418
+ )
419
+ parser.add_argument(
420
+ "--center_crop",
421
+ default=False,
422
+ action="store_true",
423
+ help=(
424
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
425
+ " cropped. The images will be resized to the resolution first before cropping."
426
+ ),
427
+ )
428
+ parser.add_argument(
429
+ "--random_flip",
430
+ action="store_true",
431
+ help="whether to randomly flip images horizontally",
432
+ )
433
+ parser.add_argument(
434
+ "--train_text_encoder",
435
+ action="store_true",
436
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
437
+ )
438
+ parser.add_argument(
439
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
440
+ )
441
+ parser.add_argument(
442
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
443
+ )
444
+ parser.add_argument("--num_train_epochs", type=int, default=1)
445
+ parser.add_argument(
446
+ "--max_train_steps",
447
+ type=int,
448
+ default=None,
449
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
450
+ )
451
+ parser.add_argument(
452
+ "--checkpointing_steps",
453
+ type=int,
454
+ default=500,
455
+ help=(
456
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
457
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
458
+ " training using `--resume_from_checkpoint`."
459
+ ),
460
+ )
461
+ parser.add_argument(
462
+ "--checkpoints_total_limit",
463
+ type=int,
464
+ default=None,
465
+ help=("Max number of checkpoints to store."),
466
+ )
467
+ parser.add_argument(
468
+ "--resume_from_checkpoint",
469
+ type=str,
470
+ default=None,
471
+ help=(
472
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
473
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
474
+ ),
475
+ )
476
+ parser.add_argument(
477
+ "--gradient_accumulation_steps",
478
+ type=int,
479
+ default=1,
480
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
481
+ )
482
+ parser.add_argument(
483
+ "--gradient_checkpointing",
484
+ action="store_true",
485
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
486
+ )
487
+ parser.add_argument(
488
+ "--learning_rate",
489
+ type=float,
490
+ default=1e-4,
491
+ help="Initial learning rate (after the potential warmup period) to use.",
492
+ )
493
+
494
+ parser.add_argument(
495
+ "--text_encoder_lr",
496
+ type=float,
497
+ default=5e-6,
498
+ help="Text encoder learning rate to use.",
499
+ )
500
+ parser.add_argument(
501
+ "--scale_lr",
502
+ action="store_true",
503
+ default=False,
504
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
505
+ )
506
+ parser.add_argument(
507
+ "--lr_scheduler",
508
+ type=str,
509
+ default="constant",
510
+ help=(
511
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
512
+ ' "constant", "constant_with_warmup"]'
513
+ ),
514
+ )
515
+
516
+ parser.add_argument(
517
+ "--snr_gamma",
518
+ type=float,
519
+ default=None,
520
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
521
+ "More details here: https://arxiv.org/abs/2303.09556.",
522
+ )
523
+ parser.add_argument(
524
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
525
+ )
526
+ parser.add_argument(
527
+ "--lr_num_cycles",
528
+ type=int,
529
+ default=1,
530
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
531
+ )
532
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
533
+ parser.add_argument(
534
+ "--dataloader_num_workers",
535
+ type=int,
536
+ default=0,
537
+ help=(
538
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
539
+ ),
540
+ )
541
+
542
+ parser.add_argument(
543
+ "--optimizer",
544
+ type=str,
545
+ default="AdamW",
546
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
547
+ )
548
+
549
+ parser.add_argument(
550
+ "--use_8bit_adam",
551
+ action="store_true",
552
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
553
+ )
554
+
555
+ parser.add_argument(
556
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
557
+ )
558
+ parser.add_argument(
559
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
560
+ )
561
+ parser.add_argument(
562
+ "--prodigy_beta3",
563
+ type=float,
564
+ default=None,
565
+ help="coefficients for computing the Prodidy stepsize using running averages. If set to None, "
566
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
567
+ )
568
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
569
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
570
+ parser.add_argument(
571
+ "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
572
+ )
573
+
574
+ parser.add_argument(
575
+ "--adam_epsilon",
576
+ type=float,
577
+ default=1e-08,
578
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
579
+ )
580
+
581
+ parser.add_argument(
582
+ "--prodigy_use_bias_correction",
583
+ type=bool,
584
+ default=True,
585
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
586
+ )
587
+ parser.add_argument(
588
+ "--prodigy_safeguard_warmup",
589
+ type=bool,
590
+ default=True,
591
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
592
+ "Ignored if optimizer is adamW",
593
+ )
594
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
595
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
596
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
597
+ parser.add_argument(
598
+ "--hub_model_id",
599
+ type=str,
600
+ default=None,
601
+ help="The name of the repository to keep in sync with the local `output_dir`.",
602
+ )
603
+ parser.add_argument(
604
+ "--logging_dir",
605
+ type=str,
606
+ default="logs",
607
+ help=(
608
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
609
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
610
+ ),
611
+ )
612
+ parser.add_argument(
613
+ "--allow_tf32",
614
+ action="store_true",
615
+ help=(
616
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
617
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
618
+ ),
619
+ )
620
+ parser.add_argument(
621
+ "--report_to",
622
+ type=str,
623
+ default="tensorboard",
624
+ help=(
625
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
626
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
627
+ ),
628
+ )
629
+ parser.add_argument(
630
+ "--mixed_precision",
631
+ type=str,
632
+ default=None,
633
+ choices=["no", "fp16", "bf16"],
634
+ help=(
635
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
636
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
637
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
638
+ ),
639
+ )
640
+ parser.add_argument(
641
+ "--prior_generation_precision",
642
+ type=str,
643
+ default=None,
644
+ choices=["no", "fp32", "fp16", "bf16"],
645
+ help=(
646
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
647
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
648
+ ),
649
+ )
650
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
651
+ parser.add_argument(
652
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
653
+ )
654
+ parser.add_argument(
655
+ "--rank",
656
+ type=int,
657
+ default=4,
658
+ help=("The dimension of the LoRA update matrices."),
659
+ )
660
+ parser.add_argument(
661
+ "--use_dora",
662
+ action="store_true",
663
+ default=False,
664
+ help=(
665
+ "Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
666
+ "Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
667
+ ),
668
+ )
669
+
670
+ if input_args is not None:
671
+ args = parser.parse_args(input_args)
672
+ else:
673
+ args = parser.parse_args()
674
+
675
+ if args.dataset_name is None and args.instance_data_dir is None:
676
+ raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
677
+
678
+ if args.dataset_name is not None and args.instance_data_dir is not None:
679
+ raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
680
+
681
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
682
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
683
+ args.local_rank = env_local_rank
684
+
685
+ if args.with_prior_preservation:
686
+ if args.class_data_dir is None:
687
+ raise ValueError("You must specify a data directory for class images.")
688
+ if args.class_prompt is None:
689
+ raise ValueError("You must specify prompt for class images.")
690
+ else:
691
+ # logger is not available yet
692
+ if args.class_data_dir is not None:
693
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
694
+ if args.class_prompt is not None:
695
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
696
+
697
+ return args
698
+
699
+
700
+ class DreamBoothDataset(Dataset):
701
+ """
702
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
703
+ It pre-processes the images.
704
+ """
705
+
706
+ def __init__(
707
+ self,
708
+ instance_data_root,
709
+ instance_prompt,
710
+ class_prompt,
711
+ class_data_root=None,
712
+ class_num=None,
713
+ size=1024,
714
+ repeats=1,
715
+ center_crop=False,
716
+ ):
717
+ self.size = size
718
+ self.center_crop = center_crop
719
+
720
+ self.instance_prompt = instance_prompt
721
+ self.custom_instance_prompts = None
722
+ self.class_prompt = class_prompt
723
+
724
+ # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
725
+ # we load the training data using load_dataset
726
+ if args.dataset_name is not None:
727
+ try:
728
+ from datasets import load_dataset
729
+ except ImportError:
730
+ raise ImportError(
731
+ "You are trying to load your data using the datasets library. If you wish to train using custom "
732
+ "captions please install the datasets library: `pip install datasets`. If you wish to load a "
733
+ "local folder containing images only, specify --instance_data_dir instead."
734
+ )
735
+ # Downloading and loading a dataset from the hub.
736
+ # See more about loading custom images at
737
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
738
+ dataset = load_dataset(
739
+ args.dataset_name,
740
+ args.dataset_config_name,
741
+ cache_dir=args.cache_dir,
742
+ )
743
+ # Preprocessing the datasets.
744
+ column_names = dataset["train"].column_names
745
+
746
+ # 6. Get the column names for input/target.
747
+ if args.image_column is None:
748
+ image_column = column_names[0]
749
+ logger.info(f"image column defaulting to {image_column}")
750
+ else:
751
+ image_column = args.image_column
752
+ if image_column not in column_names:
753
+ raise ValueError(
754
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
755
+ )
756
+ instance_images = dataset["train"][image_column]
757
+
758
+ if args.caption_column is None:
759
+ logger.info(
760
+ "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
761
+ "contains captions/prompts for the images, make sure to specify the "
762
+ "column as --caption_column"
763
+ )
764
+ self.custom_instance_prompts = None
765
+ else:
766
+ if args.caption_column not in column_names:
767
+ raise ValueError(
768
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
769
+ )
770
+ custom_instance_prompts = dataset["train"][args.caption_column]
771
+ # create final list of captions according to --repeats
772
+ self.custom_instance_prompts = []
773
+ for caption in custom_instance_prompts:
774
+ self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
775
+ else:
776
+ self.instance_data_root = Path(instance_data_root)
777
+ if not self.instance_data_root.exists():
778
+ raise ValueError("Instance images root doesn't exists.")
779
+
780
+ instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
781
+ self.custom_instance_prompts = None
782
+
783
+ self.instance_images = []
784
+ for img in instance_images:
785
+ self.instance_images.extend(itertools.repeat(img, repeats))
786
+
787
+ # image processing to prepare for using SD-XL micro-conditioning
788
+ self.original_sizes = []
789
+ self.crop_top_lefts = []
790
+ self.pixel_values = []
791
+ train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
792
+ train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
793
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
794
+ train_transforms = transforms.Compose(
795
+ [
796
+ transforms.ToTensor(),
797
+ transforms.Normalize([0.5], [0.5]),
798
+ ]
799
+ )
800
+ for image in self.instance_images:
801
+ image = exif_transpose(image)
802
+ if not image.mode == "RGB":
803
+ image = image.convert("RGB")
804
+ self.original_sizes.append((image.height, image.width))
805
+ image = train_resize(image)
806
+ if args.random_flip and random.random() < 0.5:
807
+ # flip
808
+ image = train_flip(image)
809
+ if args.center_crop:
810
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
811
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
812
+ image = train_crop(image)
813
+ else:
814
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
815
+ image = crop(image, y1, x1, h, w)
816
+ crop_top_left = (y1, x1)
817
+ self.crop_top_lefts.append(crop_top_left)
818
+ image = train_transforms(image)
819
+ self.pixel_values.append(image)
820
+
821
+ self.num_instance_images = len(self.instance_images)
822
+ self._length = self.num_instance_images
823
+
824
+ if class_data_root is not None:
825
+ self.class_data_root = Path(class_data_root)
826
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
827
+ self.class_images_path = list(self.class_data_root.iterdir())
828
+ if class_num is not None:
829
+ self.num_class_images = min(len(self.class_images_path), class_num)
830
+ else:
831
+ self.num_class_images = len(self.class_images_path)
832
+ self._length = max(self.num_class_images, self.num_instance_images)
833
+ else:
834
+ self.class_data_root = None
835
+
836
+ self.image_transforms = transforms.Compose(
837
+ [
838
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
839
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
840
+ transforms.ToTensor(),
841
+ transforms.Normalize([0.5], [0.5]),
842
+ ]
843
+ )
844
+
845
+ def __len__(self):
846
+ return self._length
847
+
848
+ def __getitem__(self, index):
849
+ example = {}
850
+ instance_image = self.pixel_values[index % self.num_instance_images]
851
+ original_size = self.original_sizes[index % self.num_instance_images]
852
+ crop_top_left = self.crop_top_lefts[index % self.num_instance_images]
853
+ example["instance_images"] = instance_image
854
+ example["original_size"] = original_size
855
+ example["crop_top_left"] = crop_top_left
856
+
857
+ if self.custom_instance_prompts:
858
+ caption = self.custom_instance_prompts[index % self.num_instance_images]
859
+ if caption:
860
+ example["instance_prompt"] = caption
861
+ else:
862
+ example["instance_prompt"] = self.instance_prompt
863
+
864
+ else: # costum prompts were provided, but length does not match size of image dataset
865
+ example["instance_prompt"] = self.instance_prompt
866
+
867
+ if self.class_data_root:
868
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
869
+ class_image = exif_transpose(class_image)
870
+
871
+ if not class_image.mode == "RGB":
872
+ class_image = class_image.convert("RGB")
873
+ example["class_images"] = self.image_transforms(class_image)
874
+ example["class_prompt"] = self.class_prompt
875
+
876
+ return example
877
+
878
+
879
+ def collate_fn(examples, with_prior_preservation=False):
880
+ pixel_values = [example["instance_images"] for example in examples]
881
+ prompts = [example["instance_prompt"] for example in examples]
882
+ original_sizes = [example["original_size"] for example in examples]
883
+ crop_top_lefts = [example["crop_top_left"] for example in examples]
884
+
885
+ # Concat class and instance examples for prior preservation.
886
+ # We do this to avoid doing two forward passes.
887
+ if with_prior_preservation:
888
+ pixel_values += [example["class_images"] for example in examples]
889
+ prompts += [example["class_prompt"] for example in examples]
890
+ original_sizes += [example["original_size"] for example in examples]
891
+ crop_top_lefts += [example["crop_top_left"] for example in examples]
892
+
893
+ pixel_values = torch.stack(pixel_values)
894
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
895
+
896
+ batch = {
897
+ "pixel_values": pixel_values,
898
+ "prompts": prompts,
899
+ "original_sizes": original_sizes,
900
+ "crop_top_lefts": crop_top_lefts,
901
+ }
902
+ return batch
903
+
904
+
905
+ class PromptDataset(Dataset):
906
+ """A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
907
+
908
+ def __init__(self, prompt, num_samples):
909
+ self.prompt = prompt
910
+ self.num_samples = num_samples
911
+
912
+ def __len__(self):
913
+ return self.num_samples
914
+
915
+ def __getitem__(self, index):
916
+ example = {}
917
+ example["prompt"] = self.prompt
918
+ example["index"] = index
919
+ return example
920
+
921
+
922
+ def tokenize_prompt(tokenizer, prompt):
923
+ text_inputs = tokenizer(
924
+ prompt,
925
+ padding="max_length",
926
+ max_length=tokenizer.model_max_length,
927
+ truncation=True,
928
+ return_tensors="pt",
929
+ )
930
+ text_input_ids = text_inputs.input_ids
931
+ return text_input_ids
932
+
933
+
934
+ # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
935
+ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
936
+ prompt_embeds_list = []
937
+
938
+ for i, text_encoder in enumerate(text_encoders):
939
+ if tokenizers is not None:
940
+ tokenizer = tokenizers[i]
941
+ text_input_ids = tokenize_prompt(tokenizer, prompt)
942
+ else:
943
+ assert text_input_ids_list is not None
944
+ text_input_ids = text_input_ids_list[i]
945
+
946
+ prompt_embeds = text_encoder(
947
+ text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False
948
+ )
949
+
950
+ # We are only ALWAYS interested in the pooled output of the final text encoder
951
+ pooled_prompt_embeds = prompt_embeds[0]
952
+ prompt_embeds = prompt_embeds[-1][-2]
953
+ bs_embed, seq_len, _ = prompt_embeds.shape
954
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
955
+ prompt_embeds_list.append(prompt_embeds)
956
+
957
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
958
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
959
+ return prompt_embeds, pooled_prompt_embeds
960
+
961
+
962
+ def main(args):
963
+ if args.report_to == "wandb" and args.hub_token is not None:
964
+ raise ValueError(
965
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
966
+ " Please use `huggingface-cli login` to authenticate with the Hub."
967
+ )
968
+
969
+ if args.do_edm_style_training and args.snr_gamma is not None:
970
+ raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.")
971
+
972
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
973
+ # due to pytorch#99272, MPS does not yet support bfloat16.
974
+ raise ValueError(
975
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
976
+ )
977
+
978
+ logging_dir = Path(args.output_dir, args.logging_dir)
979
+
980
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
981
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
982
+ accelerator = Accelerator(
983
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
984
+ mixed_precision=args.mixed_precision,
985
+ log_with=args.report_to,
986
+ project_config=accelerator_project_config,
987
+ kwargs_handlers=[kwargs],
988
+ )
989
+
990
+ # Disable AMP for MPS.
991
+ if torch.backends.mps.is_available():
992
+ accelerator.native_amp = False
993
+
994
+ if args.report_to == "wandb":
995
+ if not is_wandb_available():
996
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
997
+
998
+ # Make one log on every process with the configuration for debugging.
999
+ logging.basicConfig(
1000
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
1001
+ datefmt="%m/%d/%Y %H:%M:%S",
1002
+ level=logging.INFO,
1003
+ )
1004
+ logger.info(accelerator.state, main_process_only=False)
1005
+ if accelerator.is_local_main_process:
1006
+ transformers.utils.logging.set_verbosity_warning()
1007
+ diffusers.utils.logging.set_verbosity_info()
1008
+ else:
1009
+ transformers.utils.logging.set_verbosity_error()
1010
+ diffusers.utils.logging.set_verbosity_error()
1011
+
1012
+ # If passed along, set the training seed now.
1013
+ if args.seed is not None:
1014
+ set_seed(args.seed)
1015
+
1016
+ # Generate class images if prior preservation is enabled.
1017
+ if args.with_prior_preservation:
1018
+ class_images_dir = Path(args.class_data_dir)
1019
+ if not class_images_dir.exists():
1020
+ class_images_dir.mkdir(parents=True)
1021
+ cur_class_images = len(list(class_images_dir.iterdir()))
1022
+
1023
+ if cur_class_images < args.num_class_images:
1024
+ has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
1025
+ torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
1026
+ if args.prior_generation_precision == "fp32":
1027
+ torch_dtype = torch.float32
1028
+ elif args.prior_generation_precision == "fp16":
1029
+ torch_dtype = torch.float16
1030
+ elif args.prior_generation_precision == "bf16":
1031
+ torch_dtype = torch.bfloat16
1032
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
1033
+ args.pretrained_model_name_or_path,
1034
+ torch_dtype=torch_dtype,
1035
+ revision=args.revision,
1036
+ variant=args.variant,
1037
+ )
1038
+ pipeline.set_progress_bar_config(disable=True)
1039
+
1040
+ num_new_images = args.num_class_images - cur_class_images
1041
+ logger.info(f"Number of class images to sample: {num_new_images}.")
1042
+
1043
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
1044
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
1045
+
1046
+ sample_dataloader = accelerator.prepare(sample_dataloader)
1047
+ pipeline.to(accelerator.device)
1048
+
1049
+ for example in tqdm(
1050
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
1051
+ ):
1052
+ images = pipeline(example["prompt"]).images
1053
+
1054
+ for i, image in enumerate(images):
1055
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
1056
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
1057
+ image.save(image_filename)
1058
+
1059
+ del pipeline
1060
+ if torch.cuda.is_available():
1061
+ torch.cuda.empty_cache()
1062
+
1063
+ # Handle the repository creation
1064
+ if accelerator.is_main_process:
1065
+ if args.output_dir is not None:
1066
+ os.makedirs(args.output_dir, exist_ok=True)
1067
+
1068
+ if args.push_to_hub:
1069
+ repo_id = create_repo(
1070
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
1071
+ ).repo_id
1072
+
1073
+ # Load the tokenizers
1074
+ tokenizer_one = AutoTokenizer.from_pretrained(
1075
+ args.pretrained_model_name_or_path,
1076
+ subfolder="tokenizer",
1077
+ revision=args.revision,
1078
+ use_fast=False,
1079
+ )
1080
+ tokenizer_two = AutoTokenizer.from_pretrained(
1081
+ args.pretrained_model_name_or_path,
1082
+ subfolder="tokenizer_2",
1083
+ revision=args.revision,
1084
+ use_fast=False,
1085
+ )
1086
+
1087
+ # import correct text encoder classes
1088
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
1089
+ args.pretrained_model_name_or_path, args.revision
1090
+ )
1091
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
1092
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
1093
+ )
1094
+
1095
+ # Load scheduler and models
1096
+ scheduler_type = determine_scheduler_type(args.pretrained_model_name_or_path, args.revision)
1097
+ if "EDM" in scheduler_type:
1098
+ args.do_edm_style_training = True
1099
+ noise_scheduler = EDMEulerScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
1100
+ logger.info("Performing EDM-style training!")
1101
+ elif args.do_edm_style_training:
1102
+ noise_scheduler = EulerDiscreteScheduler.from_pretrained(
1103
+ args.pretrained_model_name_or_path, subfolder="scheduler"
1104
+ )
1105
+ logger.info("Performing EDM-style training!")
1106
+ else:
1107
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
1108
+
1109
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
1110
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
1111
+ )
1112
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
1113
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
1114
+ )
1115
+ vae_path = (
1116
+ args.pretrained_model_name_or_path
1117
+ if args.pretrained_vae_model_name_or_path is None
1118
+ else args.pretrained_vae_model_name_or_path
1119
+ )
1120
+ vae = AutoencoderKL.from_pretrained(
1121
+ vae_path,
1122
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
1123
+ revision=args.revision,
1124
+ variant=args.variant,
1125
+ )
1126
+ latents_mean = latents_std = None
1127
+ if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None:
1128
+ latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)
1129
+ if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None:
1130
+ latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)
1131
+
1132
+ unet = UNet2DConditionModel.from_pretrained(
1133
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
1134
+ )
1135
+
1136
+ # We only train the additional adapter LoRA layers
1137
+ vae.requires_grad_(False)
1138
+ text_encoder_one.requires_grad_(False)
1139
+ text_encoder_two.requires_grad_(False)
1140
+ unet.requires_grad_(False)
1141
+
1142
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
1143
+ # as these weights are only used for inference, keeping weights in full precision is not required.
1144
+ weight_dtype = torch.float32
1145
+ if accelerator.mixed_precision == "fp16":
1146
+ weight_dtype = torch.float16
1147
+ elif accelerator.mixed_precision == "bf16":
1148
+ weight_dtype = torch.bfloat16
1149
+
1150
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
1151
+ # due to pytorch#99272, MPS does not yet support bfloat16.
1152
+ raise ValueError(
1153
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
1154
+ )
1155
+
1156
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
1157
+ unet.to(accelerator.device, dtype=weight_dtype)
1158
+
1159
+ # The VAE is always in float32 to avoid NaN losses.
1160
+ vae.to(accelerator.device, dtype=torch.float32)
1161
+
1162
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
1163
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
1164
+
1165
+ if args.enable_xformers_memory_efficient_attention:
1166
+ if is_xformers_available():
1167
+ import xformers
1168
+
1169
+ xformers_version = version.parse(xformers.__version__)
1170
+ if xformers_version == version.parse("0.0.16"):
1171
+ logger.warning(
1172
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
1173
+ "please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
1174
+ )
1175
+ unet.enable_xformers_memory_efficient_attention()
1176
+ else:
1177
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
1178
+
1179
+ if args.gradient_checkpointing:
1180
+ unet.enable_gradient_checkpointing()
1181
+ if args.train_text_encoder:
1182
+ text_encoder_one.gradient_checkpointing_enable()
1183
+ text_encoder_two.gradient_checkpointing_enable()
1184
+
1185
+ # now we will add new LoRA weights to the attention layers
1186
+ unet_lora_config = LoraConfig(
1187
+ r=args.rank,
1188
+ use_dora=args.use_dora,
1189
+ lora_alpha=args.rank,
1190
+ init_lora_weights="gaussian",
1191
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
1192
+ )
1193
+ unet.add_adapter(unet_lora_config)
1194
+
1195
+ # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
1196
+ # So, instead, we monkey-patch the forward calls of its attention-blocks.
1197
+ if args.train_text_encoder:
1198
+ text_lora_config = LoraConfig(
1199
+ r=args.rank,
1200
+ use_dora=args.use_dora,
1201
+ lora_alpha=args.rank,
1202
+ init_lora_weights="gaussian",
1203
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
1204
+ )
1205
+ text_encoder_one.add_adapter(text_lora_config)
1206
+ text_encoder_two.add_adapter(text_lora_config)
1207
+
1208
+ def unwrap_model(model):
1209
+ model = accelerator.unwrap_model(model)
1210
+ model = model._orig_mod if is_compiled_module(model) else model
1211
+ return model
1212
+
1213
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
1214
+ def save_model_hook(models, weights, output_dir):
1215
+ if accelerator.is_main_process:
1216
+ # there are only two options here. Either are just the unet attn processor layers
1217
+ # or there are the unet and text encoder atten layers
1218
+ unet_lora_layers_to_save = None
1219
+ text_encoder_one_lora_layers_to_save = None
1220
+ text_encoder_two_lora_layers_to_save = None
1221
+
1222
+ for model in models:
1223
+ if isinstance(model, type(unwrap_model(unet))):
1224
+ unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
1225
+ elif isinstance(model, type(unwrap_model(text_encoder_one))):
1226
+ text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
1227
+ get_peft_model_state_dict(model)
1228
+ )
1229
+ elif isinstance(model, type(unwrap_model(text_encoder_two))):
1230
+ text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
1231
+ get_peft_model_state_dict(model)
1232
+ )
1233
+ else:
1234
+ raise ValueError(f"unexpected save model: {model.__class__}")
1235
+
1236
+ # make sure to pop weight so that corresponding model is not saved again
1237
+ weights.pop()
1238
+
1239
+ StableDiffusionXLPipeline.save_lora_weights(
1240
+ output_dir,
1241
+ unet_lora_layers=unet_lora_layers_to_save,
1242
+ text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
1243
+ text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
1244
+ )
1245
+
1246
+ def load_model_hook(models, input_dir):
1247
+ unet_ = None
1248
+ text_encoder_one_ = None
1249
+ text_encoder_two_ = None
1250
+
1251
+ while len(models) > 0:
1252
+ model = models.pop()
1253
+
1254
+ if isinstance(model, type(unwrap_model(unet))):
1255
+ unet_ = model
1256
+ elif isinstance(model, type(unwrap_model(text_encoder_one))):
1257
+ text_encoder_one_ = model
1258
+ elif isinstance(model, type(unwrap_model(text_encoder_two))):
1259
+ text_encoder_two_ = model
1260
+ else:
1261
+ raise ValueError(f"unexpected save model: {model.__class__}")
1262
+
1263
+ lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
1264
+
1265
+ unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
1266
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
1267
+ incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
1268
+ if incompatible_keys is not None:
1269
+ # check only for unexpected keys
1270
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
1271
+ if unexpected_keys:
1272
+ logger.warning(
1273
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1274
+ f" {unexpected_keys}. "
1275
+ )
1276
+
1277
+ if args.train_text_encoder:
1278
+ # Do we need to call `scale_lora_layers()` here?
1279
+ _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
1280
+
1281
+ _set_state_dict_into_text_encoder(
1282
+ lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_
1283
+ )
1284
+
1285
+ # Make sure the trainable params are in float32. This is again needed since the base models
1286
+ # are in `weight_dtype`. More details:
1287
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
1288
+ if args.mixed_precision == "fp16":
1289
+ models = [unet_]
1290
+ if args.train_text_encoder:
1291
+ models.extend([text_encoder_one_, text_encoder_two_])
1292
+ # only upcast trainable parameters (LoRA) into fp32
1293
+ cast_training_params(models)
1294
+
1295
+ accelerator.register_save_state_pre_hook(save_model_hook)
1296
+ accelerator.register_load_state_pre_hook(load_model_hook)
1297
+
1298
+ # Enable TF32 for faster training on Ampere GPUs,
1299
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
1300
+ if args.allow_tf32 and torch.cuda.is_available():
1301
+ torch.backends.cuda.matmul.allow_tf32 = True
1302
+
1303
+ if args.scale_lr:
1304
+ args.learning_rate = (
1305
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
1306
+ )
1307
+
1308
+ # Make sure the trainable params are in float32.
1309
+ if args.mixed_precision == "fp16":
1310
+ models = [unet]
1311
+ if args.train_text_encoder:
1312
+ models.extend([text_encoder_one, text_encoder_two])
1313
+
1314
+ # only upcast trainable parameters (LoRA) into fp32
1315
+ cast_training_params(models, dtype=torch.float32)
1316
+
1317
+ unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
1318
+
1319
+ if args.train_text_encoder:
1320
+ text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
1321
+ text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
1322
+
1323
+ # Optimization parameters
1324
+ unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate}
1325
+ if args.train_text_encoder:
1326
+ # different learning rate for text encoder and unet
1327
+ text_lora_parameters_one_with_lr = {
1328
+ "params": text_lora_parameters_one,
1329
+ "weight_decay": args.adam_weight_decay_text_encoder,
1330
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
1331
+ }
1332
+ text_lora_parameters_two_with_lr = {
1333
+ "params": text_lora_parameters_two,
1334
+ "weight_decay": args.adam_weight_decay_text_encoder,
1335
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
1336
+ }
1337
+ params_to_optimize = [
1338
+ unet_lora_parameters_with_lr,
1339
+ text_lora_parameters_one_with_lr,
1340
+ text_lora_parameters_two_with_lr,
1341
+ ]
1342
+ else:
1343
+ params_to_optimize = [unet_lora_parameters_with_lr]
1344
+
1345
+ # Optimizer creation
1346
+ if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
1347
+ logger.warning(
1348
+ f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
1349
+ "Defaulting to adamW"
1350
+ )
1351
+ args.optimizer = "adamw"
1352
+
1353
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
1354
+ logger.warning(
1355
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
1356
+ f"set to {args.optimizer.lower()}"
1357
+ )
1358
+
1359
+ if args.optimizer.lower() == "adamw":
1360
+ if args.use_8bit_adam:
1361
+ try:
1362
+ import bitsandbytes as bnb
1363
+ except ImportError:
1364
+ raise ImportError(
1365
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
1366
+ )
1367
+
1368
+ optimizer_class = bnb.optim.AdamW8bit
1369
+ else:
1370
+ optimizer_class = torch.optim.AdamW
1371
+
1372
+ optimizer = optimizer_class(
1373
+ params_to_optimize,
1374
+ betas=(args.adam_beta1, args.adam_beta2),
1375
+ weight_decay=args.adam_weight_decay,
1376
+ eps=args.adam_epsilon,
1377
+ )
1378
+
1379
+ if args.optimizer.lower() == "prodigy":
1380
+ try:
1381
+ import prodigyopt
1382
+ except ImportError:
1383
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
1384
+
1385
+ optimizer_class = prodigyopt.Prodigy
1386
+
1387
+ if args.learning_rate <= 0.1:
1388
+ logger.warning(
1389
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
1390
+ )
1391
+ if args.train_text_encoder and args.text_encoder_lr:
1392
+ logger.warning(
1393
+ f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
1394
+ f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
1395
+ f"When using prodigy only learning_rate is used as the initial learning rate."
1396
+ )
1397
+ # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
1398
+ # --learning_rate
1399
+ params_to_optimize[1]["lr"] = args.learning_rate
1400
+ params_to_optimize[2]["lr"] = args.learning_rate
1401
+
1402
+ optimizer = optimizer_class(
1403
+ params_to_optimize,
1404
+ lr=args.learning_rate,
1405
+ betas=(args.adam_beta1, args.adam_beta2),
1406
+ beta3=args.prodigy_beta3,
1407
+ weight_decay=args.adam_weight_decay,
1408
+ eps=args.adam_epsilon,
1409
+ decouple=args.prodigy_decouple,
1410
+ use_bias_correction=args.prodigy_use_bias_correction,
1411
+ safeguard_warmup=args.prodigy_safeguard_warmup,
1412
+ )
1413
+
1414
+ # Dataset and DataLoaders creation:
1415
+ train_dataset = DreamBoothDataset(
1416
+ instance_data_root=args.instance_data_dir,
1417
+ instance_prompt=args.instance_prompt,
1418
+ class_prompt=args.class_prompt,
1419
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
1420
+ class_num=args.num_class_images,
1421
+ size=args.resolution,
1422
+ repeats=args.repeats,
1423
+ center_crop=args.center_crop,
1424
+ )
1425
+
1426
+ train_dataloader = torch.utils.data.DataLoader(
1427
+ train_dataset,
1428
+ batch_size=args.train_batch_size,
1429
+ shuffle=True,
1430
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
1431
+ num_workers=args.dataloader_num_workers,
1432
+ )
1433
+
1434
+ # Computes additional embeddings/ids required by the SDXL UNet.
1435
+ # regular text embeddings (when `train_text_encoder` is not True)
1436
+ # pooled text embeddings
1437
+ # time ids
1438
+
1439
+ def compute_time_ids(original_size, crops_coords_top_left):
1440
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
1441
+ target_size = (args.resolution, args.resolution)
1442
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
1443
+ add_time_ids = torch.tensor([add_time_ids])
1444
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
1445
+ return add_time_ids
1446
+
1447
+ if not args.train_text_encoder:
1448
+ tokenizers = [tokenizer_one, tokenizer_two]
1449
+ text_encoders = [text_encoder_one, text_encoder_two]
1450
+
1451
+ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1452
+ with torch.no_grad():
1453
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
1454
+ prompt_embeds = prompt_embeds.to(accelerator.device)
1455
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
1456
+ return prompt_embeds, pooled_prompt_embeds
1457
+
1458
+ # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
1459
+ # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
1460
+ # the redundant encoding.
1461
+ if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
1462
+ instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
1463
+ args.instance_prompt, text_encoders, tokenizers
1464
+ )
1465
+
1466
+ # Handle class prompt for prior-preservation.
1467
+ if args.with_prior_preservation:
1468
+ if not args.train_text_encoder:
1469
+ class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
1470
+ args.class_prompt, text_encoders, tokenizers
1471
+ )
1472
+
1473
+ # Clear the memory here
1474
+ if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
1475
+ del tokenizers, text_encoders
1476
+ gc.collect()
1477
+ if torch.cuda.is_available():
1478
+ torch.cuda.empty_cache()
1479
+
1480
+ # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
1481
+ # pack the statically computed variables appropriately here. This is so that we don't
1482
+ # have to pass them to the dataloader.
1483
+
1484
+ if not train_dataset.custom_instance_prompts:
1485
+ if not args.train_text_encoder:
1486
+ prompt_embeds = instance_prompt_hidden_states
1487
+ unet_add_text_embeds = instance_pooled_prompt_embeds
1488
+ if args.with_prior_preservation:
1489
+ prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
1490
+ unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
1491
+ # if we're optmizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
1492
+ # batch prompts on all training steps
1493
+ else:
1494
+ tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
1495
+ tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
1496
+ if args.with_prior_preservation:
1497
+ class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
1498
+ class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
1499
+ tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
1500
+ tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
1501
+
1502
+ # Scheduler and math around the number of training steps.
1503
+ overrode_max_train_steps = False
1504
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1505
+ if args.max_train_steps is None:
1506
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1507
+ overrode_max_train_steps = True
1508
+
1509
+ lr_scheduler = get_scheduler(
1510
+ args.lr_scheduler,
1511
+ optimizer=optimizer,
1512
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1513
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
1514
+ num_cycles=args.lr_num_cycles,
1515
+ power=args.lr_power,
1516
+ )
1517
+
1518
+ # Prepare everything with our `accelerator`.
1519
+ if args.train_text_encoder:
1520
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1521
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
1522
+ )
1523
+ else:
1524
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1525
+ unet, optimizer, train_dataloader, lr_scheduler
1526
+ )
1527
+
1528
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1529
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1530
+ if overrode_max_train_steps:
1531
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1532
+ # Afterwards we recalculate our number of training epochs
1533
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1534
+
1535
+ # We need to initialize the trackers we use, and also store our configuration.
1536
+ # The trackers initializes automatically on the main process.
1537
+ if accelerator.is_main_process:
1538
+ tracker_name = (
1539
+ "dreambooth-lora-sd-xl"
1540
+ if "playground" not in args.pretrained_model_name_or_path
1541
+ else "dreambooth-lora-playground"
1542
+ )
1543
+ accelerator.init_trackers(tracker_name, config=vars(args))
1544
+
1545
+ # Train!
1546
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1547
+
1548
+ logger.info("***** Running training *****")
1549
+ logger.info(f" Num examples = {len(train_dataset)}")
1550
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1551
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1552
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1553
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1554
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1555
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1556
+ global_step = 0
1557
+ first_epoch = 0
1558
+
1559
+ # Potentially load in the weights and states from a previous save
1560
+ if args.resume_from_checkpoint:
1561
+ if args.resume_from_checkpoint != "latest":
1562
+ path = os.path.basename(args.resume_from_checkpoint)
1563
+ else:
1564
+ # Get the mos recent checkpoint
1565
+ dirs = os.listdir(args.output_dir)
1566
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1567
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1568
+ path = dirs[-1] if len(dirs) > 0 else None
1569
+
1570
+ if path is None:
1571
+ accelerator.print(
1572
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1573
+ )
1574
+ args.resume_from_checkpoint = None
1575
+ initial_global_step = 0
1576
+ else:
1577
+ accelerator.print(f"Resuming from checkpoint {path}")
1578
+ accelerator.load_state(os.path.join(args.output_dir, path))
1579
+ global_step = int(path.split("-")[1])
1580
+
1581
+ initial_global_step = global_step
1582
+ first_epoch = global_step // num_update_steps_per_epoch
1583
+
1584
+ else:
1585
+ initial_global_step = 0
1586
+
1587
+ progress_bar = tqdm(
1588
+ range(0, args.max_train_steps),
1589
+ initial=initial_global_step,
1590
+ desc="Steps",
1591
+ # Only show the progress bar once on each machine.
1592
+ disable=not accelerator.is_local_main_process,
1593
+ )
1594
+
1595
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1596
+ sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)
1597
+ schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)
1598
+ timesteps = timesteps.to(accelerator.device)
1599
+
1600
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
1601
+
1602
+ sigma = sigmas[step_indices].flatten()
1603
+ while len(sigma.shape) < n_dim:
1604
+ sigma = sigma.unsqueeze(-1)
1605
+ return sigma
1606
+
1607
+ for epoch in range(first_epoch, args.num_train_epochs):
1608
+ unet.train()
1609
+ if args.train_text_encoder:
1610
+ text_encoder_one.train()
1611
+ text_encoder_two.train()
1612
+
1613
+ # set top parameter requires_grad = True for gradient checkpointing works
1614
+ accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
1615
+ accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
1616
+
1617
+ for step, batch in enumerate(train_dataloader):
1618
+ with accelerator.accumulate(unet):
1619
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1620
+ prompts = batch["prompts"]
1621
+
1622
+ # encode batch prompts when custom prompts are provided for each image -
1623
+ if train_dataset.custom_instance_prompts:
1624
+ if not args.train_text_encoder:
1625
+ prompt_embeds, unet_add_text_embeds = compute_text_embeddings(
1626
+ prompts, text_encoders, tokenizers
1627
+ )
1628
+ else:
1629
+ tokens_one = tokenize_prompt(tokenizer_one, prompts)
1630
+ tokens_two = tokenize_prompt(tokenizer_two, prompts)
1631
+
1632
+ # Convert images to latent space
1633
+ model_input = vae.encode(pixel_values).latent_dist.sample()
1634
+
1635
+ if latents_mean is None and latents_std is None:
1636
+ model_input = model_input * vae.config.scaling_factor
1637
+ if args.pretrained_vae_model_name_or_path is None:
1638
+ model_input = model_input.to(weight_dtype)
1639
+ else:
1640
+ latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype)
1641
+ latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype)
1642
+ model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std
1643
+ model_input = model_input.to(dtype=weight_dtype)
1644
+
1645
+ # Sample noise that we'll add to the latents
1646
+ noise = torch.randn_like(model_input)
1647
+ bsz = model_input.shape[0]
1648
+
1649
+ # Sample a random timestep for each image
1650
+ if not args.do_edm_style_training:
1651
+ timesteps = torch.randint(
1652
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
1653
+ )
1654
+ timesteps = timesteps.long()
1655
+ else:
1656
+ # in EDM formulation, the model is conditioned on the pre-conditioned noise levels
1657
+ # instead of discrete timesteps, so here we sample indices to get the noise levels
1658
+ # from `scheduler.timesteps`
1659
+ indices = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,))
1660
+ timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device)
1661
+
1662
+ # Add noise to the model input according to the noise magnitude at each timestep
1663
+ # (this is the forward diffusion process)
1664
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
1665
+ # For EDM-style training, we first obtain the sigmas based on the continuous timesteps.
1666
+ # We then precondition the final model inputs based on these sigmas instead of the timesteps.
1667
+ # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
1668
+ if args.do_edm_style_training:
1669
+ sigmas = get_sigmas(timesteps, len(noisy_model_input.shape), noisy_model_input.dtype)
1670
+ if "EDM" in scheduler_type:
1671
+ inp_noisy_latents = noise_scheduler.precondition_inputs(noisy_model_input, sigmas)
1672
+ else:
1673
+ inp_noisy_latents = noisy_model_input / ((sigmas**2 + 1) ** 0.5)
1674
+
1675
+ # time ids
1676
+ add_time_ids = torch.cat(
1677
+ [
1678
+ compute_time_ids(original_size=s, crops_coords_top_left=c)
1679
+ for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])
1680
+ ]
1681
+ )
1682
+
1683
+ # Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
1684
+ if not train_dataset.custom_instance_prompts:
1685
+ elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz
1686
+ else:
1687
+ elems_to_repeat_text_embeds = 1
1688
+
1689
+ # Predict the noise residual
1690
+ if not args.train_text_encoder:
1691
+ unet_added_conditions = {
1692
+ "time_ids": add_time_ids,
1693
+ "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
1694
+ }
1695
+ prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
1696
+ model_pred = unet(
1697
+ inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
1698
+ timesteps,
1699
+ prompt_embeds_input,
1700
+ added_cond_kwargs=unet_added_conditions,
1701
+ return_dict=False,
1702
+ )[0]
1703
+ else:
1704
+ unet_added_conditions = {"time_ids": add_time_ids}
1705
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
1706
+ text_encoders=[text_encoder_one, text_encoder_two],
1707
+ tokenizers=None,
1708
+ prompt=None,
1709
+ text_input_ids_list=[tokens_one, tokens_two],
1710
+ )
1711
+ unet_added_conditions.update(
1712
+ {"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}
1713
+ )
1714
+ prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
1715
+ model_pred = unet(
1716
+ inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
1717
+ timesteps,
1718
+ prompt_embeds_input,
1719
+ added_cond_kwargs=unet_added_conditions,
1720
+ return_dict=False,
1721
+ )[0]
1722
+
1723
+ weighting = None
1724
+ if args.do_edm_style_training:
1725
+ # Similar to the input preconditioning, the model predictions are also preconditioned
1726
+ # on noised model inputs (before preconditioning) and the sigmas.
1727
+ # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
1728
+ if "EDM" in scheduler_type:
1729
+ model_pred = noise_scheduler.precondition_outputs(noisy_model_input, model_pred, sigmas)
1730
+ else:
1731
+ if noise_scheduler.config.prediction_type == "epsilon":
1732
+ model_pred = model_pred * (-sigmas) + noisy_model_input
1733
+ elif noise_scheduler.config.prediction_type == "v_prediction":
1734
+ model_pred = model_pred * (-sigmas / (sigmas**2 + 1) ** 0.5) + (
1735
+ noisy_model_input / (sigmas**2 + 1)
1736
+ )
1737
+ # We are not doing weighting here because it tends result in numerical problems.
1738
+ # See: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
1739
+ # There might be other alternatives for weighting as well:
1740
+ # https://github.com/huggingface/diffusers/pull/7126#discussion_r1505404686
1741
+ if "EDM" not in scheduler_type:
1742
+ weighting = (sigmas**-2.0).float()
1743
+
1744
+ # Get the target for loss depending on the prediction type
1745
+ if noise_scheduler.config.prediction_type == "epsilon":
1746
+ target = model_input if args.do_edm_style_training else noise
1747
+ elif noise_scheduler.config.prediction_type == "v_prediction":
1748
+ target = (
1749
+ model_input
1750
+ if args.do_edm_style_training
1751
+ else noise_scheduler.get_velocity(model_input, noise, timesteps)
1752
+ )
1753
+ else:
1754
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1755
+
1756
+ if args.with_prior_preservation:
1757
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
1758
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
1759
+ target, target_prior = torch.chunk(target, 2, dim=0)
1760
+
1761
+ # Compute prior loss
1762
+ if weighting is not None:
1763
+ prior_loss = torch.mean(
1764
+ (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
1765
+ target_prior.shape[0], -1
1766
+ ),
1767
+ 1,
1768
+ )
1769
+ prior_loss = prior_loss.mean()
1770
+ else:
1771
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
1772
+
1773
+ if args.snr_gamma is None:
1774
+ if weighting is not None:
1775
+ loss = torch.mean(
1776
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(
1777
+ target.shape[0], -1
1778
+ ),
1779
+ 1,
1780
+ )
1781
+ loss = loss.mean()
1782
+ else:
1783
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1784
+ else:
1785
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
1786
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
1787
+ # This is discussed in Section 4.2 of the same paper.
1788
+ snr = compute_snr(noise_scheduler, timesteps)
1789
+ base_weight = (
1790
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
1791
+ )
1792
+
1793
+ if noise_scheduler.config.prediction_type == "v_prediction":
1794
+ # Velocity objective needs to be floored to an SNR weight of one.
1795
+ mse_loss_weights = base_weight + 1
1796
+ else:
1797
+ # Epsilon and sample both use the same loss weights.
1798
+ mse_loss_weights = base_weight
1799
+
1800
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
1801
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
1802
+ loss = loss.mean()
1803
+
1804
+ if args.with_prior_preservation:
1805
+ # Add the prior loss to the instance loss.
1806
+ loss = loss + args.prior_loss_weight * prior_loss
1807
+
1808
+ accelerator.backward(loss)
1809
+ if accelerator.sync_gradients:
1810
+ params_to_clip = (
1811
+ itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
1812
+ if args.train_text_encoder
1813
+ else unet_lora_parameters
1814
+ )
1815
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1816
+
1817
+ optimizer.step()
1818
+ lr_scheduler.step()
1819
+ optimizer.zero_grad()
1820
+
1821
+ # Checks if the accelerator has performed an optimization step behind the scenes
1822
+ if accelerator.sync_gradients:
1823
+ progress_bar.update(1)
1824
+ global_step += 1
1825
+
1826
+ if accelerator.is_main_process:
1827
+ if global_step % args.checkpointing_steps == 0:
1828
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1829
+ if args.checkpoints_total_limit is not None:
1830
+ checkpoints = os.listdir(args.output_dir)
1831
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1832
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1833
+
1834
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1835
+ if len(checkpoints) >= args.checkpoints_total_limit:
1836
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1837
+ removing_checkpoints = checkpoints[0:num_to_remove]
1838
+
1839
+ logger.info(
1840
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1841
+ )
1842
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1843
+
1844
+ for removing_checkpoint in removing_checkpoints:
1845
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1846
+ shutil.rmtree(removing_checkpoint)
1847
+
1848
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1849
+ accelerator.save_state(save_path)
1850
+ logger.info(f"Saved state to {save_path}")
1851
+
1852
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1853
+ progress_bar.set_postfix(**logs)
1854
+ accelerator.log(logs, step=global_step)
1855
+
1856
+ if global_step >= args.max_train_steps:
1857
+ break
1858
+
1859
+ if accelerator.is_main_process:
1860
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
1861
+ # create pipeline
1862
+ if not args.train_text_encoder:
1863
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
1864
+ args.pretrained_model_name_or_path,
1865
+ subfolder="text_encoder",
1866
+ revision=args.revision,
1867
+ variant=args.variant,
1868
+ )
1869
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
1870
+ args.pretrained_model_name_or_path,
1871
+ subfolder="text_encoder_2",
1872
+ revision=args.revision,
1873
+ variant=args.variant,
1874
+ )
1875
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
1876
+ args.pretrained_model_name_or_path,
1877
+ vae=vae,
1878
+ text_encoder=accelerator.unwrap_model(text_encoder_one),
1879
+ text_encoder_2=accelerator.unwrap_model(text_encoder_two),
1880
+ unet=accelerator.unwrap_model(unet),
1881
+ revision=args.revision,
1882
+ variant=args.variant,
1883
+ torch_dtype=weight_dtype,
1884
+ )
1885
+ pipeline_args = {"prompt": args.validation_prompt}
1886
+
1887
+ images = log_validation(
1888
+ pipeline,
1889
+ args,
1890
+ accelerator,
1891
+ pipeline_args,
1892
+ epoch,
1893
+ )
1894
+
1895
+ # Save the lora layers
1896
+ accelerator.wait_for_everyone()
1897
+ if accelerator.is_main_process:
1898
+ unet = unwrap_model(unet)
1899
+ unet = unet.to(torch.float32)
1900
+ unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
1901
+
1902
+ if args.train_text_encoder:
1903
+ text_encoder_one = unwrap_model(text_encoder_one)
1904
+ text_encoder_lora_layers = convert_state_dict_to_diffusers(
1905
+ get_peft_model_state_dict(text_encoder_one.to(torch.float32))
1906
+ )
1907
+ text_encoder_two = unwrap_model(text_encoder_two)
1908
+ text_encoder_2_lora_layers = convert_state_dict_to_diffusers(
1909
+ get_peft_model_state_dict(text_encoder_two.to(torch.float32))
1910
+ )
1911
+ else:
1912
+ text_encoder_lora_layers = None
1913
+ text_encoder_2_lora_layers = None
1914
+
1915
+ StableDiffusionXLPipeline.save_lora_weights(
1916
+ save_directory=args.output_dir,
1917
+ unet_lora_layers=unet_lora_layers,
1918
+ text_encoder_lora_layers=text_encoder_lora_layers,
1919
+ text_encoder_2_lora_layers=text_encoder_2_lora_layers,
1920
+ )
1921
+ if args.output_kohya_format:
1922
+ lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
1923
+ peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
1924
+ kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
1925
+ save_file(kohya_state_dict, f"{args.output_dir}/pytorch_lora_weights_kohya.safetensors")
1926
+
1927
+ # Final inference
1928
+ # Load previous pipeline
1929
+ vae = AutoencoderKL.from_pretrained(
1930
+ vae_path,
1931
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
1932
+ revision=args.revision,
1933
+ variant=args.variant,
1934
+ torch_dtype=weight_dtype,
1935
+ )
1936
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
1937
+ args.pretrained_model_name_or_path,
1938
+ vae=vae,
1939
+ revision=args.revision,
1940
+ variant=args.variant,
1941
+ torch_dtype=weight_dtype,
1942
+ )
1943
+
1944
+ # load attention processors
1945
+ pipeline.load_lora_weights(args.output_dir)
1946
+
1947
+ # run inference
1948
+ images = []
1949
+ if args.validation_prompt and args.num_validation_images > 0:
1950
+ pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25}
1951
+ images = log_validation(
1952
+ pipeline,
1953
+ args,
1954
+ accelerator,
1955
+ pipeline_args,
1956
+ epoch,
1957
+ is_final_validation=True,
1958
+ )
1959
+
1960
+ if args.push_to_hub:
1961
+ save_model_card(
1962
+ repo_id,
1963
+ use_dora=args.use_dora,
1964
+ images=images,
1965
+ base_model=args.pretrained_model_name_or_path,
1966
+ train_text_encoder=args.train_text_encoder,
1967
+ instance_prompt=args.instance_prompt,
1968
+ validation_prompt=args.validation_prompt,
1969
+ repo_folder=args.output_dir,
1970
+ vae_path=args.pretrained_vae_model_name_or_path,
1971
+ )
1972
+ upload_folder(
1973
+ repo_id=repo_id,
1974
+ folder_path=args.output_dir,
1975
+ commit_message="End of training",
1976
+ ignore_patterns=["step_*", "epoch_*"],
1977
+ )
1978
+
1979
+ accelerator.end_training()
1980
+
1981
+
1982
+ if __name__ == "__main__":
1983
+ args = parse_args()
1984
+ main(args)