Spaces:
Runtime error
Runtime error
WIP
Browse files- .gitignore +2 -0
- app.py +113 -11
- lora_diffusion/FOR-cloneofsimo-LoRA +6 -0
- lora_diffusion/__init__.py +5 -0
- lora_diffusion/cli_lora_add.py +187 -0
- lora_diffusion/cli_lora_pti.py +1040 -0
- lora_diffusion/cli_pt_to_safetensors.py +85 -0
- lora_diffusion/cli_svd.py +146 -0
- lora_diffusion/dataset.py +311 -0
- lora_diffusion/lora.py +1110 -0
- lora_diffusion/lora_manager.py +144 -0
- lora_diffusion/preprocess_files.py +327 -0
- lora_diffusion/safe_open.py +68 -0
- lora_diffusion/to_ckpt_v2.py +232 -0
- lora_diffusion/utils.py +214 -0
- lora_diffusion/xformers_utils.py +70 -0
- requirements.txt +3 -0
- train_dreambooth_cloneofsimo_lora.py +1008 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*/__pycache__/
|
2 |
+
*/*.pyc
|
app.py
CHANGED
@@ -1,16 +1,118 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
2 |
import pandas as pd
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
def
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
iface = gr.Interface(
|
9 |
-
fn=load_csv,
|
10 |
-
inputs="file",
|
11 |
-
outputs="dataframe",
|
12 |
-
title="CSV Loader",
|
13 |
-
description="Load a CSV file and display its contents.",
|
14 |
-
)
|
15 |
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import shutil
|
3 |
+
import zipfile
|
4 |
+
import tensorflow as tf
|
5 |
import pandas as pd
|
6 |
+
import pathlib
|
7 |
+
import PIL.Image
|
8 |
+
import os
|
9 |
+
import subprocess
|
10 |
|
11 |
+
def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
|
12 |
+
w, h = image.size
|
13 |
+
if w == h:
|
14 |
+
return image
|
15 |
+
elif w > h:
|
16 |
+
new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0))
|
17 |
+
new_image.paste(image, (0, (w - h) // 2))
|
18 |
+
return new_image
|
19 |
+
else:
|
20 |
+
new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0))
|
21 |
+
new_image.paste(image, ((h - w) // 2, 0))
|
22 |
+
return new_image
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
+
class ModelTrainer:
|
26 |
+
def __init__(self):
|
27 |
+
self.training_pictures = []
|
28 |
+
self.training_model = None
|
29 |
+
|
30 |
+
def unzip_file(self, zip_file_path):
|
31 |
+
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
|
32 |
+
extracted_path = zip_file_path.replace('.zip', '')
|
33 |
+
zip_ref.extractall(extracted_path)
|
34 |
+
file_names = zip_ref.namelist()
|
35 |
+
for file_name in file_names:
|
36 |
+
if file_name.endswith(('.jpeg', '.jpg', '.png')):
|
37 |
+
self.training_pictures.append(f'{extracted_path}/{file_name}')
|
38 |
+
|
39 |
+
def train(self, pretrained_model_name_or_path: str, instance_images: list | None):
|
40 |
+
output_model_name = 'a-xyz-model'
|
41 |
+
resolution = 512
|
42 |
+
repo_dir = pathlib.Path(__file__).parent
|
43 |
+
subdirs = ['train-instance', 'train-class', 'experiments']
|
44 |
+
dir_paths = []
|
45 |
+
|
46 |
+
for subdir in subdirs:
|
47 |
+
dir_path = repo_dir / subdir / output_model_name
|
48 |
+
dir_paths.append(dir_path)
|
49 |
+
shutil.rmtree(dir_path, ignore_errors=True)
|
50 |
+
os.makedirs(dir_path, exist_ok=True)
|
51 |
+
|
52 |
+
instance_data_dir, class_data_dir, output_dir = dir_paths
|
53 |
+
|
54 |
+
for i, temp_path in enumerate(instance_images):
|
55 |
+
image = PIL.Image.open(temp_path.name)
|
56 |
+
image = pad_image(image)
|
57 |
+
image = image.resize((resolution, resolution))
|
58 |
+
image = image.convert('RGB')
|
59 |
+
out_path = instance_data_dir / f'{i:03d}.jpg'
|
60 |
+
image.save(out_path, format='JPEG', quality=100)
|
61 |
+
|
62 |
+
command = [
|
63 |
+
'python', '-u',
|
64 |
+
'train_dreambooth_cloneofsimo_lora.py',
|
65 |
+
'--pretrained_model_name_or_path', pretrained_model_name_or_path,
|
66 |
+
'--instance_data_dir', instance_data_dir,
|
67 |
+
'--class_data_dir', class_data_dir,
|
68 |
+
'--resolution', '768',
|
69 |
+
'--output_dir', output_dir,
|
70 |
+
'--instance_prompt', 'a photo of a pwsm dog',
|
71 |
+
'--with_prior_preservation',
|
72 |
+
'--class_prompt', 'a dog',
|
73 |
+
'--prior_loss_weight', '1.0',
|
74 |
+
'--num_class_images', '100',
|
75 |
+
'--learning_rate', '0.0004',
|
76 |
+
'--train_batch_size', '1',
|
77 |
+
'--sample_batch_size', '1',
|
78 |
+
'--max_train_steps', '400',
|
79 |
+
'--gradient_accumulation_steps', '1',
|
80 |
+
'--gradient_checkpointing',
|
81 |
+
'--train_text_encoder',
|
82 |
+
'--learning_rate_text', '5e-6',
|
83 |
+
'--save_steps', '100',
|
84 |
+
'--seed', '1337',
|
85 |
+
'--lr_scheduler', 'constant',
|
86 |
+
'--lr_warmup_steps', '0'
|
87 |
+
]
|
88 |
+
|
89 |
+
result = subprocess.run(command)
|
90 |
+
return result
|
91 |
+
|
92 |
+
def generate_picture(self, row):
|
93 |
+
num_of_training_steps, learning_rate, checkpoint_steps, abc = row
|
94 |
+
return f'Picture generated for num_of_training_steps: {num_of_training_steps}, learning_rate: {learning_rate}, checkpoint_steps: {checkpoint_steps}'
|
95 |
+
|
96 |
+
def generate_pictures(self, csv_input):
|
97 |
+
csv = pd.read_csv(csv_input.name)
|
98 |
+
result = []
|
99 |
+
for index, row in csv.iterrows():
|
100 |
+
result.append(self.generate_picture(row))
|
101 |
+
return "\n".join(str(item) for item in result)
|
102 |
+
|
103 |
+
loader = ModelTrainer()
|
104 |
+
|
105 |
+
with gr.Blocks() as demo:
|
106 |
+
with gr.Box():
|
107 |
+
instance_images = gr.Files(label='Instance images')
|
108 |
+
pretrained_model_name_or_path = gr.inputs.Textbox(lines=1, label='pretrained_model_name_or_path', default='stabilityai/stable-diffusion-2-1')
|
109 |
+
output_message = gr.Markdown()
|
110 |
+
train_button = gr.Button('Train')
|
111 |
+
train_button.click(fn=loader.train, inputs=[pretrained_model_name_or_path, instance_images], outputs=[output_message])
|
112 |
+
with gr.Box():
|
113 |
+
csv_input = gr.inputs.File(label='CSV File')
|
114 |
+
output_message2 = gr.Markdown()
|
115 |
+
generate_button = gr.Button('Generate Pictures from CSV')
|
116 |
+
generate_button.click(fn=loader.generate_pictures, inputs=[csv_input], outputs=[output_message2])
|
117 |
+
|
118 |
+
demo.launch()
|
lora_diffusion/FOR-cloneofsimo-LoRA
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
This 'lora_diffusion' library in this subdirectory is required by
|
2 |
+
'train_dreambooth_cloneofsimo_lora.py' script and is the underlying library in the
|
3 |
+
https://github.com/cloneofsimo/lora project.
|
4 |
+
|
5 |
+
The 'train_dreambooth_cloneofsimo_lora.py' script, in turn, is merely a renamed copy
|
6 |
+
of 'traning_scripts/train_lora_dreambooth.py' from that same project.
|
lora_diffusion/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .lora import *
|
2 |
+
from .dataset import *
|
3 |
+
from .utils import *
|
4 |
+
from .preprocess_files import *
|
5 |
+
from .lora_manager import *
|
lora_diffusion/cli_lora_add.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal, Union, Dict
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
import fire
|
5 |
+
from diffusers import StableDiffusionPipeline
|
6 |
+
from safetensors.torch import safe_open, save_file
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from .lora import (
|
10 |
+
tune_lora_scale,
|
11 |
+
patch_pipe,
|
12 |
+
collapse_lora,
|
13 |
+
monkeypatch_remove_lora,
|
14 |
+
)
|
15 |
+
from .lora_manager import lora_join
|
16 |
+
from .to_ckpt_v2 import convert_to_ckpt
|
17 |
+
|
18 |
+
|
19 |
+
def _text_lora_path(path: str) -> str:
|
20 |
+
assert path.endswith(".pt"), "Only .pt files are supported"
|
21 |
+
return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
|
22 |
+
|
23 |
+
|
24 |
+
def add(
|
25 |
+
path_1: str,
|
26 |
+
path_2: str,
|
27 |
+
output_path: str,
|
28 |
+
alpha_1: float = 0.5,
|
29 |
+
alpha_2: float = 0.5,
|
30 |
+
mode: Literal[
|
31 |
+
"lpl",
|
32 |
+
"upl",
|
33 |
+
"upl-ckpt-v2",
|
34 |
+
] = "lpl",
|
35 |
+
with_text_lora: bool = False,
|
36 |
+
):
|
37 |
+
print("Lora Add, mode " + mode)
|
38 |
+
if mode == "lpl":
|
39 |
+
if path_1.endswith(".pt") and path_2.endswith(".pt"):
|
40 |
+
for _path_1, _path_2, opt in [(path_1, path_2, "unet")] + (
|
41 |
+
[(_text_lora_path(path_1), _text_lora_path(path_2), "text_encoder")]
|
42 |
+
if with_text_lora
|
43 |
+
else []
|
44 |
+
):
|
45 |
+
print("Loading", _path_1, _path_2)
|
46 |
+
out_list = []
|
47 |
+
if opt == "text_encoder":
|
48 |
+
if not os.path.exists(_path_1):
|
49 |
+
print(f"No text encoder found in {_path_1}, skipping...")
|
50 |
+
continue
|
51 |
+
if not os.path.exists(_path_2):
|
52 |
+
print(f"No text encoder found in {_path_1}, skipping...")
|
53 |
+
continue
|
54 |
+
|
55 |
+
l1 = torch.load(_path_1)
|
56 |
+
l2 = torch.load(_path_2)
|
57 |
+
|
58 |
+
l1pairs = zip(l1[::2], l1[1::2])
|
59 |
+
l2pairs = zip(l2[::2], l2[1::2])
|
60 |
+
|
61 |
+
for (x1, y1), (x2, y2) in zip(l1pairs, l2pairs):
|
62 |
+
# print("Merging", x1.shape, y1.shape, x2.shape, y2.shape)
|
63 |
+
x1.data = alpha_1 * x1.data + alpha_2 * x2.data
|
64 |
+
y1.data = alpha_1 * y1.data + alpha_2 * y2.data
|
65 |
+
|
66 |
+
out_list.append(x1)
|
67 |
+
out_list.append(y1)
|
68 |
+
|
69 |
+
if opt == "unet":
|
70 |
+
|
71 |
+
print("Saving merged UNET to", output_path)
|
72 |
+
torch.save(out_list, output_path)
|
73 |
+
|
74 |
+
elif opt == "text_encoder":
|
75 |
+
print("Saving merged text encoder to", _text_lora_path(output_path))
|
76 |
+
torch.save(
|
77 |
+
out_list,
|
78 |
+
_text_lora_path(output_path),
|
79 |
+
)
|
80 |
+
|
81 |
+
elif path_1.endswith(".safetensors") and path_2.endswith(".safetensors"):
|
82 |
+
safeloras_1 = safe_open(path_1, framework="pt", device="cpu")
|
83 |
+
safeloras_2 = safe_open(path_2, framework="pt", device="cpu")
|
84 |
+
|
85 |
+
metadata = dict(safeloras_1.metadata())
|
86 |
+
metadata.update(dict(safeloras_2.metadata()))
|
87 |
+
|
88 |
+
ret_tensor = {}
|
89 |
+
|
90 |
+
for keys in set(list(safeloras_1.keys()) + list(safeloras_2.keys())):
|
91 |
+
if keys.startswith("text_encoder") or keys.startswith("unet"):
|
92 |
+
|
93 |
+
tens1 = safeloras_1.get_tensor(keys)
|
94 |
+
tens2 = safeloras_2.get_tensor(keys)
|
95 |
+
|
96 |
+
tens = alpha_1 * tens1 + alpha_2 * tens2
|
97 |
+
ret_tensor[keys] = tens
|
98 |
+
else:
|
99 |
+
if keys in safeloras_1.keys():
|
100 |
+
|
101 |
+
tens1 = safeloras_1.get_tensor(keys)
|
102 |
+
else:
|
103 |
+
tens1 = safeloras_2.get_tensor(keys)
|
104 |
+
|
105 |
+
ret_tensor[keys] = tens1
|
106 |
+
|
107 |
+
save_file(ret_tensor, output_path, metadata)
|
108 |
+
|
109 |
+
elif mode == "upl":
|
110 |
+
|
111 |
+
print(
|
112 |
+
f"Merging UNET/CLIP from {path_1} with LoRA from {path_2} to {output_path}. Merging ratio : {alpha_1}."
|
113 |
+
)
|
114 |
+
|
115 |
+
loaded_pipeline = StableDiffusionPipeline.from_pretrained(
|
116 |
+
path_1,
|
117 |
+
).to("cpu")
|
118 |
+
|
119 |
+
patch_pipe(loaded_pipeline, path_2)
|
120 |
+
|
121 |
+
collapse_lora(loaded_pipeline.unet, alpha_1)
|
122 |
+
collapse_lora(loaded_pipeline.text_encoder, alpha_1)
|
123 |
+
|
124 |
+
monkeypatch_remove_lora(loaded_pipeline.unet)
|
125 |
+
monkeypatch_remove_lora(loaded_pipeline.text_encoder)
|
126 |
+
|
127 |
+
loaded_pipeline.save_pretrained(output_path)
|
128 |
+
|
129 |
+
elif mode == "upl-ckpt-v2":
|
130 |
+
|
131 |
+
assert output_path.endswith(".ckpt"), "Only .ckpt files are supported"
|
132 |
+
name = os.path.basename(output_path)[0:-5]
|
133 |
+
|
134 |
+
print(
|
135 |
+
f"You will be using {name} as the token in A1111 webui. Make sure {name} is unique enough token."
|
136 |
+
)
|
137 |
+
|
138 |
+
loaded_pipeline = StableDiffusionPipeline.from_pretrained(
|
139 |
+
path_1,
|
140 |
+
).to("cpu")
|
141 |
+
|
142 |
+
tok_dict = patch_pipe(loaded_pipeline, path_2, patch_ti=False)
|
143 |
+
|
144 |
+
collapse_lora(loaded_pipeline.unet, alpha_1)
|
145 |
+
collapse_lora(loaded_pipeline.text_encoder, alpha_1)
|
146 |
+
|
147 |
+
monkeypatch_remove_lora(loaded_pipeline.unet)
|
148 |
+
monkeypatch_remove_lora(loaded_pipeline.text_encoder)
|
149 |
+
|
150 |
+
_tmp_output = output_path + ".tmp"
|
151 |
+
|
152 |
+
loaded_pipeline.save_pretrained(_tmp_output)
|
153 |
+
convert_to_ckpt(_tmp_output, output_path, as_half=True)
|
154 |
+
# remove the tmp_output folder
|
155 |
+
shutil.rmtree(_tmp_output)
|
156 |
+
|
157 |
+
keys = sorted(tok_dict.keys())
|
158 |
+
tok_catted = torch.stack([tok_dict[k] for k in keys])
|
159 |
+
ret = {
|
160 |
+
"string_to_token": {"*": torch.tensor(265)},
|
161 |
+
"string_to_param": {"*": tok_catted},
|
162 |
+
"name": name,
|
163 |
+
}
|
164 |
+
|
165 |
+
torch.save(ret, output_path[:-5] + ".pt")
|
166 |
+
print(
|
167 |
+
f"Textual embedding saved as {output_path[:-5]}.pt, put it in the embedding folder and use it as {name} in A1111 repo, "
|
168 |
+
)
|
169 |
+
elif mode == "ljl":
|
170 |
+
print("Using Join mode : alpha will not have an effect here.")
|
171 |
+
assert path_1.endswith(".safetensors") and path_2.endswith(
|
172 |
+
".safetensors"
|
173 |
+
), "Only .safetensors files are supported"
|
174 |
+
|
175 |
+
safeloras_1 = safe_open(path_1, framework="pt", device="cpu")
|
176 |
+
safeloras_2 = safe_open(path_2, framework="pt", device="cpu")
|
177 |
+
|
178 |
+
total_tensor, total_metadata, _, _ = lora_join([safeloras_1, safeloras_2])
|
179 |
+
save_file(total_tensor, output_path, total_metadata)
|
180 |
+
|
181 |
+
else:
|
182 |
+
print("Unknown mode", mode)
|
183 |
+
raise ValueError(f"Unknown mode {mode}")
|
184 |
+
|
185 |
+
|
186 |
+
def main():
|
187 |
+
fire.Fire(add)
|
lora_diffusion/cli_lora_pti.py
ADDED
@@ -0,0 +1,1040 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Bootstrapped from:
|
2 |
+
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import hashlib
|
6 |
+
import inspect
|
7 |
+
import itertools
|
8 |
+
import math
|
9 |
+
import os
|
10 |
+
import random
|
11 |
+
import re
|
12 |
+
from pathlib import Path
|
13 |
+
from typing import Optional, List, Literal
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
import torch.optim as optim
|
18 |
+
import torch.utils.checkpoint
|
19 |
+
from diffusers import (
|
20 |
+
AutoencoderKL,
|
21 |
+
DDPMScheduler,
|
22 |
+
StableDiffusionPipeline,
|
23 |
+
UNet2DConditionModel,
|
24 |
+
)
|
25 |
+
from diffusers.optimization import get_scheduler
|
26 |
+
from huggingface_hub import HfFolder, Repository, whoami
|
27 |
+
from PIL import Image
|
28 |
+
from torch.utils.data import Dataset
|
29 |
+
from torchvision import transforms
|
30 |
+
from tqdm.auto import tqdm
|
31 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
32 |
+
import wandb
|
33 |
+
import fire
|
34 |
+
|
35 |
+
from lora_diffusion import (
|
36 |
+
PivotalTuningDatasetCapation,
|
37 |
+
extract_lora_ups_down,
|
38 |
+
inject_trainable_lora,
|
39 |
+
inject_trainable_lora_extended,
|
40 |
+
inspect_lora,
|
41 |
+
save_lora_weight,
|
42 |
+
save_all,
|
43 |
+
prepare_clip_model_sets,
|
44 |
+
evaluate_pipe,
|
45 |
+
UNET_EXTENDED_TARGET_REPLACE,
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
def get_models(
|
50 |
+
pretrained_model_name_or_path,
|
51 |
+
pretrained_vae_name_or_path,
|
52 |
+
revision,
|
53 |
+
placeholder_tokens: List[str],
|
54 |
+
initializer_tokens: List[str],
|
55 |
+
device="cuda:0",
|
56 |
+
):
|
57 |
+
|
58 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
59 |
+
pretrained_model_name_or_path,
|
60 |
+
subfolder="tokenizer",
|
61 |
+
revision=revision,
|
62 |
+
)
|
63 |
+
|
64 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
65 |
+
pretrained_model_name_or_path,
|
66 |
+
subfolder="text_encoder",
|
67 |
+
revision=revision,
|
68 |
+
)
|
69 |
+
|
70 |
+
placeholder_token_ids = []
|
71 |
+
|
72 |
+
for token, init_tok in zip(placeholder_tokens, initializer_tokens):
|
73 |
+
num_added_tokens = tokenizer.add_tokens(token)
|
74 |
+
if num_added_tokens == 0:
|
75 |
+
raise ValueError(
|
76 |
+
f"The tokenizer already contains the token {token}. Please pass a different"
|
77 |
+
" `placeholder_token` that is not already in the tokenizer."
|
78 |
+
)
|
79 |
+
|
80 |
+
placeholder_token_id = tokenizer.convert_tokens_to_ids(token)
|
81 |
+
|
82 |
+
placeholder_token_ids.append(placeholder_token_id)
|
83 |
+
|
84 |
+
# Load models and create wrapper for stable diffusion
|
85 |
+
|
86 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
87 |
+
token_embeds = text_encoder.get_input_embeddings().weight.data
|
88 |
+
if init_tok.startswith("<rand"):
|
89 |
+
# <rand-"sigma">, e.g. <rand-0.5>
|
90 |
+
sigma_val = float(re.findall(r"<rand-(.*)>", init_tok)[0])
|
91 |
+
|
92 |
+
token_embeds[placeholder_token_id] = (
|
93 |
+
torch.randn_like(token_embeds[0]) * sigma_val
|
94 |
+
)
|
95 |
+
print(
|
96 |
+
f"Initialized {token} with random noise (sigma={sigma_val}), empirically {token_embeds[placeholder_token_id].mean().item():.3f} +- {token_embeds[placeholder_token_id].std().item():.3f}"
|
97 |
+
)
|
98 |
+
print(f"Norm : {token_embeds[placeholder_token_id].norm():.4f}")
|
99 |
+
|
100 |
+
elif init_tok == "<zero>":
|
101 |
+
token_embeds[placeholder_token_id] = torch.zeros_like(token_embeds[0])
|
102 |
+
else:
|
103 |
+
token_ids = tokenizer.encode(init_tok, add_special_tokens=False)
|
104 |
+
# Check if initializer_token is a single token or a sequence of tokens
|
105 |
+
if len(token_ids) > 1:
|
106 |
+
raise ValueError("The initializer token must be a single token.")
|
107 |
+
|
108 |
+
initializer_token_id = token_ids[0]
|
109 |
+
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
|
110 |
+
|
111 |
+
vae = AutoencoderKL.from_pretrained(
|
112 |
+
pretrained_vae_name_or_path or pretrained_model_name_or_path,
|
113 |
+
subfolder=None if pretrained_vae_name_or_path else "vae",
|
114 |
+
revision=None if pretrained_vae_name_or_path else revision,
|
115 |
+
)
|
116 |
+
unet = UNet2DConditionModel.from_pretrained(
|
117 |
+
pretrained_model_name_or_path,
|
118 |
+
subfolder="unet",
|
119 |
+
revision=revision,
|
120 |
+
)
|
121 |
+
|
122 |
+
return (
|
123 |
+
text_encoder.to(device),
|
124 |
+
vae.to(device),
|
125 |
+
unet.to(device),
|
126 |
+
tokenizer,
|
127 |
+
placeholder_token_ids,
|
128 |
+
)
|
129 |
+
|
130 |
+
|
131 |
+
@torch.no_grad()
|
132 |
+
def text2img_dataloader(
|
133 |
+
train_dataset,
|
134 |
+
train_batch_size,
|
135 |
+
tokenizer,
|
136 |
+
vae,
|
137 |
+
text_encoder,
|
138 |
+
cached_latents: bool = False,
|
139 |
+
):
|
140 |
+
|
141 |
+
if cached_latents:
|
142 |
+
cached_latents_dataset = []
|
143 |
+
for idx in tqdm(range(len(train_dataset))):
|
144 |
+
batch = train_dataset[idx]
|
145 |
+
# rint(batch)
|
146 |
+
latents = vae.encode(
|
147 |
+
batch["instance_images"].unsqueeze(0).to(dtype=vae.dtype).to(vae.device)
|
148 |
+
).latent_dist.sample()
|
149 |
+
latents = latents * 0.18215
|
150 |
+
batch["instance_images"] = latents.squeeze(0)
|
151 |
+
cached_latents_dataset.append(batch)
|
152 |
+
|
153 |
+
def collate_fn(examples):
|
154 |
+
input_ids = [example["instance_prompt_ids"] for example in examples]
|
155 |
+
pixel_values = [example["instance_images"] for example in examples]
|
156 |
+
pixel_values = torch.stack(pixel_values)
|
157 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
158 |
+
|
159 |
+
input_ids = tokenizer.pad(
|
160 |
+
{"input_ids": input_ids},
|
161 |
+
padding="max_length",
|
162 |
+
max_length=tokenizer.model_max_length,
|
163 |
+
return_tensors="pt",
|
164 |
+
).input_ids
|
165 |
+
|
166 |
+
batch = {
|
167 |
+
"input_ids": input_ids,
|
168 |
+
"pixel_values": pixel_values,
|
169 |
+
}
|
170 |
+
|
171 |
+
if examples[0].get("mask", None) is not None:
|
172 |
+
batch["mask"] = torch.stack([example["mask"] for example in examples])
|
173 |
+
|
174 |
+
return batch
|
175 |
+
|
176 |
+
if cached_latents:
|
177 |
+
|
178 |
+
train_dataloader = torch.utils.data.DataLoader(
|
179 |
+
cached_latents_dataset,
|
180 |
+
batch_size=train_batch_size,
|
181 |
+
shuffle=True,
|
182 |
+
collate_fn=collate_fn,
|
183 |
+
)
|
184 |
+
|
185 |
+
print("PTI : Using cached latent.")
|
186 |
+
|
187 |
+
else:
|
188 |
+
train_dataloader = torch.utils.data.DataLoader(
|
189 |
+
train_dataset,
|
190 |
+
batch_size=train_batch_size,
|
191 |
+
shuffle=True,
|
192 |
+
collate_fn=collate_fn,
|
193 |
+
)
|
194 |
+
|
195 |
+
return train_dataloader
|
196 |
+
|
197 |
+
|
198 |
+
def inpainting_dataloader(
|
199 |
+
train_dataset, train_batch_size, tokenizer, vae, text_encoder
|
200 |
+
):
|
201 |
+
def collate_fn(examples):
|
202 |
+
input_ids = [example["instance_prompt_ids"] for example in examples]
|
203 |
+
pixel_values = [example["instance_images"] for example in examples]
|
204 |
+
mask_values = [example["instance_masks"] for example in examples]
|
205 |
+
masked_image_values = [
|
206 |
+
example["instance_masked_images"] for example in examples
|
207 |
+
]
|
208 |
+
|
209 |
+
# Concat class and instance examples for prior preservation.
|
210 |
+
# We do this to avoid doing two forward passes.
|
211 |
+
if examples[0].get("class_prompt_ids", None) is not None:
|
212 |
+
input_ids += [example["class_prompt_ids"] for example in examples]
|
213 |
+
pixel_values += [example["class_images"] for example in examples]
|
214 |
+
mask_values += [example["class_masks"] for example in examples]
|
215 |
+
masked_image_values += [
|
216 |
+
example["class_masked_images"] for example in examples
|
217 |
+
]
|
218 |
+
|
219 |
+
pixel_values = (
|
220 |
+
torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
|
221 |
+
)
|
222 |
+
mask_values = (
|
223 |
+
torch.stack(mask_values).to(memory_format=torch.contiguous_format).float()
|
224 |
+
)
|
225 |
+
masked_image_values = (
|
226 |
+
torch.stack(masked_image_values)
|
227 |
+
.to(memory_format=torch.contiguous_format)
|
228 |
+
.float()
|
229 |
+
)
|
230 |
+
|
231 |
+
input_ids = tokenizer.pad(
|
232 |
+
{"input_ids": input_ids},
|
233 |
+
padding="max_length",
|
234 |
+
max_length=tokenizer.model_max_length,
|
235 |
+
return_tensors="pt",
|
236 |
+
).input_ids
|
237 |
+
|
238 |
+
batch = {
|
239 |
+
"input_ids": input_ids,
|
240 |
+
"pixel_values": pixel_values,
|
241 |
+
"mask_values": mask_values,
|
242 |
+
"masked_image_values": masked_image_values,
|
243 |
+
}
|
244 |
+
|
245 |
+
if examples[0].get("mask", None) is not None:
|
246 |
+
batch["mask"] = torch.stack([example["mask"] for example in examples])
|
247 |
+
|
248 |
+
return batch
|
249 |
+
|
250 |
+
train_dataloader = torch.utils.data.DataLoader(
|
251 |
+
train_dataset,
|
252 |
+
batch_size=train_batch_size,
|
253 |
+
shuffle=True,
|
254 |
+
collate_fn=collate_fn,
|
255 |
+
)
|
256 |
+
|
257 |
+
return train_dataloader
|
258 |
+
|
259 |
+
|
260 |
+
def loss_step(
|
261 |
+
batch,
|
262 |
+
unet,
|
263 |
+
vae,
|
264 |
+
text_encoder,
|
265 |
+
scheduler,
|
266 |
+
train_inpainting=False,
|
267 |
+
t_mutliplier=1.0,
|
268 |
+
mixed_precision=False,
|
269 |
+
mask_temperature=1.0,
|
270 |
+
cached_latents: bool = False,
|
271 |
+
):
|
272 |
+
weight_dtype = torch.float32
|
273 |
+
if not cached_latents:
|
274 |
+
latents = vae.encode(
|
275 |
+
batch["pixel_values"].to(dtype=weight_dtype).to(unet.device)
|
276 |
+
).latent_dist.sample()
|
277 |
+
latents = latents * 0.18215
|
278 |
+
|
279 |
+
if train_inpainting:
|
280 |
+
masked_image_latents = vae.encode(
|
281 |
+
batch["masked_image_values"].to(dtype=weight_dtype).to(unet.device)
|
282 |
+
).latent_dist.sample()
|
283 |
+
masked_image_latents = masked_image_latents * 0.18215
|
284 |
+
mask = F.interpolate(
|
285 |
+
batch["mask_values"].to(dtype=weight_dtype).to(unet.device),
|
286 |
+
scale_factor=1 / 8,
|
287 |
+
)
|
288 |
+
else:
|
289 |
+
latents = batch["pixel_values"]
|
290 |
+
|
291 |
+
if train_inpainting:
|
292 |
+
masked_image_latents = batch["masked_image_latents"]
|
293 |
+
mask = batch["mask_values"]
|
294 |
+
|
295 |
+
noise = torch.randn_like(latents)
|
296 |
+
bsz = latents.shape[0]
|
297 |
+
|
298 |
+
timesteps = torch.randint(
|
299 |
+
0,
|
300 |
+
int(scheduler.config.num_train_timesteps * t_mutliplier),
|
301 |
+
(bsz,),
|
302 |
+
device=latents.device,
|
303 |
+
)
|
304 |
+
timesteps = timesteps.long()
|
305 |
+
|
306 |
+
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
|
307 |
+
|
308 |
+
if train_inpainting:
|
309 |
+
latent_model_input = torch.cat(
|
310 |
+
[noisy_latents, mask, masked_image_latents], dim=1
|
311 |
+
)
|
312 |
+
else:
|
313 |
+
latent_model_input = noisy_latents
|
314 |
+
|
315 |
+
if mixed_precision:
|
316 |
+
with torch.cuda.amp.autocast():
|
317 |
+
|
318 |
+
encoder_hidden_states = text_encoder(
|
319 |
+
batch["input_ids"].to(text_encoder.device)
|
320 |
+
)[0]
|
321 |
+
|
322 |
+
model_pred = unet(
|
323 |
+
latent_model_input, timesteps, encoder_hidden_states
|
324 |
+
).sample
|
325 |
+
else:
|
326 |
+
|
327 |
+
encoder_hidden_states = text_encoder(
|
328 |
+
batch["input_ids"].to(text_encoder.device)
|
329 |
+
)[0]
|
330 |
+
|
331 |
+
model_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
|
332 |
+
|
333 |
+
if scheduler.config.prediction_type == "epsilon":
|
334 |
+
target = noise
|
335 |
+
elif scheduler.config.prediction_type == "v_prediction":
|
336 |
+
target = scheduler.get_velocity(latents, noise, timesteps)
|
337 |
+
else:
|
338 |
+
raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")
|
339 |
+
|
340 |
+
if batch.get("mask", None) is not None:
|
341 |
+
|
342 |
+
mask = (
|
343 |
+
batch["mask"]
|
344 |
+
.to(model_pred.device)
|
345 |
+
.reshape(
|
346 |
+
model_pred.shape[0], 1, model_pred.shape[2] * 8, model_pred.shape[3] * 8
|
347 |
+
)
|
348 |
+
)
|
349 |
+
# resize to match model_pred
|
350 |
+
mask = F.interpolate(
|
351 |
+
mask.float(),
|
352 |
+
size=model_pred.shape[-2:],
|
353 |
+
mode="nearest",
|
354 |
+
)
|
355 |
+
|
356 |
+
mask = (mask + 0.01).pow(mask_temperature)
|
357 |
+
|
358 |
+
mask = mask / mask.max()
|
359 |
+
|
360 |
+
model_pred = model_pred * mask
|
361 |
+
|
362 |
+
target = target * mask
|
363 |
+
|
364 |
+
loss = (
|
365 |
+
F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
366 |
+
.mean([1, 2, 3])
|
367 |
+
.mean()
|
368 |
+
)
|
369 |
+
|
370 |
+
return loss
|
371 |
+
|
372 |
+
|
373 |
+
def train_inversion(
|
374 |
+
unet,
|
375 |
+
vae,
|
376 |
+
text_encoder,
|
377 |
+
dataloader,
|
378 |
+
num_steps: int,
|
379 |
+
scheduler,
|
380 |
+
index_no_updates,
|
381 |
+
optimizer,
|
382 |
+
save_steps: int,
|
383 |
+
placeholder_token_ids,
|
384 |
+
placeholder_tokens,
|
385 |
+
save_path: str,
|
386 |
+
tokenizer,
|
387 |
+
lr_scheduler,
|
388 |
+
test_image_path: str,
|
389 |
+
cached_latents: bool,
|
390 |
+
accum_iter: int = 1,
|
391 |
+
log_wandb: bool = False,
|
392 |
+
wandb_log_prompt_cnt: int = 10,
|
393 |
+
class_token: str = "person",
|
394 |
+
train_inpainting: bool = False,
|
395 |
+
mixed_precision: bool = False,
|
396 |
+
clip_ti_decay: bool = True,
|
397 |
+
):
|
398 |
+
|
399 |
+
progress_bar = tqdm(range(num_steps))
|
400 |
+
progress_bar.set_description("Steps")
|
401 |
+
global_step = 0
|
402 |
+
|
403 |
+
# Original Emb for TI
|
404 |
+
orig_embeds_params = text_encoder.get_input_embeddings().weight.data.clone()
|
405 |
+
|
406 |
+
if log_wandb:
|
407 |
+
preped_clip = prepare_clip_model_sets()
|
408 |
+
|
409 |
+
index_updates = ~index_no_updates
|
410 |
+
loss_sum = 0.0
|
411 |
+
|
412 |
+
for epoch in range(math.ceil(num_steps / len(dataloader))):
|
413 |
+
unet.eval()
|
414 |
+
text_encoder.train()
|
415 |
+
for batch in dataloader:
|
416 |
+
|
417 |
+
lr_scheduler.step()
|
418 |
+
|
419 |
+
with torch.set_grad_enabled(True):
|
420 |
+
loss = (
|
421 |
+
loss_step(
|
422 |
+
batch,
|
423 |
+
unet,
|
424 |
+
vae,
|
425 |
+
text_encoder,
|
426 |
+
scheduler,
|
427 |
+
train_inpainting=train_inpainting,
|
428 |
+
mixed_precision=mixed_precision,
|
429 |
+
cached_latents=cached_latents,
|
430 |
+
)
|
431 |
+
/ accum_iter
|
432 |
+
)
|
433 |
+
|
434 |
+
loss.backward()
|
435 |
+
loss_sum += loss.detach().item()
|
436 |
+
|
437 |
+
if global_step % accum_iter == 0:
|
438 |
+
# print gradient of text encoder embedding
|
439 |
+
print(
|
440 |
+
text_encoder.get_input_embeddings()
|
441 |
+
.weight.grad[index_updates, :]
|
442 |
+
.norm(dim=-1)
|
443 |
+
.mean()
|
444 |
+
)
|
445 |
+
optimizer.step()
|
446 |
+
optimizer.zero_grad()
|
447 |
+
|
448 |
+
with torch.no_grad():
|
449 |
+
|
450 |
+
# normalize embeddings
|
451 |
+
if clip_ti_decay:
|
452 |
+
pre_norm = (
|
453 |
+
text_encoder.get_input_embeddings()
|
454 |
+
.weight[index_updates, :]
|
455 |
+
.norm(dim=-1, keepdim=True)
|
456 |
+
)
|
457 |
+
|
458 |
+
lambda_ = min(1.0, 100 * lr_scheduler.get_last_lr()[0])
|
459 |
+
text_encoder.get_input_embeddings().weight[
|
460 |
+
index_updates
|
461 |
+
] = F.normalize(
|
462 |
+
text_encoder.get_input_embeddings().weight[
|
463 |
+
index_updates, :
|
464 |
+
],
|
465 |
+
dim=-1,
|
466 |
+
) * (
|
467 |
+
pre_norm + lambda_ * (0.4 - pre_norm)
|
468 |
+
)
|
469 |
+
print(pre_norm)
|
470 |
+
|
471 |
+
current_norm = (
|
472 |
+
text_encoder.get_input_embeddings()
|
473 |
+
.weight[index_updates, :]
|
474 |
+
.norm(dim=-1)
|
475 |
+
)
|
476 |
+
|
477 |
+
text_encoder.get_input_embeddings().weight[
|
478 |
+
index_no_updates
|
479 |
+
] = orig_embeds_params[index_no_updates]
|
480 |
+
|
481 |
+
print(f"Current Norm : {current_norm}")
|
482 |
+
|
483 |
+
global_step += 1
|
484 |
+
progress_bar.update(1)
|
485 |
+
|
486 |
+
logs = {
|
487 |
+
"loss": loss.detach().item(),
|
488 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
489 |
+
}
|
490 |
+
progress_bar.set_postfix(**logs)
|
491 |
+
|
492 |
+
if global_step % save_steps == 0:
|
493 |
+
save_all(
|
494 |
+
unet=unet,
|
495 |
+
text_encoder=text_encoder,
|
496 |
+
placeholder_token_ids=placeholder_token_ids,
|
497 |
+
placeholder_tokens=placeholder_tokens,
|
498 |
+
save_path=os.path.join(
|
499 |
+
save_path, f"step_inv_{global_step}.safetensors"
|
500 |
+
),
|
501 |
+
save_lora=False,
|
502 |
+
)
|
503 |
+
if log_wandb:
|
504 |
+
with torch.no_grad():
|
505 |
+
pipe = StableDiffusionPipeline(
|
506 |
+
vae=vae,
|
507 |
+
text_encoder=text_encoder,
|
508 |
+
tokenizer=tokenizer,
|
509 |
+
unet=unet,
|
510 |
+
scheduler=scheduler,
|
511 |
+
safety_checker=None,
|
512 |
+
feature_extractor=None,
|
513 |
+
)
|
514 |
+
|
515 |
+
# open all images in test_image_path
|
516 |
+
images = []
|
517 |
+
for file in os.listdir(test_image_path):
|
518 |
+
if (
|
519 |
+
file.lower().endswith(".png")
|
520 |
+
or file.lower().endswith(".jpg")
|
521 |
+
or file.lower().endswith(".jpeg")
|
522 |
+
):
|
523 |
+
images.append(
|
524 |
+
Image.open(os.path.join(test_image_path, file))
|
525 |
+
)
|
526 |
+
|
527 |
+
wandb.log({"loss": loss_sum / save_steps})
|
528 |
+
loss_sum = 0.0
|
529 |
+
wandb.log(
|
530 |
+
evaluate_pipe(
|
531 |
+
pipe,
|
532 |
+
target_images=images,
|
533 |
+
class_token=class_token,
|
534 |
+
learnt_token="".join(placeholder_tokens),
|
535 |
+
n_test=wandb_log_prompt_cnt,
|
536 |
+
n_step=50,
|
537 |
+
clip_model_sets=preped_clip,
|
538 |
+
)
|
539 |
+
)
|
540 |
+
|
541 |
+
if global_step >= num_steps:
|
542 |
+
return
|
543 |
+
|
544 |
+
|
545 |
+
def perform_tuning(
|
546 |
+
unet,
|
547 |
+
vae,
|
548 |
+
text_encoder,
|
549 |
+
dataloader,
|
550 |
+
num_steps,
|
551 |
+
scheduler,
|
552 |
+
optimizer,
|
553 |
+
save_steps: int,
|
554 |
+
placeholder_token_ids,
|
555 |
+
placeholder_tokens,
|
556 |
+
save_path,
|
557 |
+
lr_scheduler_lora,
|
558 |
+
lora_unet_target_modules,
|
559 |
+
lora_clip_target_modules,
|
560 |
+
mask_temperature,
|
561 |
+
out_name: str,
|
562 |
+
tokenizer,
|
563 |
+
test_image_path: str,
|
564 |
+
cached_latents: bool,
|
565 |
+
log_wandb: bool = False,
|
566 |
+
wandb_log_prompt_cnt: int = 10,
|
567 |
+
class_token: str = "person",
|
568 |
+
train_inpainting: bool = False,
|
569 |
+
):
|
570 |
+
|
571 |
+
progress_bar = tqdm(range(num_steps))
|
572 |
+
progress_bar.set_description("Steps")
|
573 |
+
global_step = 0
|
574 |
+
|
575 |
+
weight_dtype = torch.float16
|
576 |
+
|
577 |
+
unet.train()
|
578 |
+
text_encoder.train()
|
579 |
+
|
580 |
+
if log_wandb:
|
581 |
+
preped_clip = prepare_clip_model_sets()
|
582 |
+
|
583 |
+
loss_sum = 0.0
|
584 |
+
|
585 |
+
for epoch in range(math.ceil(num_steps / len(dataloader))):
|
586 |
+
for batch in dataloader:
|
587 |
+
lr_scheduler_lora.step()
|
588 |
+
|
589 |
+
optimizer.zero_grad()
|
590 |
+
|
591 |
+
loss = loss_step(
|
592 |
+
batch,
|
593 |
+
unet,
|
594 |
+
vae,
|
595 |
+
text_encoder,
|
596 |
+
scheduler,
|
597 |
+
train_inpainting=train_inpainting,
|
598 |
+
t_mutliplier=0.8,
|
599 |
+
mixed_precision=True,
|
600 |
+
mask_temperature=mask_temperature,
|
601 |
+
cached_latents=cached_latents,
|
602 |
+
)
|
603 |
+
loss_sum += loss.detach().item()
|
604 |
+
|
605 |
+
loss.backward()
|
606 |
+
torch.nn.utils.clip_grad_norm_(
|
607 |
+
itertools.chain(unet.parameters(), text_encoder.parameters()), 1.0
|
608 |
+
)
|
609 |
+
optimizer.step()
|
610 |
+
progress_bar.update(1)
|
611 |
+
logs = {
|
612 |
+
"loss": loss.detach().item(),
|
613 |
+
"lr": lr_scheduler_lora.get_last_lr()[0],
|
614 |
+
}
|
615 |
+
progress_bar.set_postfix(**logs)
|
616 |
+
|
617 |
+
global_step += 1
|
618 |
+
|
619 |
+
if global_step % save_steps == 0:
|
620 |
+
save_all(
|
621 |
+
unet,
|
622 |
+
text_encoder,
|
623 |
+
placeholder_token_ids=placeholder_token_ids,
|
624 |
+
placeholder_tokens=placeholder_tokens,
|
625 |
+
save_path=os.path.join(
|
626 |
+
save_path, f"step_{global_step}.safetensors"
|
627 |
+
),
|
628 |
+
target_replace_module_text=lora_clip_target_modules,
|
629 |
+
target_replace_module_unet=lora_unet_target_modules,
|
630 |
+
)
|
631 |
+
moved = (
|
632 |
+
torch.tensor(list(itertools.chain(*inspect_lora(unet).values())))
|
633 |
+
.mean()
|
634 |
+
.item()
|
635 |
+
)
|
636 |
+
|
637 |
+
print("LORA Unet Moved", moved)
|
638 |
+
moved = (
|
639 |
+
torch.tensor(
|
640 |
+
list(itertools.chain(*inspect_lora(text_encoder).values()))
|
641 |
+
)
|
642 |
+
.mean()
|
643 |
+
.item()
|
644 |
+
)
|
645 |
+
|
646 |
+
print("LORA CLIP Moved", moved)
|
647 |
+
|
648 |
+
if log_wandb:
|
649 |
+
with torch.no_grad():
|
650 |
+
pipe = StableDiffusionPipeline(
|
651 |
+
vae=vae,
|
652 |
+
text_encoder=text_encoder,
|
653 |
+
tokenizer=tokenizer,
|
654 |
+
unet=unet,
|
655 |
+
scheduler=scheduler,
|
656 |
+
safety_checker=None,
|
657 |
+
feature_extractor=None,
|
658 |
+
)
|
659 |
+
|
660 |
+
# open all images in test_image_path
|
661 |
+
images = []
|
662 |
+
for file in os.listdir(test_image_path):
|
663 |
+
if file.endswith(".png") or file.endswith(".jpg"):
|
664 |
+
images.append(
|
665 |
+
Image.open(os.path.join(test_image_path, file))
|
666 |
+
)
|
667 |
+
|
668 |
+
wandb.log({"loss": loss_sum / save_steps})
|
669 |
+
loss_sum = 0.0
|
670 |
+
wandb.log(
|
671 |
+
evaluate_pipe(
|
672 |
+
pipe,
|
673 |
+
target_images=images,
|
674 |
+
class_token=class_token,
|
675 |
+
learnt_token="".join(placeholder_tokens),
|
676 |
+
n_test=wandb_log_prompt_cnt,
|
677 |
+
n_step=50,
|
678 |
+
clip_model_sets=preped_clip,
|
679 |
+
)
|
680 |
+
)
|
681 |
+
|
682 |
+
if global_step >= num_steps:
|
683 |
+
break
|
684 |
+
|
685 |
+
save_all(
|
686 |
+
unet,
|
687 |
+
text_encoder,
|
688 |
+
placeholder_token_ids=placeholder_token_ids,
|
689 |
+
placeholder_tokens=placeholder_tokens,
|
690 |
+
save_path=os.path.join(save_path, f"{out_name}.safetensors"),
|
691 |
+
target_replace_module_text=lora_clip_target_modules,
|
692 |
+
target_replace_module_unet=lora_unet_target_modules,
|
693 |
+
)
|
694 |
+
|
695 |
+
|
696 |
+
def train(
|
697 |
+
instance_data_dir: str,
|
698 |
+
pretrained_model_name_or_path: str,
|
699 |
+
output_dir: str,
|
700 |
+
train_text_encoder: bool = True,
|
701 |
+
pretrained_vae_name_or_path: str = None,
|
702 |
+
revision: Optional[str] = None,
|
703 |
+
perform_inversion: bool = True,
|
704 |
+
use_template: Literal[None, "object", "style"] = None,
|
705 |
+
train_inpainting: bool = False,
|
706 |
+
placeholder_tokens: str = "",
|
707 |
+
placeholder_token_at_data: Optional[str] = None,
|
708 |
+
initializer_tokens: Optional[str] = None,
|
709 |
+
seed: int = 42,
|
710 |
+
resolution: int = 512,
|
711 |
+
color_jitter: bool = True,
|
712 |
+
train_batch_size: int = 1,
|
713 |
+
sample_batch_size: int = 1,
|
714 |
+
max_train_steps_tuning: int = 1000,
|
715 |
+
max_train_steps_ti: int = 1000,
|
716 |
+
save_steps: int = 100,
|
717 |
+
gradient_accumulation_steps: int = 4,
|
718 |
+
gradient_checkpointing: bool = False,
|
719 |
+
lora_rank: int = 4,
|
720 |
+
lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"},
|
721 |
+
lora_clip_target_modules={"CLIPAttention"},
|
722 |
+
lora_dropout_p: float = 0.0,
|
723 |
+
lora_scale: float = 1.0,
|
724 |
+
use_extended_lora: bool = False,
|
725 |
+
clip_ti_decay: bool = True,
|
726 |
+
learning_rate_unet: float = 1e-4,
|
727 |
+
learning_rate_text: float = 1e-5,
|
728 |
+
learning_rate_ti: float = 5e-4,
|
729 |
+
continue_inversion: bool = False,
|
730 |
+
continue_inversion_lr: Optional[float] = None,
|
731 |
+
use_face_segmentation_condition: bool = False,
|
732 |
+
cached_latents: bool = True,
|
733 |
+
use_mask_captioned_data: bool = False,
|
734 |
+
mask_temperature: float = 1.0,
|
735 |
+
scale_lr: bool = False,
|
736 |
+
lr_scheduler: str = "linear",
|
737 |
+
lr_warmup_steps: int = 0,
|
738 |
+
lr_scheduler_lora: str = "linear",
|
739 |
+
lr_warmup_steps_lora: int = 0,
|
740 |
+
weight_decay_ti: float = 0.00,
|
741 |
+
weight_decay_lora: float = 0.001,
|
742 |
+
use_8bit_adam: bool = False,
|
743 |
+
device="cuda:0",
|
744 |
+
extra_args: Optional[dict] = None,
|
745 |
+
log_wandb: bool = False,
|
746 |
+
wandb_log_prompt_cnt: int = 10,
|
747 |
+
wandb_project_name: str = "new_pti_project",
|
748 |
+
wandb_entity: str = "new_pti_entity",
|
749 |
+
proxy_token: str = "person",
|
750 |
+
enable_xformers_memory_efficient_attention: bool = False,
|
751 |
+
out_name: str = "final_lora",
|
752 |
+
):
|
753 |
+
torch.manual_seed(seed)
|
754 |
+
|
755 |
+
if log_wandb:
|
756 |
+
wandb.init(
|
757 |
+
project=wandb_project_name,
|
758 |
+
entity=wandb_entity,
|
759 |
+
name=f"steps_{max_train_steps_ti}_lr_{learning_rate_ti}_{instance_data_dir.split('/')[-1]}",
|
760 |
+
reinit=True,
|
761 |
+
config={
|
762 |
+
**(extra_args if extra_args is not None else {}),
|
763 |
+
},
|
764 |
+
)
|
765 |
+
|
766 |
+
if output_dir is not None:
|
767 |
+
os.makedirs(output_dir, exist_ok=True)
|
768 |
+
# print(placeholder_tokens, initializer_tokens)
|
769 |
+
if len(placeholder_tokens) == 0:
|
770 |
+
placeholder_tokens = []
|
771 |
+
print("PTI : Placeholder Tokens not given, using null token")
|
772 |
+
else:
|
773 |
+
placeholder_tokens = placeholder_tokens.split("|")
|
774 |
+
|
775 |
+
assert (
|
776 |
+
sorted(placeholder_tokens) == placeholder_tokens
|
777 |
+
), f"Placeholder tokens should be sorted. Use something like {'|'.join(sorted(placeholder_tokens))}'"
|
778 |
+
|
779 |
+
if initializer_tokens is None:
|
780 |
+
print("PTI : Initializer Tokens not given, doing random inits")
|
781 |
+
initializer_tokens = ["<rand-0.017>"] * len(placeholder_tokens)
|
782 |
+
else:
|
783 |
+
initializer_tokens = initializer_tokens.split("|")
|
784 |
+
|
785 |
+
assert len(initializer_tokens) == len(
|
786 |
+
placeholder_tokens
|
787 |
+
), "Unequal Initializer token for Placeholder tokens."
|
788 |
+
|
789 |
+
if proxy_token is not None:
|
790 |
+
class_token = proxy_token
|
791 |
+
class_token = "".join(initializer_tokens)
|
792 |
+
|
793 |
+
if placeholder_token_at_data is not None:
|
794 |
+
tok, pat = placeholder_token_at_data.split("|")
|
795 |
+
token_map = {tok: pat}
|
796 |
+
|
797 |
+
else:
|
798 |
+
token_map = {"DUMMY": "".join(placeholder_tokens)}
|
799 |
+
|
800 |
+
print("PTI : Placeholder Tokens", placeholder_tokens)
|
801 |
+
print("PTI : Initializer Tokens", initializer_tokens)
|
802 |
+
|
803 |
+
# get the models
|
804 |
+
text_encoder, vae, unet, tokenizer, placeholder_token_ids = get_models(
|
805 |
+
pretrained_model_name_or_path,
|
806 |
+
pretrained_vae_name_or_path,
|
807 |
+
revision,
|
808 |
+
placeholder_tokens,
|
809 |
+
initializer_tokens,
|
810 |
+
device=device,
|
811 |
+
)
|
812 |
+
|
813 |
+
noise_scheduler = DDPMScheduler.from_config(
|
814 |
+
pretrained_model_name_or_path, subfolder="scheduler"
|
815 |
+
)
|
816 |
+
|
817 |
+
if gradient_checkpointing:
|
818 |
+
unet.enable_gradient_checkpointing()
|
819 |
+
|
820 |
+
if enable_xformers_memory_efficient_attention:
|
821 |
+
from diffusers.utils.import_utils import is_xformers_available
|
822 |
+
|
823 |
+
if is_xformers_available():
|
824 |
+
unet.enable_xformers_memory_efficient_attention()
|
825 |
+
else:
|
826 |
+
raise ValueError(
|
827 |
+
"xformers is not available. Make sure it is installed correctly"
|
828 |
+
)
|
829 |
+
|
830 |
+
if scale_lr:
|
831 |
+
unet_lr = learning_rate_unet * gradient_accumulation_steps * train_batch_size
|
832 |
+
text_encoder_lr = (
|
833 |
+
learning_rate_text * gradient_accumulation_steps * train_batch_size
|
834 |
+
)
|
835 |
+
ti_lr = learning_rate_ti * gradient_accumulation_steps * train_batch_size
|
836 |
+
else:
|
837 |
+
unet_lr = learning_rate_unet
|
838 |
+
text_encoder_lr = learning_rate_text
|
839 |
+
ti_lr = learning_rate_ti
|
840 |
+
|
841 |
+
train_dataset = PivotalTuningDatasetCapation(
|
842 |
+
instance_data_root=instance_data_dir,
|
843 |
+
token_map=token_map,
|
844 |
+
use_template=use_template,
|
845 |
+
tokenizer=tokenizer,
|
846 |
+
size=resolution,
|
847 |
+
color_jitter=color_jitter,
|
848 |
+
use_face_segmentation_condition=use_face_segmentation_condition,
|
849 |
+
use_mask_captioned_data=use_mask_captioned_data,
|
850 |
+
train_inpainting=train_inpainting,
|
851 |
+
)
|
852 |
+
|
853 |
+
train_dataset.blur_amount = 200
|
854 |
+
|
855 |
+
if train_inpainting:
|
856 |
+
assert not cached_latents, "Cached latents not supported for inpainting"
|
857 |
+
|
858 |
+
train_dataloader = inpainting_dataloader(
|
859 |
+
train_dataset, train_batch_size, tokenizer, vae, text_encoder
|
860 |
+
)
|
861 |
+
else:
|
862 |
+
train_dataloader = text2img_dataloader(
|
863 |
+
train_dataset,
|
864 |
+
train_batch_size,
|
865 |
+
tokenizer,
|
866 |
+
vae,
|
867 |
+
text_encoder,
|
868 |
+
cached_latents=cached_latents,
|
869 |
+
)
|
870 |
+
|
871 |
+
index_no_updates = torch.arange(len(tokenizer)) != -1
|
872 |
+
|
873 |
+
for tok_id in placeholder_token_ids:
|
874 |
+
index_no_updates[tok_id] = False
|
875 |
+
|
876 |
+
unet.requires_grad_(False)
|
877 |
+
vae.requires_grad_(False)
|
878 |
+
|
879 |
+
params_to_freeze = itertools.chain(
|
880 |
+
text_encoder.text_model.encoder.parameters(),
|
881 |
+
text_encoder.text_model.final_layer_norm.parameters(),
|
882 |
+
text_encoder.text_model.embeddings.position_embedding.parameters(),
|
883 |
+
)
|
884 |
+
for param in params_to_freeze:
|
885 |
+
param.requires_grad = False
|
886 |
+
|
887 |
+
if cached_latents:
|
888 |
+
vae = None
|
889 |
+
# STEP 1 : Perform Inversion
|
890 |
+
if perform_inversion:
|
891 |
+
ti_optimizer = optim.AdamW(
|
892 |
+
text_encoder.get_input_embeddings().parameters(),
|
893 |
+
lr=ti_lr,
|
894 |
+
betas=(0.9, 0.999),
|
895 |
+
eps=1e-08,
|
896 |
+
weight_decay=weight_decay_ti,
|
897 |
+
)
|
898 |
+
|
899 |
+
lr_scheduler = get_scheduler(
|
900 |
+
lr_scheduler,
|
901 |
+
optimizer=ti_optimizer,
|
902 |
+
num_warmup_steps=lr_warmup_steps,
|
903 |
+
num_training_steps=max_train_steps_ti,
|
904 |
+
)
|
905 |
+
|
906 |
+
train_inversion(
|
907 |
+
unet,
|
908 |
+
vae,
|
909 |
+
text_encoder,
|
910 |
+
train_dataloader,
|
911 |
+
max_train_steps_ti,
|
912 |
+
cached_latents=cached_latents,
|
913 |
+
accum_iter=gradient_accumulation_steps,
|
914 |
+
scheduler=noise_scheduler,
|
915 |
+
index_no_updates=index_no_updates,
|
916 |
+
optimizer=ti_optimizer,
|
917 |
+
lr_scheduler=lr_scheduler,
|
918 |
+
save_steps=save_steps,
|
919 |
+
placeholder_tokens=placeholder_tokens,
|
920 |
+
placeholder_token_ids=placeholder_token_ids,
|
921 |
+
save_path=output_dir,
|
922 |
+
test_image_path=instance_data_dir,
|
923 |
+
log_wandb=log_wandb,
|
924 |
+
wandb_log_prompt_cnt=wandb_log_prompt_cnt,
|
925 |
+
class_token=class_token,
|
926 |
+
train_inpainting=train_inpainting,
|
927 |
+
mixed_precision=False,
|
928 |
+
tokenizer=tokenizer,
|
929 |
+
clip_ti_decay=clip_ti_decay,
|
930 |
+
)
|
931 |
+
|
932 |
+
del ti_optimizer
|
933 |
+
|
934 |
+
# Next perform Tuning with LoRA:
|
935 |
+
if not use_extended_lora:
|
936 |
+
unet_lora_params, _ = inject_trainable_lora(
|
937 |
+
unet,
|
938 |
+
r=lora_rank,
|
939 |
+
target_replace_module=lora_unet_target_modules,
|
940 |
+
dropout_p=lora_dropout_p,
|
941 |
+
scale=lora_scale,
|
942 |
+
)
|
943 |
+
else:
|
944 |
+
print("PTI : USING EXTENDED UNET!!!")
|
945 |
+
lora_unet_target_modules = (
|
946 |
+
lora_unet_target_modules | UNET_EXTENDED_TARGET_REPLACE
|
947 |
+
)
|
948 |
+
print("PTI : Will replace modules: ", lora_unet_target_modules)
|
949 |
+
|
950 |
+
unet_lora_params, _ = inject_trainable_lora_extended(
|
951 |
+
unet, r=lora_rank, target_replace_module=lora_unet_target_modules
|
952 |
+
)
|
953 |
+
print(f"PTI : has {len(unet_lora_params)} lora")
|
954 |
+
|
955 |
+
print("PTI : Before training:")
|
956 |
+
inspect_lora(unet)
|
957 |
+
|
958 |
+
params_to_optimize = [
|
959 |
+
{"params": itertools.chain(*unet_lora_params), "lr": unet_lr},
|
960 |
+
]
|
961 |
+
|
962 |
+
text_encoder.requires_grad_(False)
|
963 |
+
|
964 |
+
if continue_inversion:
|
965 |
+
params_to_optimize += [
|
966 |
+
{
|
967 |
+
"params": text_encoder.get_input_embeddings().parameters(),
|
968 |
+
"lr": continue_inversion_lr
|
969 |
+
if continue_inversion_lr is not None
|
970 |
+
else ti_lr,
|
971 |
+
}
|
972 |
+
]
|
973 |
+
text_encoder.requires_grad_(True)
|
974 |
+
params_to_freeze = itertools.chain(
|
975 |
+
text_encoder.text_model.encoder.parameters(),
|
976 |
+
text_encoder.text_model.final_layer_norm.parameters(),
|
977 |
+
text_encoder.text_model.embeddings.position_embedding.parameters(),
|
978 |
+
)
|
979 |
+
for param in params_to_freeze:
|
980 |
+
param.requires_grad = False
|
981 |
+
else:
|
982 |
+
text_encoder.requires_grad_(False)
|
983 |
+
if train_text_encoder:
|
984 |
+
text_encoder_lora_params, _ = inject_trainable_lora(
|
985 |
+
text_encoder,
|
986 |
+
target_replace_module=lora_clip_target_modules,
|
987 |
+
r=lora_rank,
|
988 |
+
)
|
989 |
+
params_to_optimize += [
|
990 |
+
{
|
991 |
+
"params": itertools.chain(*text_encoder_lora_params),
|
992 |
+
"lr": text_encoder_lr,
|
993 |
+
}
|
994 |
+
]
|
995 |
+
inspect_lora(text_encoder)
|
996 |
+
|
997 |
+
lora_optimizers = optim.AdamW(params_to_optimize, weight_decay=weight_decay_lora)
|
998 |
+
|
999 |
+
unet.train()
|
1000 |
+
if train_text_encoder:
|
1001 |
+
text_encoder.train()
|
1002 |
+
|
1003 |
+
train_dataset.blur_amount = 70
|
1004 |
+
|
1005 |
+
lr_scheduler_lora = get_scheduler(
|
1006 |
+
lr_scheduler_lora,
|
1007 |
+
optimizer=lora_optimizers,
|
1008 |
+
num_warmup_steps=lr_warmup_steps_lora,
|
1009 |
+
num_training_steps=max_train_steps_tuning,
|
1010 |
+
)
|
1011 |
+
|
1012 |
+
perform_tuning(
|
1013 |
+
unet,
|
1014 |
+
vae,
|
1015 |
+
text_encoder,
|
1016 |
+
train_dataloader,
|
1017 |
+
max_train_steps_tuning,
|
1018 |
+
cached_latents=cached_latents,
|
1019 |
+
scheduler=noise_scheduler,
|
1020 |
+
optimizer=lora_optimizers,
|
1021 |
+
save_steps=save_steps,
|
1022 |
+
placeholder_tokens=placeholder_tokens,
|
1023 |
+
placeholder_token_ids=placeholder_token_ids,
|
1024 |
+
save_path=output_dir,
|
1025 |
+
lr_scheduler_lora=lr_scheduler_lora,
|
1026 |
+
lora_unet_target_modules=lora_unet_target_modules,
|
1027 |
+
lora_clip_target_modules=lora_clip_target_modules,
|
1028 |
+
mask_temperature=mask_temperature,
|
1029 |
+
tokenizer=tokenizer,
|
1030 |
+
out_name=out_name,
|
1031 |
+
test_image_path=instance_data_dir,
|
1032 |
+
log_wandb=log_wandb,
|
1033 |
+
wandb_log_prompt_cnt=wandb_log_prompt_cnt,
|
1034 |
+
class_token=class_token,
|
1035 |
+
train_inpainting=train_inpainting,
|
1036 |
+
)
|
1037 |
+
|
1038 |
+
|
1039 |
+
def main():
|
1040 |
+
fire.Fire(train)
|
lora_diffusion/cli_pt_to_safetensors.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import fire
|
4 |
+
import torch
|
5 |
+
from lora_diffusion import (
|
6 |
+
DEFAULT_TARGET_REPLACE,
|
7 |
+
TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
|
8 |
+
UNET_DEFAULT_TARGET_REPLACE,
|
9 |
+
convert_loras_to_safeloras_with_embeds,
|
10 |
+
safetensors_available,
|
11 |
+
)
|
12 |
+
|
13 |
+
_target_by_name = {
|
14 |
+
"unet": UNET_DEFAULT_TARGET_REPLACE,
|
15 |
+
"text_encoder": TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
|
16 |
+
}
|
17 |
+
|
18 |
+
|
19 |
+
def convert(*paths, outpath, overwrite=False, **settings):
|
20 |
+
"""
|
21 |
+
Converts one or more pytorch Lora and/or Textual Embedding pytorch files
|
22 |
+
into a safetensor file.
|
23 |
+
|
24 |
+
Pass all the input paths as arguments. Whether they are Textual Embedding
|
25 |
+
or Lora models will be auto-detected.
|
26 |
+
|
27 |
+
For Lora models, their name will be taken from the path, i.e.
|
28 |
+
"lora_weight.pt" => unet
|
29 |
+
"lora_weight.text_encoder.pt" => text_encoder
|
30 |
+
|
31 |
+
You can also set target_modules and/or rank by providing an argument prefixed
|
32 |
+
by the name.
|
33 |
+
|
34 |
+
So a complete example might be something like:
|
35 |
+
|
36 |
+
```
|
37 |
+
python -m lora_diffusion.cli_pt_to_safetensors lora_weight.* --outpath lora_weight.safetensor --unet.rank 8
|
38 |
+
```
|
39 |
+
"""
|
40 |
+
modelmap = {}
|
41 |
+
embeds = {}
|
42 |
+
|
43 |
+
if os.path.exists(outpath) and not overwrite:
|
44 |
+
raise ValueError(
|
45 |
+
f"Output path {outpath} already exists, and overwrite is not True"
|
46 |
+
)
|
47 |
+
|
48 |
+
for path in paths:
|
49 |
+
data = torch.load(path)
|
50 |
+
|
51 |
+
if isinstance(data, dict):
|
52 |
+
print(f"Loading textual inversion embeds {data.keys()} from {path}")
|
53 |
+
embeds.update(data)
|
54 |
+
|
55 |
+
else:
|
56 |
+
name_parts = os.path.split(path)[1].split(".")
|
57 |
+
name = name_parts[-2] if len(name_parts) > 2 else "unet"
|
58 |
+
|
59 |
+
model_settings = {
|
60 |
+
"target_modules": _target_by_name.get(name, DEFAULT_TARGET_REPLACE),
|
61 |
+
"rank": 4,
|
62 |
+
}
|
63 |
+
|
64 |
+
prefix = f"{name}."
|
65 |
+
|
66 |
+
arg_settings = { k[len(prefix) :]: v for k, v in settings.items() if k.startswith(prefix) }
|
67 |
+
model_settings = { **model_settings, **arg_settings }
|
68 |
+
|
69 |
+
print(f"Loading Lora for {name} from {path} with settings {model_settings}")
|
70 |
+
|
71 |
+
modelmap[name] = (
|
72 |
+
path,
|
73 |
+
model_settings["target_modules"],
|
74 |
+
model_settings["rank"],
|
75 |
+
)
|
76 |
+
|
77 |
+
convert_loras_to_safeloras_with_embeds(modelmap, embeds, outpath)
|
78 |
+
|
79 |
+
|
80 |
+
def main():
|
81 |
+
fire.Fire(convert)
|
82 |
+
|
83 |
+
|
84 |
+
if __name__ == "__main__":
|
85 |
+
main()
|
lora_diffusion/cli_svd.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import fire
|
2 |
+
from diffusers import StableDiffusionPipeline
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from .lora import (
|
7 |
+
save_all,
|
8 |
+
_find_modules,
|
9 |
+
LoraInjectedConv2d,
|
10 |
+
LoraInjectedLinear,
|
11 |
+
inject_trainable_lora,
|
12 |
+
inject_trainable_lora_extended,
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
def _iter_lora(model):
|
17 |
+
for module in model.modules():
|
18 |
+
if isinstance(module, LoraInjectedConv2d) or isinstance(
|
19 |
+
module, LoraInjectedLinear
|
20 |
+
):
|
21 |
+
yield module
|
22 |
+
|
23 |
+
|
24 |
+
def overwrite_base(base_model, tuned_model, rank, clamp_quantile):
|
25 |
+
device = base_model.device
|
26 |
+
dtype = base_model.dtype
|
27 |
+
|
28 |
+
for lor_base, lor_tune in zip(_iter_lora(base_model), _iter_lora(tuned_model)):
|
29 |
+
|
30 |
+
if isinstance(lor_base, LoraInjectedLinear):
|
31 |
+
residual = lor_tune.linear.weight.data - lor_base.linear.weight.data
|
32 |
+
# SVD on residual
|
33 |
+
print("Distill Linear shape ", residual.shape)
|
34 |
+
residual = residual.float()
|
35 |
+
U, S, Vh = torch.linalg.svd(residual)
|
36 |
+
U = U[:, :rank]
|
37 |
+
S = S[:rank]
|
38 |
+
U = U @ torch.diag(S)
|
39 |
+
|
40 |
+
Vh = Vh[:rank, :]
|
41 |
+
|
42 |
+
dist = torch.cat([U.flatten(), Vh.flatten()])
|
43 |
+
hi_val = torch.quantile(dist, clamp_quantile)
|
44 |
+
low_val = -hi_val
|
45 |
+
|
46 |
+
U = U.clamp(low_val, hi_val)
|
47 |
+
Vh = Vh.clamp(low_val, hi_val)
|
48 |
+
|
49 |
+
assert lor_base.lora_up.weight.shape == U.shape
|
50 |
+
assert lor_base.lora_down.weight.shape == Vh.shape
|
51 |
+
|
52 |
+
lor_base.lora_up.weight.data = U.to(device=device, dtype=dtype)
|
53 |
+
lor_base.lora_down.weight.data = Vh.to(device=device, dtype=dtype)
|
54 |
+
|
55 |
+
if isinstance(lor_base, LoraInjectedConv2d):
|
56 |
+
residual = lor_tune.conv.weight.data - lor_base.conv.weight.data
|
57 |
+
print("Distill Conv shape ", residual.shape)
|
58 |
+
|
59 |
+
residual = residual.float()
|
60 |
+
residual = residual.flatten(start_dim=1)
|
61 |
+
|
62 |
+
# SVD on residual
|
63 |
+
U, S, Vh = torch.linalg.svd(residual)
|
64 |
+
U = U[:, :rank]
|
65 |
+
S = S[:rank]
|
66 |
+
U = U @ torch.diag(S)
|
67 |
+
|
68 |
+
Vh = Vh[:rank, :]
|
69 |
+
|
70 |
+
dist = torch.cat([U.flatten(), Vh.flatten()])
|
71 |
+
hi_val = torch.quantile(dist, clamp_quantile)
|
72 |
+
low_val = -hi_val
|
73 |
+
|
74 |
+
U = U.clamp(low_val, hi_val)
|
75 |
+
Vh = Vh.clamp(low_val, hi_val)
|
76 |
+
|
77 |
+
# U is (out_channels, rank) with 1x1 conv. So,
|
78 |
+
U = U.reshape(U.shape[0], U.shape[1], 1, 1)
|
79 |
+
# V is (rank, in_channels * kernel_size1 * kernel_size2)
|
80 |
+
# now reshape:
|
81 |
+
Vh = Vh.reshape(
|
82 |
+
Vh.shape[0],
|
83 |
+
lor_base.conv.in_channels,
|
84 |
+
lor_base.conv.kernel_size[0],
|
85 |
+
lor_base.conv.kernel_size[1],
|
86 |
+
)
|
87 |
+
|
88 |
+
assert lor_base.lora_up.weight.shape == U.shape
|
89 |
+
assert lor_base.lora_down.weight.shape == Vh.shape
|
90 |
+
|
91 |
+
lor_base.lora_up.weight.data = U.to(device=device, dtype=dtype)
|
92 |
+
lor_base.lora_down.weight.data = Vh.to(device=device, dtype=dtype)
|
93 |
+
|
94 |
+
|
95 |
+
def svd_distill(
|
96 |
+
target_model: str,
|
97 |
+
base_model: str,
|
98 |
+
rank: int = 4,
|
99 |
+
clamp_quantile: float = 0.99,
|
100 |
+
device: str = "cuda:0",
|
101 |
+
save_path: str = "svd_distill.safetensors",
|
102 |
+
):
|
103 |
+
pipe_base = StableDiffusionPipeline.from_pretrained(
|
104 |
+
base_model, torch_dtype=torch.float16
|
105 |
+
).to(device)
|
106 |
+
|
107 |
+
pipe_tuned = StableDiffusionPipeline.from_pretrained(
|
108 |
+
target_model, torch_dtype=torch.float16
|
109 |
+
).to(device)
|
110 |
+
|
111 |
+
# Inject unet
|
112 |
+
_ = inject_trainable_lora_extended(pipe_base.unet, r=rank)
|
113 |
+
_ = inject_trainable_lora_extended(pipe_tuned.unet, r=rank)
|
114 |
+
|
115 |
+
overwrite_base(
|
116 |
+
pipe_base.unet, pipe_tuned.unet, rank=rank, clamp_quantile=clamp_quantile
|
117 |
+
)
|
118 |
+
|
119 |
+
# Inject text encoder
|
120 |
+
_ = inject_trainable_lora(
|
121 |
+
pipe_base.text_encoder, r=rank, target_replace_module={"CLIPAttention"}
|
122 |
+
)
|
123 |
+
_ = inject_trainable_lora(
|
124 |
+
pipe_tuned.text_encoder, r=rank, target_replace_module={"CLIPAttention"}
|
125 |
+
)
|
126 |
+
|
127 |
+
overwrite_base(
|
128 |
+
pipe_base.text_encoder,
|
129 |
+
pipe_tuned.text_encoder,
|
130 |
+
rank=rank,
|
131 |
+
clamp_quantile=clamp_quantile,
|
132 |
+
)
|
133 |
+
|
134 |
+
save_all(
|
135 |
+
unet=pipe_base.unet,
|
136 |
+
text_encoder=pipe_base.text_encoder,
|
137 |
+
placeholder_token_ids=None,
|
138 |
+
placeholder_tokens=None,
|
139 |
+
save_path=save_path,
|
140 |
+
save_lora=True,
|
141 |
+
save_ti=False,
|
142 |
+
)
|
143 |
+
|
144 |
+
|
145 |
+
def main():
|
146 |
+
fire.Fire(svd_distill)
|
lora_diffusion/dataset.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Dict, List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
from PIL import Image
|
6 |
+
from torch import zeros_like
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from torchvision import transforms
|
9 |
+
import glob
|
10 |
+
from .preprocess_files import face_mask_google_mediapipe
|
11 |
+
|
12 |
+
OBJECT_TEMPLATE = [
|
13 |
+
"a photo of a {}",
|
14 |
+
"a rendering of a {}",
|
15 |
+
"a cropped photo of the {}",
|
16 |
+
"the photo of a {}",
|
17 |
+
"a photo of a clean {}",
|
18 |
+
"a photo of a dirty {}",
|
19 |
+
"a dark photo of the {}",
|
20 |
+
"a photo of my {}",
|
21 |
+
"a photo of the cool {}",
|
22 |
+
"a close-up photo of a {}",
|
23 |
+
"a bright photo of the {}",
|
24 |
+
"a cropped photo of a {}",
|
25 |
+
"a photo of the {}",
|
26 |
+
"a good photo of the {}",
|
27 |
+
"a photo of one {}",
|
28 |
+
"a close-up photo of the {}",
|
29 |
+
"a rendition of the {}",
|
30 |
+
"a photo of the clean {}",
|
31 |
+
"a rendition of a {}",
|
32 |
+
"a photo of a nice {}",
|
33 |
+
"a good photo of a {}",
|
34 |
+
"a photo of the nice {}",
|
35 |
+
"a photo of the small {}",
|
36 |
+
"a photo of the weird {}",
|
37 |
+
"a photo of the large {}",
|
38 |
+
"a photo of a cool {}",
|
39 |
+
"a photo of a small {}",
|
40 |
+
]
|
41 |
+
|
42 |
+
STYLE_TEMPLATE = [
|
43 |
+
"a painting in the style of {}",
|
44 |
+
"a rendering in the style of {}",
|
45 |
+
"a cropped painting in the style of {}",
|
46 |
+
"the painting in the style of {}",
|
47 |
+
"a clean painting in the style of {}",
|
48 |
+
"a dirty painting in the style of {}",
|
49 |
+
"a dark painting in the style of {}",
|
50 |
+
"a picture in the style of {}",
|
51 |
+
"a cool painting in the style of {}",
|
52 |
+
"a close-up painting in the style of {}",
|
53 |
+
"a bright painting in the style of {}",
|
54 |
+
"a cropped painting in the style of {}",
|
55 |
+
"a good painting in the style of {}",
|
56 |
+
"a close-up painting in the style of {}",
|
57 |
+
"a rendition in the style of {}",
|
58 |
+
"a nice painting in the style of {}",
|
59 |
+
"a small painting in the style of {}",
|
60 |
+
"a weird painting in the style of {}",
|
61 |
+
"a large painting in the style of {}",
|
62 |
+
]
|
63 |
+
|
64 |
+
NULL_TEMPLATE = ["{}"]
|
65 |
+
|
66 |
+
TEMPLATE_MAP = {
|
67 |
+
"object": OBJECT_TEMPLATE,
|
68 |
+
"style": STYLE_TEMPLATE,
|
69 |
+
"null": NULL_TEMPLATE,
|
70 |
+
}
|
71 |
+
|
72 |
+
|
73 |
+
def _randomset(lis):
|
74 |
+
ret = []
|
75 |
+
for i in range(len(lis)):
|
76 |
+
if random.random() < 0.5:
|
77 |
+
ret.append(lis[i])
|
78 |
+
return ret
|
79 |
+
|
80 |
+
|
81 |
+
def _shuffle(lis):
|
82 |
+
|
83 |
+
return random.sample(lis, len(lis))
|
84 |
+
|
85 |
+
|
86 |
+
def _get_cutout_holes(
|
87 |
+
height,
|
88 |
+
width,
|
89 |
+
min_holes=8,
|
90 |
+
max_holes=32,
|
91 |
+
min_height=16,
|
92 |
+
max_height=128,
|
93 |
+
min_width=16,
|
94 |
+
max_width=128,
|
95 |
+
):
|
96 |
+
holes = []
|
97 |
+
for _n in range(random.randint(min_holes, max_holes)):
|
98 |
+
hole_height = random.randint(min_height, max_height)
|
99 |
+
hole_width = random.randint(min_width, max_width)
|
100 |
+
y1 = random.randint(0, height - hole_height)
|
101 |
+
x1 = random.randint(0, width - hole_width)
|
102 |
+
y2 = y1 + hole_height
|
103 |
+
x2 = x1 + hole_width
|
104 |
+
holes.append((x1, y1, x2, y2))
|
105 |
+
return holes
|
106 |
+
|
107 |
+
|
108 |
+
def _generate_random_mask(image):
|
109 |
+
mask = zeros_like(image[:1])
|
110 |
+
holes = _get_cutout_holes(mask.shape[1], mask.shape[2])
|
111 |
+
for (x1, y1, x2, y2) in holes:
|
112 |
+
mask[:, y1:y2, x1:x2] = 1.0
|
113 |
+
if random.uniform(0, 1) < 0.25:
|
114 |
+
mask.fill_(1.0)
|
115 |
+
masked_image = image * (mask < 0.5)
|
116 |
+
return mask, masked_image
|
117 |
+
|
118 |
+
|
119 |
+
class PivotalTuningDatasetCapation(Dataset):
|
120 |
+
"""
|
121 |
+
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
122 |
+
It pre-processes the images and the tokenizes prompts.
|
123 |
+
"""
|
124 |
+
|
125 |
+
def __init__(
|
126 |
+
self,
|
127 |
+
instance_data_root,
|
128 |
+
tokenizer,
|
129 |
+
token_map: Optional[dict] = None,
|
130 |
+
use_template: Optional[str] = None,
|
131 |
+
size=512,
|
132 |
+
h_flip=True,
|
133 |
+
color_jitter=False,
|
134 |
+
resize=True,
|
135 |
+
use_mask_captioned_data=False,
|
136 |
+
use_face_segmentation_condition=False,
|
137 |
+
train_inpainting=False,
|
138 |
+
blur_amount: int = 70,
|
139 |
+
):
|
140 |
+
self.size = size
|
141 |
+
self.tokenizer = tokenizer
|
142 |
+
self.resize = resize
|
143 |
+
self.train_inpainting = train_inpainting
|
144 |
+
|
145 |
+
instance_data_root = Path(instance_data_root)
|
146 |
+
if not instance_data_root.exists():
|
147 |
+
raise ValueError("Instance images root doesn't exists.")
|
148 |
+
|
149 |
+
self.instance_images_path = []
|
150 |
+
self.mask_path = []
|
151 |
+
|
152 |
+
assert not (
|
153 |
+
use_mask_captioned_data and use_template
|
154 |
+
), "Can't use both mask caption data and template."
|
155 |
+
|
156 |
+
# Prepare the instance images
|
157 |
+
if use_mask_captioned_data:
|
158 |
+
src_imgs = glob.glob(str(instance_data_root) + "/*src.jpg")
|
159 |
+
for f in src_imgs:
|
160 |
+
idx = int(str(Path(f).stem).split(".")[0])
|
161 |
+
mask_path = f"{instance_data_root}/{idx}.mask.png"
|
162 |
+
|
163 |
+
if Path(mask_path).exists():
|
164 |
+
self.instance_images_path.append(f)
|
165 |
+
self.mask_path.append(mask_path)
|
166 |
+
else:
|
167 |
+
print(f"Mask not found for {f}")
|
168 |
+
|
169 |
+
self.captions = open(f"{instance_data_root}/caption.txt").readlines()
|
170 |
+
|
171 |
+
else:
|
172 |
+
possibily_src_images = (
|
173 |
+
glob.glob(str(instance_data_root) + "/*.jpg")
|
174 |
+
+ glob.glob(str(instance_data_root) + "/*.png")
|
175 |
+
+ glob.glob(str(instance_data_root) + "/*.jpeg")
|
176 |
+
)
|
177 |
+
possibily_src_images = (
|
178 |
+
set(possibily_src_images)
|
179 |
+
- set(glob.glob(str(instance_data_root) + "/*mask.png"))
|
180 |
+
- set([str(instance_data_root) + "/caption.txt"])
|
181 |
+
)
|
182 |
+
|
183 |
+
self.instance_images_path = list(set(possibily_src_images))
|
184 |
+
self.captions = [
|
185 |
+
x.split("/")[-1].split(".")[0] for x in self.instance_images_path
|
186 |
+
]
|
187 |
+
|
188 |
+
assert (
|
189 |
+
len(self.instance_images_path) > 0
|
190 |
+
), "No images found in the instance data root."
|
191 |
+
|
192 |
+
self.instance_images_path = sorted(self.instance_images_path)
|
193 |
+
|
194 |
+
self.use_mask = use_face_segmentation_condition or use_mask_captioned_data
|
195 |
+
self.use_mask_captioned_data = use_mask_captioned_data
|
196 |
+
|
197 |
+
if use_face_segmentation_condition:
|
198 |
+
|
199 |
+
for idx in range(len(self.instance_images_path)):
|
200 |
+
targ = f"{instance_data_root}/{idx}.mask.png"
|
201 |
+
# see if the mask exists
|
202 |
+
if not Path(targ).exists():
|
203 |
+
print(f"Mask not found for {targ}")
|
204 |
+
|
205 |
+
print(
|
206 |
+
"Warning : this will pre-process all the images in the instance data root."
|
207 |
+
)
|
208 |
+
|
209 |
+
if len(self.mask_path) > 0:
|
210 |
+
print(
|
211 |
+
"Warning : masks already exists, but will be overwritten."
|
212 |
+
)
|
213 |
+
|
214 |
+
masks = face_mask_google_mediapipe(
|
215 |
+
[
|
216 |
+
Image.open(f).convert("RGB")
|
217 |
+
for f in self.instance_images_path
|
218 |
+
]
|
219 |
+
)
|
220 |
+
for idx, mask in enumerate(masks):
|
221 |
+
mask.save(f"{instance_data_root}/{idx}.mask.png")
|
222 |
+
|
223 |
+
break
|
224 |
+
|
225 |
+
for idx in range(len(self.instance_images_path)):
|
226 |
+
self.mask_path.append(f"{instance_data_root}/{idx}.mask.png")
|
227 |
+
|
228 |
+
self.num_instance_images = len(self.instance_images_path)
|
229 |
+
self.token_map = token_map
|
230 |
+
|
231 |
+
self.use_template = use_template
|
232 |
+
if use_template is not None:
|
233 |
+
self.templates = TEMPLATE_MAP[use_template]
|
234 |
+
|
235 |
+
self._length = self.num_instance_images
|
236 |
+
|
237 |
+
self.h_flip = h_flip
|
238 |
+
self.image_transforms = transforms.Compose(
|
239 |
+
[
|
240 |
+
transforms.Resize(
|
241 |
+
size, interpolation=transforms.InterpolationMode.BILINEAR
|
242 |
+
)
|
243 |
+
if resize
|
244 |
+
else transforms.Lambda(lambda x: x),
|
245 |
+
transforms.ColorJitter(0.1, 0.1)
|
246 |
+
if color_jitter
|
247 |
+
else transforms.Lambda(lambda x: x),
|
248 |
+
transforms.CenterCrop(size),
|
249 |
+
transforms.ToTensor(),
|
250 |
+
transforms.Normalize([0.5], [0.5]),
|
251 |
+
]
|
252 |
+
)
|
253 |
+
|
254 |
+
self.blur_amount = blur_amount
|
255 |
+
|
256 |
+
def __len__(self):
|
257 |
+
return self._length
|
258 |
+
|
259 |
+
def __getitem__(self, index):
|
260 |
+
example = {}
|
261 |
+
instance_image = Image.open(
|
262 |
+
self.instance_images_path[index % self.num_instance_images]
|
263 |
+
)
|
264 |
+
if not instance_image.mode == "RGB":
|
265 |
+
instance_image = instance_image.convert("RGB")
|
266 |
+
example["instance_images"] = self.image_transforms(instance_image)
|
267 |
+
|
268 |
+
if self.train_inpainting:
|
269 |
+
(
|
270 |
+
example["instance_masks"],
|
271 |
+
example["instance_masked_images"],
|
272 |
+
) = _generate_random_mask(example["instance_images"])
|
273 |
+
|
274 |
+
if self.use_template:
|
275 |
+
assert self.token_map is not None
|
276 |
+
input_tok = list(self.token_map.values())[0]
|
277 |
+
|
278 |
+
text = random.choice(self.templates).format(input_tok)
|
279 |
+
else:
|
280 |
+
text = self.captions[index % self.num_instance_images].strip()
|
281 |
+
|
282 |
+
if self.token_map is not None:
|
283 |
+
for token, value in self.token_map.items():
|
284 |
+
text = text.replace(token, value)
|
285 |
+
|
286 |
+
print(text)
|
287 |
+
|
288 |
+
if self.use_mask:
|
289 |
+
example["mask"] = (
|
290 |
+
self.image_transforms(
|
291 |
+
Image.open(self.mask_path[index % self.num_instance_images])
|
292 |
+
)
|
293 |
+
* 0.5
|
294 |
+
+ 1.0
|
295 |
+
)
|
296 |
+
|
297 |
+
if self.h_flip and random.random() > 0.5:
|
298 |
+
hflip = transforms.RandomHorizontalFlip(p=1)
|
299 |
+
|
300 |
+
example["instance_images"] = hflip(example["instance_images"])
|
301 |
+
if self.use_mask:
|
302 |
+
example["mask"] = hflip(example["mask"])
|
303 |
+
|
304 |
+
example["instance_prompt_ids"] = self.tokenizer(
|
305 |
+
text,
|
306 |
+
padding="do_not_pad",
|
307 |
+
truncation=True,
|
308 |
+
max_length=self.tokenizer.model_max_length,
|
309 |
+
).input_ids
|
310 |
+
|
311 |
+
return example
|
lora_diffusion/lora.py
ADDED
@@ -0,0 +1,1110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import math
|
3 |
+
from itertools import groupby
|
4 |
+
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import PIL
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
try:
|
13 |
+
from safetensors.torch import safe_open
|
14 |
+
from safetensors.torch import save_file as safe_save
|
15 |
+
|
16 |
+
safetensors_available = True
|
17 |
+
except ImportError:
|
18 |
+
from .safe_open import safe_open
|
19 |
+
|
20 |
+
def safe_save(
|
21 |
+
tensors: Dict[str, torch.Tensor],
|
22 |
+
filename: str,
|
23 |
+
metadata: Optional[Dict[str, str]] = None,
|
24 |
+
) -> None:
|
25 |
+
raise EnvironmentError(
|
26 |
+
"Saving safetensors requires the safetensors library. Please install with pip or similar."
|
27 |
+
)
|
28 |
+
|
29 |
+
safetensors_available = False
|
30 |
+
|
31 |
+
|
32 |
+
class LoraInjectedLinear(nn.Module):
|
33 |
+
def __init__(
|
34 |
+
self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0
|
35 |
+
):
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
if r > min(in_features, out_features):
|
39 |
+
raise ValueError(
|
40 |
+
f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
|
41 |
+
)
|
42 |
+
self.r = r
|
43 |
+
self.linear = nn.Linear(in_features, out_features, bias)
|
44 |
+
self.lora_down = nn.Linear(in_features, r, bias=False)
|
45 |
+
self.dropout = nn.Dropout(dropout_p)
|
46 |
+
self.lora_up = nn.Linear(r, out_features, bias=False)
|
47 |
+
self.scale = scale
|
48 |
+
self.selector = nn.Identity()
|
49 |
+
|
50 |
+
nn.init.normal_(self.lora_down.weight, std=1 / r)
|
51 |
+
nn.init.zeros_(self.lora_up.weight)
|
52 |
+
|
53 |
+
def forward(self, input):
|
54 |
+
return (
|
55 |
+
self.linear(input)
|
56 |
+
+ self.dropout(self.lora_up(self.selector(self.lora_down(input))))
|
57 |
+
* self.scale
|
58 |
+
)
|
59 |
+
|
60 |
+
def realize_as_lora(self):
|
61 |
+
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
|
62 |
+
|
63 |
+
def set_selector_from_diag(self, diag: torch.Tensor):
|
64 |
+
# diag is a 1D tensor of size (r,)
|
65 |
+
assert diag.shape == (self.r,)
|
66 |
+
self.selector = nn.Linear(self.r, self.r, bias=False)
|
67 |
+
self.selector.weight.data = torch.diag(diag)
|
68 |
+
self.selector.weight.data = self.selector.weight.data.to(
|
69 |
+
self.lora_up.weight.device
|
70 |
+
).to(self.lora_up.weight.dtype)
|
71 |
+
|
72 |
+
|
73 |
+
class LoraInjectedConv2d(nn.Module):
|
74 |
+
def __init__(
|
75 |
+
self,
|
76 |
+
in_channels: int,
|
77 |
+
out_channels: int,
|
78 |
+
kernel_size,
|
79 |
+
stride=1,
|
80 |
+
padding=0,
|
81 |
+
dilation=1,
|
82 |
+
groups: int = 1,
|
83 |
+
bias: bool = True,
|
84 |
+
r: int = 4,
|
85 |
+
dropout_p: float = 0.1,
|
86 |
+
scale: float = 1.0,
|
87 |
+
):
|
88 |
+
super().__init__()
|
89 |
+
if r > min(in_channels, out_channels):
|
90 |
+
raise ValueError(
|
91 |
+
f"LoRA rank {r} must be less or equal than {min(in_channels, out_channels)}"
|
92 |
+
)
|
93 |
+
self.r = r
|
94 |
+
self.conv = nn.Conv2d(
|
95 |
+
in_channels=in_channels,
|
96 |
+
out_channels=out_channels,
|
97 |
+
kernel_size=kernel_size,
|
98 |
+
stride=stride,
|
99 |
+
padding=padding,
|
100 |
+
dilation=dilation,
|
101 |
+
groups=groups,
|
102 |
+
bias=bias,
|
103 |
+
)
|
104 |
+
|
105 |
+
self.lora_down = nn.Conv2d(
|
106 |
+
in_channels=in_channels,
|
107 |
+
out_channels=r,
|
108 |
+
kernel_size=kernel_size,
|
109 |
+
stride=stride,
|
110 |
+
padding=padding,
|
111 |
+
dilation=dilation,
|
112 |
+
groups=groups,
|
113 |
+
bias=False,
|
114 |
+
)
|
115 |
+
self.dropout = nn.Dropout(dropout_p)
|
116 |
+
self.lora_up = nn.Conv2d(
|
117 |
+
in_channels=r,
|
118 |
+
out_channels=out_channels,
|
119 |
+
kernel_size=1,
|
120 |
+
stride=1,
|
121 |
+
padding=0,
|
122 |
+
bias=False,
|
123 |
+
)
|
124 |
+
self.selector = nn.Identity()
|
125 |
+
self.scale = scale
|
126 |
+
|
127 |
+
nn.init.normal_(self.lora_down.weight, std=1 / r)
|
128 |
+
nn.init.zeros_(self.lora_up.weight)
|
129 |
+
|
130 |
+
def forward(self, input):
|
131 |
+
return (
|
132 |
+
self.conv(input)
|
133 |
+
+ self.dropout(self.lora_up(self.selector(self.lora_down(input))))
|
134 |
+
* self.scale
|
135 |
+
)
|
136 |
+
|
137 |
+
def realize_as_lora(self):
|
138 |
+
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
|
139 |
+
|
140 |
+
def set_selector_from_diag(self, diag: torch.Tensor):
|
141 |
+
# diag is a 1D tensor of size (r,)
|
142 |
+
assert diag.shape == (self.r,)
|
143 |
+
self.selector = nn.Conv2d(
|
144 |
+
in_channels=self.r,
|
145 |
+
out_channels=self.r,
|
146 |
+
kernel_size=1,
|
147 |
+
stride=1,
|
148 |
+
padding=0,
|
149 |
+
bias=False,
|
150 |
+
)
|
151 |
+
self.selector.weight.data = torch.diag(diag)
|
152 |
+
|
153 |
+
# same device + dtype as lora_up
|
154 |
+
self.selector.weight.data = self.selector.weight.data.to(
|
155 |
+
self.lora_up.weight.device
|
156 |
+
).to(self.lora_up.weight.dtype)
|
157 |
+
|
158 |
+
|
159 |
+
UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}
|
160 |
+
|
161 |
+
UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU"}
|
162 |
+
|
163 |
+
TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"}
|
164 |
+
|
165 |
+
TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"}
|
166 |
+
|
167 |
+
DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE
|
168 |
+
|
169 |
+
EMBED_FLAG = "<embed>"
|
170 |
+
|
171 |
+
|
172 |
+
def _find_children(
|
173 |
+
model,
|
174 |
+
search_class: List[Type[nn.Module]] = [nn.Linear],
|
175 |
+
):
|
176 |
+
"""
|
177 |
+
Find all modules of a certain class (or union of classes).
|
178 |
+
|
179 |
+
Returns all matching modules, along with the parent of those moduless and the
|
180 |
+
names they are referenced by.
|
181 |
+
"""
|
182 |
+
# For each target find every linear_class module that isn't a child of a LoraInjectedLinear
|
183 |
+
for parent in model.modules():
|
184 |
+
for name, module in parent.named_children():
|
185 |
+
if any([isinstance(module, _class) for _class in search_class]):
|
186 |
+
yield parent, name, module
|
187 |
+
|
188 |
+
|
189 |
+
def _find_modules_v2(
|
190 |
+
model,
|
191 |
+
ancestor_class: Optional[Set[str]] = None,
|
192 |
+
search_class: List[Type[nn.Module]] = [nn.Linear],
|
193 |
+
exclude_children_of: Optional[List[Type[nn.Module]]] = [
|
194 |
+
LoraInjectedLinear,
|
195 |
+
LoraInjectedConv2d,
|
196 |
+
],
|
197 |
+
):
|
198 |
+
"""
|
199 |
+
Find all modules of a certain class (or union of classes) that are direct or
|
200 |
+
indirect descendants of other modules of a certain class (or union of classes).
|
201 |
+
|
202 |
+
Returns all matching modules, along with the parent of those moduless and the
|
203 |
+
names they are referenced by.
|
204 |
+
"""
|
205 |
+
|
206 |
+
# Get the targets we should replace all linears under
|
207 |
+
if ancestor_class is not None:
|
208 |
+
ancestors = (
|
209 |
+
module
|
210 |
+
for module in model.modules()
|
211 |
+
if module.__class__.__name__ in ancestor_class
|
212 |
+
)
|
213 |
+
else:
|
214 |
+
# this, incase you want to naively iterate over all modules.
|
215 |
+
ancestors = [module for module in model.modules()]
|
216 |
+
|
217 |
+
# For each target find every linear_class module that isn't a child of a LoraInjectedLinear
|
218 |
+
for ancestor in ancestors:
|
219 |
+
for fullname, module in ancestor.named_modules():
|
220 |
+
if any([isinstance(module, _class) for _class in search_class]):
|
221 |
+
# Find the direct parent if this is a descendant, not a child, of target
|
222 |
+
*path, name = fullname.split(".")
|
223 |
+
parent = ancestor
|
224 |
+
while path:
|
225 |
+
parent = parent.get_submodule(path.pop(0))
|
226 |
+
# Skip this linear if it's a child of a LoraInjectedLinear
|
227 |
+
if exclude_children_of and any(
|
228 |
+
[isinstance(parent, _class) for _class in exclude_children_of]
|
229 |
+
):
|
230 |
+
continue
|
231 |
+
# Otherwise, yield it
|
232 |
+
yield parent, name, module
|
233 |
+
|
234 |
+
|
235 |
+
def _find_modules_old(
|
236 |
+
model,
|
237 |
+
ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE,
|
238 |
+
search_class: List[Type[nn.Module]] = [nn.Linear],
|
239 |
+
exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraInjectedLinear],
|
240 |
+
):
|
241 |
+
ret = []
|
242 |
+
for _module in model.modules():
|
243 |
+
if _module.__class__.__name__ in ancestor_class:
|
244 |
+
|
245 |
+
for name, _child_module in _module.named_modules():
|
246 |
+
if _child_module.__class__ in search_class:
|
247 |
+
ret.append((_module, name, _child_module))
|
248 |
+
print(ret)
|
249 |
+
return ret
|
250 |
+
|
251 |
+
|
252 |
+
_find_modules = _find_modules_v2
|
253 |
+
|
254 |
+
|
255 |
+
def inject_trainable_lora(
|
256 |
+
model: nn.Module,
|
257 |
+
target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE,
|
258 |
+
r: int = 4,
|
259 |
+
loras=None, # path to lora .pt
|
260 |
+
verbose: bool = False,
|
261 |
+
dropout_p: float = 0.0,
|
262 |
+
scale: float = 1.0,
|
263 |
+
):
|
264 |
+
"""
|
265 |
+
inject lora into model, and returns lora parameter groups.
|
266 |
+
"""
|
267 |
+
|
268 |
+
require_grad_params = []
|
269 |
+
names = []
|
270 |
+
|
271 |
+
if loras != None:
|
272 |
+
loras = torch.load(loras)
|
273 |
+
|
274 |
+
for _module, name, _child_module in _find_modules(
|
275 |
+
model, target_replace_module, search_class=[nn.Linear]
|
276 |
+
):
|
277 |
+
weight = _child_module.weight
|
278 |
+
bias = _child_module.bias
|
279 |
+
if verbose:
|
280 |
+
print("LoRA Injection : injecting lora into ", name)
|
281 |
+
print("LoRA Injection : weight shape", weight.shape)
|
282 |
+
_tmp = LoraInjectedLinear(
|
283 |
+
_child_module.in_features,
|
284 |
+
_child_module.out_features,
|
285 |
+
_child_module.bias is not None,
|
286 |
+
r=r,
|
287 |
+
dropout_p=dropout_p,
|
288 |
+
scale=scale,
|
289 |
+
)
|
290 |
+
_tmp.linear.weight = weight
|
291 |
+
if bias is not None:
|
292 |
+
_tmp.linear.bias = bias
|
293 |
+
|
294 |
+
# switch the module
|
295 |
+
_tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
|
296 |
+
_module._modules[name] = _tmp
|
297 |
+
|
298 |
+
require_grad_params.append(_module._modules[name].lora_up.parameters())
|
299 |
+
require_grad_params.append(_module._modules[name].lora_down.parameters())
|
300 |
+
|
301 |
+
if loras != None:
|
302 |
+
_module._modules[name].lora_up.weight = loras.pop(0)
|
303 |
+
_module._modules[name].lora_down.weight = loras.pop(0)
|
304 |
+
|
305 |
+
_module._modules[name].lora_up.weight.requires_grad = True
|
306 |
+
_module._modules[name].lora_down.weight.requires_grad = True
|
307 |
+
names.append(name)
|
308 |
+
|
309 |
+
return require_grad_params, names
|
310 |
+
|
311 |
+
|
312 |
+
def inject_trainable_lora_extended(
|
313 |
+
model: nn.Module,
|
314 |
+
target_replace_module: Set[str] = UNET_EXTENDED_TARGET_REPLACE,
|
315 |
+
r: int = 4,
|
316 |
+
loras=None, # path to lora .pt
|
317 |
+
):
|
318 |
+
"""
|
319 |
+
inject lora into model, and returns lora parameter groups.
|
320 |
+
"""
|
321 |
+
|
322 |
+
require_grad_params = []
|
323 |
+
names = []
|
324 |
+
|
325 |
+
if loras != None:
|
326 |
+
loras = torch.load(loras)
|
327 |
+
|
328 |
+
for _module, name, _child_module in _find_modules(
|
329 |
+
model, target_replace_module, search_class=[nn.Linear, nn.Conv2d]
|
330 |
+
):
|
331 |
+
if _child_module.__class__ == nn.Linear:
|
332 |
+
weight = _child_module.weight
|
333 |
+
bias = _child_module.bias
|
334 |
+
_tmp = LoraInjectedLinear(
|
335 |
+
_child_module.in_features,
|
336 |
+
_child_module.out_features,
|
337 |
+
_child_module.bias is not None,
|
338 |
+
r=r,
|
339 |
+
)
|
340 |
+
_tmp.linear.weight = weight
|
341 |
+
if bias is not None:
|
342 |
+
_tmp.linear.bias = bias
|
343 |
+
elif _child_module.__class__ == nn.Conv2d:
|
344 |
+
weight = _child_module.weight
|
345 |
+
bias = _child_module.bias
|
346 |
+
_tmp = LoraInjectedConv2d(
|
347 |
+
_child_module.in_channels,
|
348 |
+
_child_module.out_channels,
|
349 |
+
_child_module.kernel_size,
|
350 |
+
_child_module.stride,
|
351 |
+
_child_module.padding,
|
352 |
+
_child_module.dilation,
|
353 |
+
_child_module.groups,
|
354 |
+
_child_module.bias is not None,
|
355 |
+
r=r,
|
356 |
+
)
|
357 |
+
|
358 |
+
_tmp.conv.weight = weight
|
359 |
+
if bias is not None:
|
360 |
+
_tmp.conv.bias = bias
|
361 |
+
|
362 |
+
# switch the module
|
363 |
+
_tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
|
364 |
+
if bias is not None:
|
365 |
+
_tmp.to(_child_module.bias.device).to(_child_module.bias.dtype)
|
366 |
+
|
367 |
+
_module._modules[name] = _tmp
|
368 |
+
|
369 |
+
require_grad_params.append(_module._modules[name].lora_up.parameters())
|
370 |
+
require_grad_params.append(_module._modules[name].lora_down.parameters())
|
371 |
+
|
372 |
+
if loras != None:
|
373 |
+
_module._modules[name].lora_up.weight = loras.pop(0)
|
374 |
+
_module._modules[name].lora_down.weight = loras.pop(0)
|
375 |
+
|
376 |
+
_module._modules[name].lora_up.weight.requires_grad = True
|
377 |
+
_module._modules[name].lora_down.weight.requires_grad = True
|
378 |
+
names.append(name)
|
379 |
+
|
380 |
+
return require_grad_params, names
|
381 |
+
|
382 |
+
|
383 |
+
def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE):
|
384 |
+
|
385 |
+
loras = []
|
386 |
+
|
387 |
+
for _m, _n, _child_module in _find_modules(
|
388 |
+
model,
|
389 |
+
target_replace_module,
|
390 |
+
search_class=[LoraInjectedLinear, LoraInjectedConv2d],
|
391 |
+
):
|
392 |
+
loras.append((_child_module.lora_up, _child_module.lora_down))
|
393 |
+
|
394 |
+
if len(loras) == 0:
|
395 |
+
raise ValueError("No lora injected.")
|
396 |
+
|
397 |
+
return loras
|
398 |
+
|
399 |
+
|
400 |
+
def extract_lora_as_tensor(
|
401 |
+
model, target_replace_module=DEFAULT_TARGET_REPLACE, as_fp16=True
|
402 |
+
):
|
403 |
+
|
404 |
+
loras = []
|
405 |
+
|
406 |
+
for _m, _n, _child_module in _find_modules(
|
407 |
+
model,
|
408 |
+
target_replace_module,
|
409 |
+
search_class=[LoraInjectedLinear, LoraInjectedConv2d],
|
410 |
+
):
|
411 |
+
up, down = _child_module.realize_as_lora()
|
412 |
+
if as_fp16:
|
413 |
+
up = up.to(torch.float16)
|
414 |
+
down = down.to(torch.float16)
|
415 |
+
|
416 |
+
loras.append((up, down))
|
417 |
+
|
418 |
+
if len(loras) == 0:
|
419 |
+
raise ValueError("No lora injected.")
|
420 |
+
|
421 |
+
return loras
|
422 |
+
|
423 |
+
|
424 |
+
def save_lora_weight(
|
425 |
+
model,
|
426 |
+
path="./lora.pt",
|
427 |
+
target_replace_module=DEFAULT_TARGET_REPLACE,
|
428 |
+
):
|
429 |
+
weights = []
|
430 |
+
for _up, _down in extract_lora_ups_down(
|
431 |
+
model, target_replace_module=target_replace_module
|
432 |
+
):
|
433 |
+
weights.append(_up.weight.to("cpu").to(torch.float16))
|
434 |
+
weights.append(_down.weight.to("cpu").to(torch.float16))
|
435 |
+
|
436 |
+
torch.save(weights, path)
|
437 |
+
|
438 |
+
|
439 |
+
def save_lora_as_json(model, path="./lora.json"):
|
440 |
+
weights = []
|
441 |
+
for _up, _down in extract_lora_ups_down(model):
|
442 |
+
weights.append(_up.weight.detach().cpu().numpy().tolist())
|
443 |
+
weights.append(_down.weight.detach().cpu().numpy().tolist())
|
444 |
+
|
445 |
+
import json
|
446 |
+
|
447 |
+
with open(path, "w") as f:
|
448 |
+
json.dump(weights, f)
|
449 |
+
|
450 |
+
|
451 |
+
def save_safeloras_with_embeds(
|
452 |
+
modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
|
453 |
+
embeds: Dict[str, torch.Tensor] = {},
|
454 |
+
outpath="./lora.safetensors",
|
455 |
+
):
|
456 |
+
"""
|
457 |
+
Saves the Lora from multiple modules in a single safetensor file.
|
458 |
+
|
459 |
+
modelmap is a dictionary of {
|
460 |
+
"module name": (module, target_replace_module)
|
461 |
+
}
|
462 |
+
"""
|
463 |
+
weights = {}
|
464 |
+
metadata = {}
|
465 |
+
|
466 |
+
for name, (model, target_replace_module) in modelmap.items():
|
467 |
+
metadata[name] = json.dumps(list(target_replace_module))
|
468 |
+
|
469 |
+
for i, (_up, _down) in enumerate(
|
470 |
+
extract_lora_as_tensor(model, target_replace_module)
|
471 |
+
):
|
472 |
+
rank = _down.shape[0]
|
473 |
+
|
474 |
+
metadata[f"{name}:{i}:rank"] = str(rank)
|
475 |
+
weights[f"{name}:{i}:up"] = _up
|
476 |
+
weights[f"{name}:{i}:down"] = _down
|
477 |
+
|
478 |
+
for token, tensor in embeds.items():
|
479 |
+
metadata[token] = EMBED_FLAG
|
480 |
+
weights[token] = tensor
|
481 |
+
|
482 |
+
print(f"Saving weights to {outpath}")
|
483 |
+
safe_save(weights, outpath, metadata)
|
484 |
+
|
485 |
+
|
486 |
+
def save_safeloras(
|
487 |
+
modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
|
488 |
+
outpath="./lora.safetensors",
|
489 |
+
):
|
490 |
+
return save_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
|
491 |
+
|
492 |
+
|
493 |
+
def convert_loras_to_safeloras_with_embeds(
|
494 |
+
modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
|
495 |
+
embeds: Dict[str, torch.Tensor] = {},
|
496 |
+
outpath="./lora.safetensors",
|
497 |
+
):
|
498 |
+
"""
|
499 |
+
Converts the Lora from multiple pytorch .pt files into a single safetensor file.
|
500 |
+
|
501 |
+
modelmap is a dictionary of {
|
502 |
+
"module name": (pytorch_model_path, target_replace_module, rank)
|
503 |
+
}
|
504 |
+
"""
|
505 |
+
|
506 |
+
weights = {}
|
507 |
+
metadata = {}
|
508 |
+
|
509 |
+
for name, (path, target_replace_module, r) in modelmap.items():
|
510 |
+
metadata[name] = json.dumps(list(target_replace_module))
|
511 |
+
|
512 |
+
lora = torch.load(path)
|
513 |
+
for i, weight in enumerate(lora):
|
514 |
+
is_up = i % 2 == 0
|
515 |
+
i = i // 2
|
516 |
+
|
517 |
+
if is_up:
|
518 |
+
metadata[f"{name}:{i}:rank"] = str(r)
|
519 |
+
weights[f"{name}:{i}:up"] = weight
|
520 |
+
else:
|
521 |
+
weights[f"{name}:{i}:down"] = weight
|
522 |
+
|
523 |
+
for token, tensor in embeds.items():
|
524 |
+
metadata[token] = EMBED_FLAG
|
525 |
+
weights[token] = tensor
|
526 |
+
|
527 |
+
print(f"Saving weights to {outpath}")
|
528 |
+
safe_save(weights, outpath, metadata)
|
529 |
+
|
530 |
+
|
531 |
+
def convert_loras_to_safeloras(
|
532 |
+
modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
|
533 |
+
outpath="./lora.safetensors",
|
534 |
+
):
|
535 |
+
convert_loras_to_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
|
536 |
+
|
537 |
+
|
538 |
+
def parse_safeloras(
|
539 |
+
safeloras,
|
540 |
+
) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]:
|
541 |
+
"""
|
542 |
+
Converts a loaded safetensor file that contains a set of module Loras
|
543 |
+
into Parameters and other information
|
544 |
+
|
545 |
+
Output is a dictionary of {
|
546 |
+
"module name": (
|
547 |
+
[list of weights],
|
548 |
+
[list of ranks],
|
549 |
+
target_replacement_modules
|
550 |
+
)
|
551 |
+
}
|
552 |
+
"""
|
553 |
+
loras = {}
|
554 |
+
metadata = safeloras.metadata()
|
555 |
+
|
556 |
+
get_name = lambda k: k.split(":")[0]
|
557 |
+
|
558 |
+
keys = list(safeloras.keys())
|
559 |
+
keys.sort(key=get_name)
|
560 |
+
|
561 |
+
for name, module_keys in groupby(keys, get_name):
|
562 |
+
info = metadata.get(name)
|
563 |
+
|
564 |
+
if not info:
|
565 |
+
raise ValueError(
|
566 |
+
f"Tensor {name} has no metadata - is this a Lora safetensor?"
|
567 |
+
)
|
568 |
+
|
569 |
+
# Skip Textual Inversion embeds
|
570 |
+
if info == EMBED_FLAG:
|
571 |
+
continue
|
572 |
+
|
573 |
+
# Handle Loras
|
574 |
+
# Extract the targets
|
575 |
+
target = json.loads(info)
|
576 |
+
|
577 |
+
# Build the result lists - Python needs us to preallocate lists to insert into them
|
578 |
+
module_keys = list(module_keys)
|
579 |
+
ranks = [4] * (len(module_keys) // 2)
|
580 |
+
weights = [None] * len(module_keys)
|
581 |
+
|
582 |
+
for key in module_keys:
|
583 |
+
# Split the model name and index out of the key
|
584 |
+
_, idx, direction = key.split(":")
|
585 |
+
idx = int(idx)
|
586 |
+
|
587 |
+
# Add the rank
|
588 |
+
ranks[idx] = int(metadata[f"{name}:{idx}:rank"])
|
589 |
+
|
590 |
+
# Insert the weight into the list
|
591 |
+
idx = idx * 2 + (1 if direction == "down" else 0)
|
592 |
+
weights[idx] = nn.parameter.Parameter(safeloras.get_tensor(key))
|
593 |
+
|
594 |
+
loras[name] = (weights, ranks, target)
|
595 |
+
|
596 |
+
return loras
|
597 |
+
|
598 |
+
|
599 |
+
def parse_safeloras_embeds(
|
600 |
+
safeloras,
|
601 |
+
) -> Dict[str, torch.Tensor]:
|
602 |
+
"""
|
603 |
+
Converts a loaded safetensor file that contains Textual Inversion embeds into
|
604 |
+
a dictionary of embed_token: Tensor
|
605 |
+
"""
|
606 |
+
embeds = {}
|
607 |
+
metadata = safeloras.metadata()
|
608 |
+
|
609 |
+
for key in safeloras.keys():
|
610 |
+
# Only handle Textual Inversion embeds
|
611 |
+
meta = metadata.get(key)
|
612 |
+
if not meta or meta != EMBED_FLAG:
|
613 |
+
continue
|
614 |
+
|
615 |
+
embeds[key] = safeloras.get_tensor(key)
|
616 |
+
|
617 |
+
return embeds
|
618 |
+
|
619 |
+
|
620 |
+
def load_safeloras(path, device="cpu"):
|
621 |
+
safeloras = safe_open(path, framework="pt", device=device)
|
622 |
+
return parse_safeloras(safeloras)
|
623 |
+
|
624 |
+
|
625 |
+
def load_safeloras_embeds(path, device="cpu"):
|
626 |
+
safeloras = safe_open(path, framework="pt", device=device)
|
627 |
+
return parse_safeloras_embeds(safeloras)
|
628 |
+
|
629 |
+
|
630 |
+
def load_safeloras_both(path, device="cpu"):
|
631 |
+
safeloras = safe_open(path, framework="pt", device=device)
|
632 |
+
return parse_safeloras(safeloras), parse_safeloras_embeds(safeloras)
|
633 |
+
|
634 |
+
|
635 |
+
def collapse_lora(model, alpha=1.0):
|
636 |
+
|
637 |
+
for _module, name, _child_module in _find_modules(
|
638 |
+
model,
|
639 |
+
UNET_EXTENDED_TARGET_REPLACE | TEXT_ENCODER_EXTENDED_TARGET_REPLACE,
|
640 |
+
search_class=[LoraInjectedLinear, LoraInjectedConv2d],
|
641 |
+
):
|
642 |
+
|
643 |
+
if isinstance(_child_module, LoraInjectedLinear):
|
644 |
+
print("Collapsing Lin Lora in", name)
|
645 |
+
|
646 |
+
_child_module.linear.weight = nn.Parameter(
|
647 |
+
_child_module.linear.weight.data
|
648 |
+
+ alpha
|
649 |
+
* (
|
650 |
+
_child_module.lora_up.weight.data
|
651 |
+
@ _child_module.lora_down.weight.data
|
652 |
+
)
|
653 |
+
.type(_child_module.linear.weight.dtype)
|
654 |
+
.to(_child_module.linear.weight.device)
|
655 |
+
)
|
656 |
+
|
657 |
+
else:
|
658 |
+
print("Collapsing Conv Lora in", name)
|
659 |
+
_child_module.conv.weight = nn.Parameter(
|
660 |
+
_child_module.conv.weight.data
|
661 |
+
+ alpha
|
662 |
+
* (
|
663 |
+
_child_module.lora_up.weight.data.flatten(start_dim=1)
|
664 |
+
@ _child_module.lora_down.weight.data.flatten(start_dim=1)
|
665 |
+
)
|
666 |
+
.reshape(_child_module.conv.weight.data.shape)
|
667 |
+
.type(_child_module.conv.weight.dtype)
|
668 |
+
.to(_child_module.conv.weight.device)
|
669 |
+
)
|
670 |
+
|
671 |
+
|
672 |
+
def monkeypatch_or_replace_lora(
|
673 |
+
model,
|
674 |
+
loras,
|
675 |
+
target_replace_module=DEFAULT_TARGET_REPLACE,
|
676 |
+
r: Union[int, List[int]] = 4,
|
677 |
+
):
|
678 |
+
for _module, name, _child_module in _find_modules(
|
679 |
+
model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear]
|
680 |
+
):
|
681 |
+
_source = (
|
682 |
+
_child_module.linear
|
683 |
+
if isinstance(_child_module, LoraInjectedLinear)
|
684 |
+
else _child_module
|
685 |
+
)
|
686 |
+
|
687 |
+
weight = _source.weight
|
688 |
+
bias = _source.bias
|
689 |
+
_tmp = LoraInjectedLinear(
|
690 |
+
_source.in_features,
|
691 |
+
_source.out_features,
|
692 |
+
_source.bias is not None,
|
693 |
+
r=r.pop(0) if isinstance(r, list) else r,
|
694 |
+
)
|
695 |
+
_tmp.linear.weight = weight
|
696 |
+
|
697 |
+
if bias is not None:
|
698 |
+
_tmp.linear.bias = bias
|
699 |
+
|
700 |
+
# switch the module
|
701 |
+
_module._modules[name] = _tmp
|
702 |
+
|
703 |
+
up_weight = loras.pop(0)
|
704 |
+
down_weight = loras.pop(0)
|
705 |
+
|
706 |
+
_module._modules[name].lora_up.weight = nn.Parameter(
|
707 |
+
up_weight.type(weight.dtype)
|
708 |
+
)
|
709 |
+
_module._modules[name].lora_down.weight = nn.Parameter(
|
710 |
+
down_weight.type(weight.dtype)
|
711 |
+
)
|
712 |
+
|
713 |
+
_module._modules[name].to(weight.device)
|
714 |
+
|
715 |
+
|
716 |
+
def monkeypatch_or_replace_lora_extended(
|
717 |
+
model,
|
718 |
+
loras,
|
719 |
+
target_replace_module=DEFAULT_TARGET_REPLACE,
|
720 |
+
r: Union[int, List[int]] = 4,
|
721 |
+
):
|
722 |
+
for _module, name, _child_module in _find_modules(
|
723 |
+
model,
|
724 |
+
target_replace_module,
|
725 |
+
search_class=[nn.Linear, LoraInjectedLinear, nn.Conv2d, LoraInjectedConv2d],
|
726 |
+
):
|
727 |
+
|
728 |
+
if (_child_module.__class__ == nn.Linear) or (
|
729 |
+
_child_module.__class__ == LoraInjectedLinear
|
730 |
+
):
|
731 |
+
if len(loras[0].shape) != 2:
|
732 |
+
continue
|
733 |
+
|
734 |
+
_source = (
|
735 |
+
_child_module.linear
|
736 |
+
if isinstance(_child_module, LoraInjectedLinear)
|
737 |
+
else _child_module
|
738 |
+
)
|
739 |
+
|
740 |
+
weight = _source.weight
|
741 |
+
bias = _source.bias
|
742 |
+
_tmp = LoraInjectedLinear(
|
743 |
+
_source.in_features,
|
744 |
+
_source.out_features,
|
745 |
+
_source.bias is not None,
|
746 |
+
r=r.pop(0) if isinstance(r, list) else r,
|
747 |
+
)
|
748 |
+
_tmp.linear.weight = weight
|
749 |
+
|
750 |
+
if bias is not None:
|
751 |
+
_tmp.linear.bias = bias
|
752 |
+
|
753 |
+
elif (_child_module.__class__ == nn.Conv2d) or (
|
754 |
+
_child_module.__class__ == LoraInjectedConv2d
|
755 |
+
):
|
756 |
+
if len(loras[0].shape) != 4:
|
757 |
+
continue
|
758 |
+
_source = (
|
759 |
+
_child_module.conv
|
760 |
+
if isinstance(_child_module, LoraInjectedConv2d)
|
761 |
+
else _child_module
|
762 |
+
)
|
763 |
+
|
764 |
+
weight = _source.weight
|
765 |
+
bias = _source.bias
|
766 |
+
_tmp = LoraInjectedConv2d(
|
767 |
+
_source.in_channels,
|
768 |
+
_source.out_channels,
|
769 |
+
_source.kernel_size,
|
770 |
+
_source.stride,
|
771 |
+
_source.padding,
|
772 |
+
_source.dilation,
|
773 |
+
_source.groups,
|
774 |
+
_source.bias is not None,
|
775 |
+
r=r.pop(0) if isinstance(r, list) else r,
|
776 |
+
)
|
777 |
+
|
778 |
+
_tmp.conv.weight = weight
|
779 |
+
|
780 |
+
if bias is not None:
|
781 |
+
_tmp.conv.bias = bias
|
782 |
+
|
783 |
+
# switch the module
|
784 |
+
_module._modules[name] = _tmp
|
785 |
+
|
786 |
+
up_weight = loras.pop(0)
|
787 |
+
down_weight = loras.pop(0)
|
788 |
+
|
789 |
+
_module._modules[name].lora_up.weight = nn.Parameter(
|
790 |
+
up_weight.type(weight.dtype)
|
791 |
+
)
|
792 |
+
_module._modules[name].lora_down.weight = nn.Parameter(
|
793 |
+
down_weight.type(weight.dtype)
|
794 |
+
)
|
795 |
+
|
796 |
+
_module._modules[name].to(weight.device)
|
797 |
+
|
798 |
+
|
799 |
+
def monkeypatch_or_replace_safeloras(models, safeloras):
|
800 |
+
loras = parse_safeloras(safeloras)
|
801 |
+
|
802 |
+
for name, (lora, ranks, target) in loras.items():
|
803 |
+
model = getattr(models, name, None)
|
804 |
+
|
805 |
+
if not model:
|
806 |
+
print(f"No model provided for {name}, contained in Lora")
|
807 |
+
continue
|
808 |
+
|
809 |
+
monkeypatch_or_replace_lora_extended(model, lora, target, ranks)
|
810 |
+
|
811 |
+
|
812 |
+
def monkeypatch_remove_lora(model):
|
813 |
+
for _module, name, _child_module in _find_modules(
|
814 |
+
model, search_class=[LoraInjectedLinear, LoraInjectedConv2d]
|
815 |
+
):
|
816 |
+
if isinstance(_child_module, LoraInjectedLinear):
|
817 |
+
_source = _child_module.linear
|
818 |
+
weight, bias = _source.weight, _source.bias
|
819 |
+
|
820 |
+
_tmp = nn.Linear(
|
821 |
+
_source.in_features, _source.out_features, bias is not None
|
822 |
+
)
|
823 |
+
|
824 |
+
_tmp.weight = weight
|
825 |
+
if bias is not None:
|
826 |
+
_tmp.bias = bias
|
827 |
+
|
828 |
+
else:
|
829 |
+
_source = _child_module.conv
|
830 |
+
weight, bias = _source.weight, _source.bias
|
831 |
+
|
832 |
+
_tmp = nn.Conv2d(
|
833 |
+
in_channels=_source.in_channels,
|
834 |
+
out_channels=_source.out_channels,
|
835 |
+
kernel_size=_source.kernel_size,
|
836 |
+
stride=_source.stride,
|
837 |
+
padding=_source.padding,
|
838 |
+
dilation=_source.dilation,
|
839 |
+
groups=_source.groups,
|
840 |
+
bias=bias is not None,
|
841 |
+
)
|
842 |
+
|
843 |
+
_tmp.weight = weight
|
844 |
+
if bias is not None:
|
845 |
+
_tmp.bias = bias
|
846 |
+
|
847 |
+
_module._modules[name] = _tmp
|
848 |
+
|
849 |
+
|
850 |
+
def monkeypatch_add_lora(
|
851 |
+
model,
|
852 |
+
loras,
|
853 |
+
target_replace_module=DEFAULT_TARGET_REPLACE,
|
854 |
+
alpha: float = 1.0,
|
855 |
+
beta: float = 1.0,
|
856 |
+
):
|
857 |
+
for _module, name, _child_module in _find_modules(
|
858 |
+
model, target_replace_module, search_class=[LoraInjectedLinear]
|
859 |
+
):
|
860 |
+
weight = _child_module.linear.weight
|
861 |
+
|
862 |
+
up_weight = loras.pop(0)
|
863 |
+
down_weight = loras.pop(0)
|
864 |
+
|
865 |
+
_module._modules[name].lora_up.weight = nn.Parameter(
|
866 |
+
up_weight.type(weight.dtype).to(weight.device) * alpha
|
867 |
+
+ _module._modules[name].lora_up.weight.to(weight.device) * beta
|
868 |
+
)
|
869 |
+
_module._modules[name].lora_down.weight = nn.Parameter(
|
870 |
+
down_weight.type(weight.dtype).to(weight.device) * alpha
|
871 |
+
+ _module._modules[name].lora_down.weight.to(weight.device) * beta
|
872 |
+
)
|
873 |
+
|
874 |
+
_module._modules[name].to(weight.device)
|
875 |
+
|
876 |
+
|
877 |
+
def tune_lora_scale(model, alpha: float = 1.0):
|
878 |
+
for _module in model.modules():
|
879 |
+
if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
|
880 |
+
_module.scale = alpha
|
881 |
+
|
882 |
+
|
883 |
+
def set_lora_diag(model, diag: torch.Tensor):
|
884 |
+
for _module in model.modules():
|
885 |
+
if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
|
886 |
+
_module.set_selector_from_diag(diag)
|
887 |
+
|
888 |
+
|
889 |
+
def _text_lora_path(path: str) -> str:
|
890 |
+
assert path.endswith(".pt"), "Only .pt files are supported"
|
891 |
+
return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
|
892 |
+
|
893 |
+
|
894 |
+
def _ti_lora_path(path: str) -> str:
|
895 |
+
assert path.endswith(".pt"), "Only .pt files are supported"
|
896 |
+
return ".".join(path.split(".")[:-1] + ["ti", "pt"])
|
897 |
+
|
898 |
+
|
899 |
+
def apply_learned_embed_in_clip(
|
900 |
+
learned_embeds,
|
901 |
+
text_encoder,
|
902 |
+
tokenizer,
|
903 |
+
token: Optional[Union[str, List[str]]] = None,
|
904 |
+
idempotent=False,
|
905 |
+
):
|
906 |
+
if isinstance(token, str):
|
907 |
+
trained_tokens = [token]
|
908 |
+
elif isinstance(token, list):
|
909 |
+
assert len(learned_embeds.keys()) == len(
|
910 |
+
token
|
911 |
+
), "The number of tokens and the number of embeds should be the same"
|
912 |
+
trained_tokens = token
|
913 |
+
else:
|
914 |
+
trained_tokens = list(learned_embeds.keys())
|
915 |
+
|
916 |
+
for token in trained_tokens:
|
917 |
+
print(token)
|
918 |
+
embeds = learned_embeds[token]
|
919 |
+
|
920 |
+
# cast to dtype of text_encoder
|
921 |
+
dtype = text_encoder.get_input_embeddings().weight.dtype
|
922 |
+
num_added_tokens = tokenizer.add_tokens(token)
|
923 |
+
|
924 |
+
i = 1
|
925 |
+
if not idempotent:
|
926 |
+
while num_added_tokens == 0:
|
927 |
+
print(f"The tokenizer already contains the token {token}.")
|
928 |
+
token = f"{token[:-1]}-{i}>"
|
929 |
+
print(f"Attempting to add the token {token}.")
|
930 |
+
num_added_tokens = tokenizer.add_tokens(token)
|
931 |
+
i += 1
|
932 |
+
elif num_added_tokens == 0 and idempotent:
|
933 |
+
print(f"The tokenizer already contains the token {token}.")
|
934 |
+
print(f"Replacing {token} embedding.")
|
935 |
+
|
936 |
+
# resize the token embeddings
|
937 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
938 |
+
|
939 |
+
# get the id for the token and assign the embeds
|
940 |
+
token_id = tokenizer.convert_tokens_to_ids(token)
|
941 |
+
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
|
942 |
+
return token
|
943 |
+
|
944 |
+
|
945 |
+
def load_learned_embed_in_clip(
|
946 |
+
learned_embeds_path,
|
947 |
+
text_encoder,
|
948 |
+
tokenizer,
|
949 |
+
token: Optional[Union[str, List[str]]] = None,
|
950 |
+
idempotent=False,
|
951 |
+
):
|
952 |
+
learned_embeds = torch.load(learned_embeds_path)
|
953 |
+
apply_learned_embed_in_clip(
|
954 |
+
learned_embeds, text_encoder, tokenizer, token, idempotent
|
955 |
+
)
|
956 |
+
|
957 |
+
|
958 |
+
def patch_pipe(
|
959 |
+
pipe,
|
960 |
+
maybe_unet_path,
|
961 |
+
token: Optional[str] = None,
|
962 |
+
r: int = 4,
|
963 |
+
patch_unet=True,
|
964 |
+
patch_text=True,
|
965 |
+
patch_ti=True,
|
966 |
+
idempotent_token=True,
|
967 |
+
unet_target_replace_module=DEFAULT_TARGET_REPLACE,
|
968 |
+
text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
|
969 |
+
):
|
970 |
+
if maybe_unet_path.endswith(".pt"):
|
971 |
+
# torch format
|
972 |
+
|
973 |
+
if maybe_unet_path.endswith(".ti.pt"):
|
974 |
+
unet_path = maybe_unet_path[:-6] + ".pt"
|
975 |
+
elif maybe_unet_path.endswith(".text_encoder.pt"):
|
976 |
+
unet_path = maybe_unet_path[:-16] + ".pt"
|
977 |
+
else:
|
978 |
+
unet_path = maybe_unet_path
|
979 |
+
|
980 |
+
ti_path = _ti_lora_path(unet_path)
|
981 |
+
text_path = _text_lora_path(unet_path)
|
982 |
+
|
983 |
+
if patch_unet:
|
984 |
+
print("LoRA : Patching Unet")
|
985 |
+
monkeypatch_or_replace_lora(
|
986 |
+
pipe.unet,
|
987 |
+
torch.load(unet_path),
|
988 |
+
r=r,
|
989 |
+
target_replace_module=unet_target_replace_module,
|
990 |
+
)
|
991 |
+
|
992 |
+
if patch_text:
|
993 |
+
print("LoRA : Patching text encoder")
|
994 |
+
monkeypatch_or_replace_lora(
|
995 |
+
pipe.text_encoder,
|
996 |
+
torch.load(text_path),
|
997 |
+
target_replace_module=text_target_replace_module,
|
998 |
+
r=r,
|
999 |
+
)
|
1000 |
+
if patch_ti:
|
1001 |
+
print("LoRA : Patching token input")
|
1002 |
+
token = load_learned_embed_in_clip(
|
1003 |
+
ti_path,
|
1004 |
+
pipe.text_encoder,
|
1005 |
+
pipe.tokenizer,
|
1006 |
+
token=token,
|
1007 |
+
idempotent=idempotent_token,
|
1008 |
+
)
|
1009 |
+
|
1010 |
+
elif maybe_unet_path.endswith(".safetensors"):
|
1011 |
+
safeloras = safe_open(maybe_unet_path, framework="pt", device="cpu")
|
1012 |
+
monkeypatch_or_replace_safeloras(pipe, safeloras)
|
1013 |
+
tok_dict = parse_safeloras_embeds(safeloras)
|
1014 |
+
if patch_ti:
|
1015 |
+
apply_learned_embed_in_clip(
|
1016 |
+
tok_dict,
|
1017 |
+
pipe.text_encoder,
|
1018 |
+
pipe.tokenizer,
|
1019 |
+
token=token,
|
1020 |
+
idempotent=idempotent_token,
|
1021 |
+
)
|
1022 |
+
return tok_dict
|
1023 |
+
|
1024 |
+
|
1025 |
+
@torch.no_grad()
|
1026 |
+
def inspect_lora(model):
|
1027 |
+
moved = {}
|
1028 |
+
|
1029 |
+
for name, _module in model.named_modules():
|
1030 |
+
if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
|
1031 |
+
ups = _module.lora_up.weight.data.clone()
|
1032 |
+
downs = _module.lora_down.weight.data.clone()
|
1033 |
+
|
1034 |
+
wght: torch.Tensor = ups.flatten(1) @ downs.flatten(1)
|
1035 |
+
|
1036 |
+
dist = wght.flatten().abs().mean().item()
|
1037 |
+
if name in moved:
|
1038 |
+
moved[name].append(dist)
|
1039 |
+
else:
|
1040 |
+
moved[name] = [dist]
|
1041 |
+
|
1042 |
+
return moved
|
1043 |
+
|
1044 |
+
|
1045 |
+
def save_all(
|
1046 |
+
unet,
|
1047 |
+
text_encoder,
|
1048 |
+
save_path,
|
1049 |
+
placeholder_token_ids=None,
|
1050 |
+
placeholder_tokens=None,
|
1051 |
+
save_lora=True,
|
1052 |
+
save_ti=True,
|
1053 |
+
target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
|
1054 |
+
target_replace_module_unet=DEFAULT_TARGET_REPLACE,
|
1055 |
+
safe_form=True,
|
1056 |
+
):
|
1057 |
+
if not safe_form:
|
1058 |
+
# save ti
|
1059 |
+
if save_ti:
|
1060 |
+
ti_path = _ti_lora_path(save_path)
|
1061 |
+
learned_embeds_dict = {}
|
1062 |
+
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
|
1063 |
+
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
|
1064 |
+
print(
|
1065 |
+
f"Current Learned Embeddings for {tok}:, id {tok_id} ",
|
1066 |
+
learned_embeds[:4],
|
1067 |
+
)
|
1068 |
+
learned_embeds_dict[tok] = learned_embeds.detach().cpu()
|
1069 |
+
|
1070 |
+
torch.save(learned_embeds_dict, ti_path)
|
1071 |
+
print("Ti saved to ", ti_path)
|
1072 |
+
|
1073 |
+
# save text encoder
|
1074 |
+
if save_lora:
|
1075 |
+
|
1076 |
+
save_lora_weight(
|
1077 |
+
unet, save_path, target_replace_module=target_replace_module_unet
|
1078 |
+
)
|
1079 |
+
print("Unet saved to ", save_path)
|
1080 |
+
|
1081 |
+
save_lora_weight(
|
1082 |
+
text_encoder,
|
1083 |
+
_text_lora_path(save_path),
|
1084 |
+
target_replace_module=target_replace_module_text,
|
1085 |
+
)
|
1086 |
+
print("Text Encoder saved to ", _text_lora_path(save_path))
|
1087 |
+
|
1088 |
+
else:
|
1089 |
+
assert save_path.endswith(
|
1090 |
+
".safetensors"
|
1091 |
+
), f"Save path : {save_path} should end with .safetensors"
|
1092 |
+
|
1093 |
+
loras = {}
|
1094 |
+
embeds = {}
|
1095 |
+
|
1096 |
+
if save_lora:
|
1097 |
+
|
1098 |
+
loras["unet"] = (unet, target_replace_module_unet)
|
1099 |
+
loras["text_encoder"] = (text_encoder, target_replace_module_text)
|
1100 |
+
|
1101 |
+
if save_ti:
|
1102 |
+
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
|
1103 |
+
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
|
1104 |
+
print(
|
1105 |
+
f"Current Learned Embeddings for {tok}:, id {tok_id} ",
|
1106 |
+
learned_embeds[:4],
|
1107 |
+
)
|
1108 |
+
embeds[tok] = learned_embeds.detach().cpu()
|
1109 |
+
|
1110 |
+
save_safeloras_with_embeds(loras, embeds, save_path)
|
lora_diffusion/lora_manager.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import torch
|
3 |
+
from safetensors import safe_open
|
4 |
+
from diffusers import StableDiffusionPipeline
|
5 |
+
from .lora import (
|
6 |
+
monkeypatch_or_replace_safeloras,
|
7 |
+
apply_learned_embed_in_clip,
|
8 |
+
set_lora_diag,
|
9 |
+
parse_safeloras_embeds,
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
def lora_join(lora_safetenors: list):
|
14 |
+
metadatas = [dict(safelora.metadata()) for safelora in lora_safetenors]
|
15 |
+
_total_metadata = {}
|
16 |
+
total_metadata = {}
|
17 |
+
total_tensor = {}
|
18 |
+
total_rank = 0
|
19 |
+
ranklist = []
|
20 |
+
for _metadata in metadatas:
|
21 |
+
rankset = []
|
22 |
+
for k, v in _metadata.items():
|
23 |
+
if k.endswith("rank"):
|
24 |
+
rankset.append(int(v))
|
25 |
+
|
26 |
+
assert len(set(rankset)) <= 1, "Rank should be the same per model"
|
27 |
+
if len(rankset) == 0:
|
28 |
+
rankset = [0]
|
29 |
+
|
30 |
+
total_rank += rankset[0]
|
31 |
+
_total_metadata.update(_metadata)
|
32 |
+
ranklist.append(rankset[0])
|
33 |
+
|
34 |
+
# remove metadata about tokens
|
35 |
+
for k, v in _total_metadata.items():
|
36 |
+
if v != "<embed>":
|
37 |
+
total_metadata[k] = v
|
38 |
+
|
39 |
+
tensorkeys = set()
|
40 |
+
for safelora in lora_safetenors:
|
41 |
+
tensorkeys.update(safelora.keys())
|
42 |
+
|
43 |
+
for keys in tensorkeys:
|
44 |
+
if keys.startswith("text_encoder") or keys.startswith("unet"):
|
45 |
+
tensorset = [safelora.get_tensor(keys) for safelora in lora_safetenors]
|
46 |
+
|
47 |
+
is_down = keys.endswith("down")
|
48 |
+
|
49 |
+
if is_down:
|
50 |
+
_tensor = torch.cat(tensorset, dim=0)
|
51 |
+
assert _tensor.shape[0] == total_rank
|
52 |
+
else:
|
53 |
+
_tensor = torch.cat(tensorset, dim=1)
|
54 |
+
assert _tensor.shape[1] == total_rank
|
55 |
+
|
56 |
+
total_tensor[keys] = _tensor
|
57 |
+
keys_rank = ":".join(keys.split(":")[:-1]) + ":rank"
|
58 |
+
total_metadata[keys_rank] = str(total_rank)
|
59 |
+
token_size_list = []
|
60 |
+
for idx, safelora in enumerate(lora_safetenors):
|
61 |
+
tokens = [k for k, v in safelora.metadata().items() if v == "<embed>"]
|
62 |
+
for jdx, token in enumerate(sorted(tokens)):
|
63 |
+
|
64 |
+
total_tensor[f"<s{idx}-{jdx}>"] = safelora.get_tensor(token)
|
65 |
+
total_metadata[f"<s{idx}-{jdx}>"] = "<embed>"
|
66 |
+
|
67 |
+
print(f"Embedding {token} replaced to <s{idx}-{jdx}>")
|
68 |
+
|
69 |
+
token_size_list.append(len(tokens))
|
70 |
+
|
71 |
+
return total_tensor, total_metadata, ranklist, token_size_list
|
72 |
+
|
73 |
+
|
74 |
+
class DummySafeTensorObject:
|
75 |
+
def __init__(self, tensor: dict, metadata):
|
76 |
+
self.tensor = tensor
|
77 |
+
self._metadata = metadata
|
78 |
+
|
79 |
+
def keys(self):
|
80 |
+
return self.tensor.keys()
|
81 |
+
|
82 |
+
def metadata(self):
|
83 |
+
return self._metadata
|
84 |
+
|
85 |
+
def get_tensor(self, key):
|
86 |
+
return self.tensor[key]
|
87 |
+
|
88 |
+
|
89 |
+
class LoRAManager:
|
90 |
+
def __init__(self, lora_paths_list: List[str], pipe: StableDiffusionPipeline):
|
91 |
+
|
92 |
+
self.lora_paths_list = lora_paths_list
|
93 |
+
self.pipe = pipe
|
94 |
+
self._setup()
|
95 |
+
|
96 |
+
def _setup(self):
|
97 |
+
|
98 |
+
self._lora_safetenors = [
|
99 |
+
safe_open(path, framework="pt", device="cpu")
|
100 |
+
for path in self.lora_paths_list
|
101 |
+
]
|
102 |
+
|
103 |
+
(
|
104 |
+
total_tensor,
|
105 |
+
total_metadata,
|
106 |
+
self.ranklist,
|
107 |
+
self.token_size_list,
|
108 |
+
) = lora_join(self._lora_safetenors)
|
109 |
+
|
110 |
+
self.total_safelora = DummySafeTensorObject(total_tensor, total_metadata)
|
111 |
+
|
112 |
+
monkeypatch_or_replace_safeloras(self.pipe, self.total_safelora)
|
113 |
+
tok_dict = parse_safeloras_embeds(self.total_safelora)
|
114 |
+
|
115 |
+
apply_learned_embed_in_clip(
|
116 |
+
tok_dict,
|
117 |
+
self.pipe.text_encoder,
|
118 |
+
self.pipe.tokenizer,
|
119 |
+
token=None,
|
120 |
+
idempotent=True,
|
121 |
+
)
|
122 |
+
|
123 |
+
def tune(self, scales):
|
124 |
+
|
125 |
+
assert len(scales) == len(
|
126 |
+
self.ranklist
|
127 |
+
), "Scale list should be the same length as ranklist"
|
128 |
+
|
129 |
+
diags = []
|
130 |
+
for scale, rank in zip(scales, self.ranklist):
|
131 |
+
diags = diags + [scale] * rank
|
132 |
+
|
133 |
+
set_lora_diag(self.pipe.unet, torch.tensor(diags))
|
134 |
+
|
135 |
+
def prompt(self, prompt):
|
136 |
+
if prompt is not None:
|
137 |
+
for idx, tok_size in enumerate(self.token_size_list):
|
138 |
+
prompt = prompt.replace(
|
139 |
+
f"<{idx + 1}>",
|
140 |
+
"".join([f"<s{idx}-{jdx}>" for jdx in range(tok_size)]),
|
141 |
+
)
|
142 |
+
# TODO : Rescale LoRA + Text inputs based on prompt scale params
|
143 |
+
|
144 |
+
return prompt
|
lora_diffusion/preprocess_files.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Have SwinIR upsample
|
2 |
+
# Have BLIP auto caption
|
3 |
+
# Have CLIPSeg auto mask concept
|
4 |
+
|
5 |
+
from typing import List, Literal, Union, Optional, Tuple
|
6 |
+
import os
|
7 |
+
from PIL import Image, ImageFilter
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
import fire
|
11 |
+
from tqdm import tqdm
|
12 |
+
import glob
|
13 |
+
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
14 |
+
|
15 |
+
|
16 |
+
@torch.no_grad()
|
17 |
+
def swin_ir_sr(
|
18 |
+
images: List[Image.Image],
|
19 |
+
model_id: Literal[
|
20 |
+
"caidas/swin2SR-classical-sr-x2-64", "caidas/swin2SR-classical-sr-x4-48"
|
21 |
+
] = "caidas/swin2SR-classical-sr-x2-64",
|
22 |
+
target_size: Optional[Tuple[int, int]] = None,
|
23 |
+
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
|
24 |
+
**kwargs,
|
25 |
+
) -> List[Image.Image]:
|
26 |
+
"""
|
27 |
+
Upscales images using SwinIR. Returns a list of PIL images.
|
28 |
+
"""
|
29 |
+
# So this is currently in main branch, so this can be used in the future I guess?
|
30 |
+
from transformers import Swin2SRForImageSuperResolution, Swin2SRImageProcessor
|
31 |
+
|
32 |
+
model = Swin2SRForImageSuperResolution.from_pretrained(
|
33 |
+
model_id,
|
34 |
+
).to(device)
|
35 |
+
processor = Swin2SRImageProcessor()
|
36 |
+
|
37 |
+
out_images = []
|
38 |
+
|
39 |
+
for image in tqdm(images):
|
40 |
+
|
41 |
+
ori_w, ori_h = image.size
|
42 |
+
if target_size is not None:
|
43 |
+
if ori_w >= target_size[0] and ori_h >= target_size[1]:
|
44 |
+
out_images.append(image)
|
45 |
+
continue
|
46 |
+
|
47 |
+
inputs = processor(image, return_tensors="pt").to(device)
|
48 |
+
with torch.no_grad():
|
49 |
+
outputs = model(**inputs)
|
50 |
+
|
51 |
+
output = (
|
52 |
+
outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
53 |
+
)
|
54 |
+
output = np.moveaxis(output, source=0, destination=-1)
|
55 |
+
output = (output * 255.0).round().astype(np.uint8)
|
56 |
+
output = Image.fromarray(output)
|
57 |
+
|
58 |
+
out_images.append(output)
|
59 |
+
|
60 |
+
return out_images
|
61 |
+
|
62 |
+
|
63 |
+
@torch.no_grad()
|
64 |
+
def clipseg_mask_generator(
|
65 |
+
images: List[Image.Image],
|
66 |
+
target_prompts: Union[List[str], str],
|
67 |
+
model_id: Literal[
|
68 |
+
"CIDAS/clipseg-rd64-refined", "CIDAS/clipseg-rd16"
|
69 |
+
] = "CIDAS/clipseg-rd64-refined",
|
70 |
+
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
|
71 |
+
bias: float = 0.01,
|
72 |
+
temp: float = 1.0,
|
73 |
+
**kwargs,
|
74 |
+
) -> List[Image.Image]:
|
75 |
+
"""
|
76 |
+
Returns a greyscale mask for each image, where the mask is the probability of the target prompt being present in the image
|
77 |
+
"""
|
78 |
+
|
79 |
+
if isinstance(target_prompts, str):
|
80 |
+
print(
|
81 |
+
f'Warning: only one target prompt "{target_prompts}" was given, so it will be used for all images'
|
82 |
+
)
|
83 |
+
|
84 |
+
target_prompts = [target_prompts] * len(images)
|
85 |
+
|
86 |
+
processor = CLIPSegProcessor.from_pretrained(model_id)
|
87 |
+
model = CLIPSegForImageSegmentation.from_pretrained(model_id).to(device)
|
88 |
+
|
89 |
+
masks = []
|
90 |
+
|
91 |
+
for image, prompt in tqdm(zip(images, target_prompts)):
|
92 |
+
|
93 |
+
original_size = image.size
|
94 |
+
|
95 |
+
inputs = processor(
|
96 |
+
text=[prompt, ""],
|
97 |
+
images=[image] * 2,
|
98 |
+
padding="max_length",
|
99 |
+
truncation=True,
|
100 |
+
return_tensors="pt",
|
101 |
+
).to(device)
|
102 |
+
|
103 |
+
outputs = model(**inputs)
|
104 |
+
|
105 |
+
logits = outputs.logits
|
106 |
+
probs = torch.nn.functional.softmax(logits / temp, dim=0)[0]
|
107 |
+
probs = (probs + bias).clamp_(0, 1)
|
108 |
+
probs = 255 * probs / probs.max()
|
109 |
+
|
110 |
+
# make mask greyscale
|
111 |
+
mask = Image.fromarray(probs.cpu().numpy()).convert("L")
|
112 |
+
|
113 |
+
# resize mask to original size
|
114 |
+
mask = mask.resize(original_size)
|
115 |
+
|
116 |
+
masks.append(mask)
|
117 |
+
|
118 |
+
return masks
|
119 |
+
|
120 |
+
|
121 |
+
@torch.no_grad()
|
122 |
+
def blip_captioning_dataset(
|
123 |
+
images: List[Image.Image],
|
124 |
+
text: Optional[str] = None,
|
125 |
+
model_id: Literal[
|
126 |
+
"Salesforce/blip-image-captioning-large",
|
127 |
+
"Salesforce/blip-image-captioning-base",
|
128 |
+
] = "Salesforce/blip-image-captioning-large",
|
129 |
+
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
130 |
+
**kwargs,
|
131 |
+
) -> List[str]:
|
132 |
+
"""
|
133 |
+
Returns a list of captions for the given images
|
134 |
+
"""
|
135 |
+
|
136 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
137 |
+
|
138 |
+
processor = BlipProcessor.from_pretrained(model_id)
|
139 |
+
model = BlipForConditionalGeneration.from_pretrained(model_id).to(device)
|
140 |
+
captions = []
|
141 |
+
|
142 |
+
for image in tqdm(images):
|
143 |
+
inputs = processor(image, text=text, return_tensors="pt").to("cuda")
|
144 |
+
out = model.generate(
|
145 |
+
**inputs, max_length=150, do_sample=True, top_k=50, temperature=0.7
|
146 |
+
)
|
147 |
+
caption = processor.decode(out[0], skip_special_tokens=True)
|
148 |
+
|
149 |
+
captions.append(caption)
|
150 |
+
|
151 |
+
return captions
|
152 |
+
|
153 |
+
|
154 |
+
def face_mask_google_mediapipe(
|
155 |
+
images: List[Image.Image], blur_amount: float = 80.0, bias: float = 0.05
|
156 |
+
) -> List[Image.Image]:
|
157 |
+
"""
|
158 |
+
Returns a list of images with mask on the face parts.
|
159 |
+
"""
|
160 |
+
import mediapipe as mp
|
161 |
+
|
162 |
+
mp_face_detection = mp.solutions.face_detection
|
163 |
+
|
164 |
+
face_detection = mp_face_detection.FaceDetection(
|
165 |
+
model_selection=1, min_detection_confidence=0.5
|
166 |
+
)
|
167 |
+
|
168 |
+
masks = []
|
169 |
+
for image in tqdm(images):
|
170 |
+
|
171 |
+
image = np.array(image)
|
172 |
+
|
173 |
+
results = face_detection.process(image)
|
174 |
+
black_image = np.ones((image.shape[0], image.shape[1]), dtype=np.uint8)
|
175 |
+
|
176 |
+
if results.detections:
|
177 |
+
|
178 |
+
for detection in results.detections:
|
179 |
+
|
180 |
+
x_min = int(
|
181 |
+
detection.location_data.relative_bounding_box.xmin * image.shape[1]
|
182 |
+
)
|
183 |
+
y_min = int(
|
184 |
+
detection.location_data.relative_bounding_box.ymin * image.shape[0]
|
185 |
+
)
|
186 |
+
width = int(
|
187 |
+
detection.location_data.relative_bounding_box.width * image.shape[1]
|
188 |
+
)
|
189 |
+
height = int(
|
190 |
+
detection.location_data.relative_bounding_box.height
|
191 |
+
* image.shape[0]
|
192 |
+
)
|
193 |
+
|
194 |
+
# draw the colored rectangle
|
195 |
+
black_image[y_min : y_min + height, x_min : x_min + width] = 255
|
196 |
+
|
197 |
+
black_image = Image.fromarray(black_image)
|
198 |
+
masks.append(black_image)
|
199 |
+
|
200 |
+
return masks
|
201 |
+
|
202 |
+
|
203 |
+
def _crop_to_square(
|
204 |
+
image: Image.Image, com: List[Tuple[int, int]], resize_to: Optional[int] = None
|
205 |
+
):
|
206 |
+
cx, cy = com
|
207 |
+
width, height = image.size
|
208 |
+
if width > height:
|
209 |
+
left_possible = max(cx - height / 2, 0)
|
210 |
+
left = min(left_possible, width - height)
|
211 |
+
right = left + height
|
212 |
+
top = 0
|
213 |
+
bottom = height
|
214 |
+
else:
|
215 |
+
left = 0
|
216 |
+
right = width
|
217 |
+
top_possible = max(cy - width / 2, 0)
|
218 |
+
top = min(top_possible, height - width)
|
219 |
+
bottom = top + width
|
220 |
+
|
221 |
+
image = image.crop((left, top, right, bottom))
|
222 |
+
|
223 |
+
if resize_to:
|
224 |
+
image = image.resize((resize_to, resize_to), Image.Resampling.LANCZOS)
|
225 |
+
|
226 |
+
return image
|
227 |
+
|
228 |
+
|
229 |
+
def _center_of_mass(mask: Image.Image):
|
230 |
+
"""
|
231 |
+
Returns the center of mass of the mask
|
232 |
+
"""
|
233 |
+
x, y = np.meshgrid(np.arange(mask.size[0]), np.arange(mask.size[1]))
|
234 |
+
|
235 |
+
x_ = x * np.array(mask)
|
236 |
+
y_ = y * np.array(mask)
|
237 |
+
|
238 |
+
x = np.sum(x_) / np.sum(mask)
|
239 |
+
y = np.sum(y_) / np.sum(mask)
|
240 |
+
|
241 |
+
return x, y
|
242 |
+
|
243 |
+
|
244 |
+
def load_and_save_masks_and_captions(
|
245 |
+
files: Union[str, List[str]],
|
246 |
+
output_dir: str,
|
247 |
+
caption_text: Optional[str] = None,
|
248 |
+
target_prompts: Optional[Union[List[str], str]] = None,
|
249 |
+
target_size: int = 512,
|
250 |
+
crop_based_on_salience: bool = True,
|
251 |
+
use_face_detection_instead: bool = False,
|
252 |
+
temp: float = 1.0,
|
253 |
+
n_length: int = -1,
|
254 |
+
):
|
255 |
+
"""
|
256 |
+
Loads images from the given files, generates masks for them, and saves the masks and captions and upscale images
|
257 |
+
to output dir.
|
258 |
+
"""
|
259 |
+
os.makedirs(output_dir, exist_ok=True)
|
260 |
+
|
261 |
+
# load images
|
262 |
+
if isinstance(files, str):
|
263 |
+
# check if it is a directory
|
264 |
+
if os.path.isdir(files):
|
265 |
+
# get all the .png .jpg in the directory
|
266 |
+
files = glob.glob(os.path.join(files, "*.png")) + glob.glob(
|
267 |
+
os.path.join(files, "*.jpg")
|
268 |
+
)
|
269 |
+
|
270 |
+
if len(files) == 0:
|
271 |
+
raise Exception(
|
272 |
+
f"No files found in {files}. Either {files} is not a directory or it does not contain any .png or .jpg files."
|
273 |
+
)
|
274 |
+
if n_length == -1:
|
275 |
+
n_length = len(files)
|
276 |
+
files = sorted(files)[:n_length]
|
277 |
+
|
278 |
+
images = [Image.open(file) for file in files]
|
279 |
+
|
280 |
+
# captions
|
281 |
+
print(f"Generating {len(images)} captions...")
|
282 |
+
captions = blip_captioning_dataset(images, text=caption_text)
|
283 |
+
|
284 |
+
if target_prompts is None:
|
285 |
+
target_prompts = captions
|
286 |
+
|
287 |
+
print(f"Generating {len(images)} masks...")
|
288 |
+
if not use_face_detection_instead:
|
289 |
+
seg_masks = clipseg_mask_generator(
|
290 |
+
images=images, target_prompts=target_prompts, temp=temp
|
291 |
+
)
|
292 |
+
else:
|
293 |
+
seg_masks = face_mask_google_mediapipe(images=images)
|
294 |
+
|
295 |
+
# find the center of mass of the mask
|
296 |
+
if crop_based_on_salience:
|
297 |
+
coms = [_center_of_mass(mask) for mask in seg_masks]
|
298 |
+
else:
|
299 |
+
coms = [(image.size[0] / 2, image.size[1] / 2) for image in images]
|
300 |
+
# based on the center of mass, crop the image to a square
|
301 |
+
images = [
|
302 |
+
_crop_to_square(image, com, resize_to=None) for image, com in zip(images, coms)
|
303 |
+
]
|
304 |
+
|
305 |
+
print(f"Upscaling {len(images)} images...")
|
306 |
+
# upscale images anyways
|
307 |
+
images = swin_ir_sr(images, target_size=(target_size, target_size))
|
308 |
+
images = [
|
309 |
+
image.resize((target_size, target_size), Image.Resampling.LANCZOS)
|
310 |
+
for image in images
|
311 |
+
]
|
312 |
+
|
313 |
+
seg_masks = [
|
314 |
+
_crop_to_square(mask, com, resize_to=target_size)
|
315 |
+
for mask, com in zip(seg_masks, coms)
|
316 |
+
]
|
317 |
+
with open(os.path.join(output_dir, "caption.txt"), "w") as f:
|
318 |
+
# save images and masks
|
319 |
+
for idx, (image, mask, caption) in enumerate(zip(images, seg_masks, captions)):
|
320 |
+
image.save(os.path.join(output_dir, f"{idx}.src.jpg"), quality=99)
|
321 |
+
mask.save(os.path.join(output_dir, f"{idx}.mask.png"))
|
322 |
+
|
323 |
+
f.write(caption + "\n")
|
324 |
+
|
325 |
+
|
326 |
+
def main():
|
327 |
+
fire.Fire(load_and_save_masks_and_captions)
|
lora_diffusion/safe_open.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Pure python version of Safetensors safe_open
|
3 |
+
From https://gist.github.com/Narsil/3edeec2669a5e94e4707aa0f901d2282
|
4 |
+
"""
|
5 |
+
|
6 |
+
import json
|
7 |
+
import mmap
|
8 |
+
import os
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
|
13 |
+
class SafetensorsWrapper:
|
14 |
+
def __init__(self, metadata, tensors):
|
15 |
+
self._metadata = metadata
|
16 |
+
self._tensors = tensors
|
17 |
+
|
18 |
+
def metadata(self):
|
19 |
+
return self._metadata
|
20 |
+
|
21 |
+
def keys(self):
|
22 |
+
return self._tensors.keys()
|
23 |
+
|
24 |
+
def get_tensor(self, k):
|
25 |
+
return self._tensors[k]
|
26 |
+
|
27 |
+
|
28 |
+
DTYPES = {
|
29 |
+
"F32": torch.float32,
|
30 |
+
"F16": torch.float16,
|
31 |
+
"BF16": torch.bfloat16,
|
32 |
+
}
|
33 |
+
|
34 |
+
|
35 |
+
def create_tensor(storage, info, offset):
|
36 |
+
dtype = DTYPES[info["dtype"]]
|
37 |
+
shape = info["shape"]
|
38 |
+
start, stop = info["data_offsets"]
|
39 |
+
return (
|
40 |
+
torch.asarray(storage[start + offset : stop + offset], dtype=torch.uint8)
|
41 |
+
.view(dtype=dtype)
|
42 |
+
.reshape(shape)
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
def safe_open(filename, framework="pt", device="cpu"):
|
47 |
+
if framework != "pt":
|
48 |
+
raise ValueError("`framework` must be 'pt'")
|
49 |
+
|
50 |
+
with open(filename, mode="r", encoding="utf8") as file_obj:
|
51 |
+
with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m:
|
52 |
+
header = m.read(8)
|
53 |
+
n = int.from_bytes(header, "little")
|
54 |
+
metadata_bytes = m.read(n)
|
55 |
+
metadata = json.loads(metadata_bytes)
|
56 |
+
|
57 |
+
size = os.stat(filename).st_size
|
58 |
+
storage = torch.ByteStorage.from_file(filename, shared=False, size=size).untyped()
|
59 |
+
offset = n + 8
|
60 |
+
|
61 |
+
return SafetensorsWrapper(
|
62 |
+
metadata=metadata.get("__metadata__", {}),
|
63 |
+
tensors={
|
64 |
+
name: create_tensor(storage, info, offset).to(device)
|
65 |
+
for name, info in metadata.items()
|
66 |
+
if name != "__metadata__"
|
67 |
+
},
|
68 |
+
)
|
lora_diffusion/to_ckpt_v2.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from https://gist.github.com/jachiam/8a5c0b607e38fcc585168b90c686eb05
|
2 |
+
# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
|
3 |
+
# *Only* converts the UNet, VAE, and Text Encoder.
|
4 |
+
# Does not convert optimizer state or any other thing.
|
5 |
+
# Written by jachiam
|
6 |
+
import argparse
|
7 |
+
import os.path as osp
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
# =================#
|
13 |
+
# UNet Conversion #
|
14 |
+
# =================#
|
15 |
+
|
16 |
+
unet_conversion_map = [
|
17 |
+
# (stable-diffusion, HF Diffusers)
|
18 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
19 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
20 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
21 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
22 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
23 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
24 |
+
("out.0.weight", "conv_norm_out.weight"),
|
25 |
+
("out.0.bias", "conv_norm_out.bias"),
|
26 |
+
("out.2.weight", "conv_out.weight"),
|
27 |
+
("out.2.bias", "conv_out.bias"),
|
28 |
+
]
|
29 |
+
|
30 |
+
unet_conversion_map_resnet = [
|
31 |
+
# (stable-diffusion, HF Diffusers)
|
32 |
+
("in_layers.0", "norm1"),
|
33 |
+
("in_layers.2", "conv1"),
|
34 |
+
("out_layers.0", "norm2"),
|
35 |
+
("out_layers.3", "conv2"),
|
36 |
+
("emb_layers.1", "time_emb_proj"),
|
37 |
+
("skip_connection", "conv_shortcut"),
|
38 |
+
]
|
39 |
+
|
40 |
+
unet_conversion_map_layer = []
|
41 |
+
# hardcoded number of downblocks and resnets/attentions...
|
42 |
+
# would need smarter logic for other networks.
|
43 |
+
for i in range(4):
|
44 |
+
# loop over downblocks/upblocks
|
45 |
+
|
46 |
+
for j in range(2):
|
47 |
+
# loop over resnets/attentions for downblocks
|
48 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
49 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
50 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
51 |
+
|
52 |
+
if i < 3:
|
53 |
+
# no attention layers in down_blocks.3
|
54 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
55 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
56 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
57 |
+
|
58 |
+
for j in range(3):
|
59 |
+
# loop over resnets/attentions for upblocks
|
60 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
61 |
+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
62 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
63 |
+
|
64 |
+
if i > 0:
|
65 |
+
# no attention layers in up_blocks.0
|
66 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
67 |
+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
68 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
69 |
+
|
70 |
+
if i < 3:
|
71 |
+
# no downsample in down_blocks.3
|
72 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
73 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
74 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
75 |
+
|
76 |
+
# no upsample in up_blocks.3
|
77 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
78 |
+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
79 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
80 |
+
|
81 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
82 |
+
sd_mid_atn_prefix = "middle_block.1."
|
83 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
84 |
+
|
85 |
+
for j in range(2):
|
86 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
87 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
88 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
89 |
+
|
90 |
+
|
91 |
+
def convert_unet_state_dict(unet_state_dict):
|
92 |
+
# buyer beware: this is a *brittle* function,
|
93 |
+
# and correct output requires that all of these pieces interact in
|
94 |
+
# the exact order in which I have arranged them.
|
95 |
+
mapping = {k: k for k in unet_state_dict.keys()}
|
96 |
+
for sd_name, hf_name in unet_conversion_map:
|
97 |
+
mapping[hf_name] = sd_name
|
98 |
+
for k, v in mapping.items():
|
99 |
+
if "resnets" in k:
|
100 |
+
for sd_part, hf_part in unet_conversion_map_resnet:
|
101 |
+
v = v.replace(hf_part, sd_part)
|
102 |
+
mapping[k] = v
|
103 |
+
for k, v in mapping.items():
|
104 |
+
for sd_part, hf_part in unet_conversion_map_layer:
|
105 |
+
v = v.replace(hf_part, sd_part)
|
106 |
+
mapping[k] = v
|
107 |
+
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
108 |
+
return new_state_dict
|
109 |
+
|
110 |
+
|
111 |
+
# ================#
|
112 |
+
# VAE Conversion #
|
113 |
+
# ================#
|
114 |
+
|
115 |
+
vae_conversion_map = [
|
116 |
+
# (stable-diffusion, HF Diffusers)
|
117 |
+
("nin_shortcut", "conv_shortcut"),
|
118 |
+
("norm_out", "conv_norm_out"),
|
119 |
+
("mid.attn_1.", "mid_block.attentions.0."),
|
120 |
+
]
|
121 |
+
|
122 |
+
for i in range(4):
|
123 |
+
# down_blocks have two resnets
|
124 |
+
for j in range(2):
|
125 |
+
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
126 |
+
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
127 |
+
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
128 |
+
|
129 |
+
if i < 3:
|
130 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
131 |
+
sd_downsample_prefix = f"down.{i}.downsample."
|
132 |
+
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
133 |
+
|
134 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
135 |
+
sd_upsample_prefix = f"up.{3-i}.upsample."
|
136 |
+
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
137 |
+
|
138 |
+
# up_blocks have three resnets
|
139 |
+
# also, up blocks in hf are numbered in reverse from sd
|
140 |
+
for j in range(3):
|
141 |
+
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
142 |
+
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
143 |
+
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
144 |
+
|
145 |
+
# this part accounts for mid blocks in both the encoder and the decoder
|
146 |
+
for i in range(2):
|
147 |
+
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
148 |
+
sd_mid_res_prefix = f"mid.block_{i+1}."
|
149 |
+
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
150 |
+
|
151 |
+
|
152 |
+
vae_conversion_map_attn = [
|
153 |
+
# (stable-diffusion, HF Diffusers)
|
154 |
+
("norm.", "group_norm."),
|
155 |
+
("q.", "query."),
|
156 |
+
("k.", "key."),
|
157 |
+
("v.", "value."),
|
158 |
+
("proj_out.", "proj_attn."),
|
159 |
+
]
|
160 |
+
|
161 |
+
|
162 |
+
def reshape_weight_for_sd(w):
|
163 |
+
# convert HF linear weights to SD conv2d weights
|
164 |
+
return w.reshape(*w.shape, 1, 1)
|
165 |
+
|
166 |
+
|
167 |
+
def convert_vae_state_dict(vae_state_dict):
|
168 |
+
mapping = {k: k for k in vae_state_dict.keys()}
|
169 |
+
for k, v in mapping.items():
|
170 |
+
for sd_part, hf_part in vae_conversion_map:
|
171 |
+
v = v.replace(hf_part, sd_part)
|
172 |
+
mapping[k] = v
|
173 |
+
for k, v in mapping.items():
|
174 |
+
if "attentions" in k:
|
175 |
+
for sd_part, hf_part in vae_conversion_map_attn:
|
176 |
+
v = v.replace(hf_part, sd_part)
|
177 |
+
mapping[k] = v
|
178 |
+
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
179 |
+
weights_to_convert = ["q", "k", "v", "proj_out"]
|
180 |
+
for k, v in new_state_dict.items():
|
181 |
+
for weight_name in weights_to_convert:
|
182 |
+
if f"mid.attn_1.{weight_name}.weight" in k:
|
183 |
+
print(f"Reshaping {k} for SD format")
|
184 |
+
new_state_dict[k] = reshape_weight_for_sd(v)
|
185 |
+
return new_state_dict
|
186 |
+
|
187 |
+
|
188 |
+
# =========================#
|
189 |
+
# Text Encoder Conversion #
|
190 |
+
# =========================#
|
191 |
+
# pretty much a no-op
|
192 |
+
|
193 |
+
|
194 |
+
def convert_text_enc_state_dict(text_enc_dict):
|
195 |
+
return text_enc_dict
|
196 |
+
|
197 |
+
|
198 |
+
def convert_to_ckpt(model_path, checkpoint_path, as_half):
|
199 |
+
|
200 |
+
assert model_path is not None, "Must provide a model path!"
|
201 |
+
|
202 |
+
assert checkpoint_path is not None, "Must provide a checkpoint path!"
|
203 |
+
|
204 |
+
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
|
205 |
+
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
|
206 |
+
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
|
207 |
+
|
208 |
+
# Convert the UNet model
|
209 |
+
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
210 |
+
unet_state_dict = convert_unet_state_dict(unet_state_dict)
|
211 |
+
unet_state_dict = {
|
212 |
+
"model.diffusion_model." + k: v for k, v in unet_state_dict.items()
|
213 |
+
}
|
214 |
+
|
215 |
+
# Convert the VAE model
|
216 |
+
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
217 |
+
vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
218 |
+
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
219 |
+
|
220 |
+
# Convert the text encoder model
|
221 |
+
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
222 |
+
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
|
223 |
+
text_enc_dict = {
|
224 |
+
"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()
|
225 |
+
}
|
226 |
+
|
227 |
+
# Put together new checkpoint
|
228 |
+
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
229 |
+
if as_half:
|
230 |
+
state_dict = {k: v.half() for k, v in state_dict.items()}
|
231 |
+
state_dict = {"state_dict": state_dict}
|
232 |
+
torch.save(state_dict, checkpoint_path)
|
lora_diffusion/utils.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
from transformers import (
|
6 |
+
CLIPProcessor,
|
7 |
+
CLIPTextModelWithProjection,
|
8 |
+
CLIPTokenizer,
|
9 |
+
CLIPVisionModelWithProjection,
|
10 |
+
)
|
11 |
+
|
12 |
+
from diffusers import StableDiffusionPipeline
|
13 |
+
from .lora import patch_pipe, tune_lora_scale, _text_lora_path, _ti_lora_path
|
14 |
+
import os
|
15 |
+
import glob
|
16 |
+
import math
|
17 |
+
|
18 |
+
EXAMPLE_PROMPTS = [
|
19 |
+
"<obj> swimming in a pool",
|
20 |
+
"<obj> at a beach with a view of seashore",
|
21 |
+
"<obj> in times square",
|
22 |
+
"<obj> wearing sunglasses",
|
23 |
+
"<obj> in a construction outfit",
|
24 |
+
"<obj> playing with a ball",
|
25 |
+
"<obj> wearing headphones",
|
26 |
+
"<obj> oil painting ghibli inspired",
|
27 |
+
"<obj> working on the laptop",
|
28 |
+
"<obj> with mountains and sunset in background",
|
29 |
+
"Painting of <obj> at a beach by artist claude monet",
|
30 |
+
"<obj> digital painting 3d render geometric style",
|
31 |
+
"A screaming <obj>",
|
32 |
+
"A depressed <obj>",
|
33 |
+
"A sleeping <obj>",
|
34 |
+
"A sad <obj>",
|
35 |
+
"A joyous <obj>",
|
36 |
+
"A frowning <obj>",
|
37 |
+
"A sculpture of <obj>",
|
38 |
+
"<obj> near a pool",
|
39 |
+
"<obj> at a beach with a view of seashore",
|
40 |
+
"<obj> in a garden",
|
41 |
+
"<obj> in grand canyon",
|
42 |
+
"<obj> floating in ocean",
|
43 |
+
"<obj> and an armchair",
|
44 |
+
"A maple tree on the side of <obj>",
|
45 |
+
"<obj> and an orange sofa",
|
46 |
+
"<obj> with chocolate cake on it",
|
47 |
+
"<obj> with a vase of rose flowers on it",
|
48 |
+
"A digital illustration of <obj>",
|
49 |
+
"Georgia O'Keeffe style <obj> painting",
|
50 |
+
"A watercolor painting of <obj> on a beach",
|
51 |
+
]
|
52 |
+
|
53 |
+
|
54 |
+
def image_grid(_imgs, rows=None, cols=None):
|
55 |
+
|
56 |
+
if rows is None and cols is None:
|
57 |
+
rows = cols = math.ceil(len(_imgs) ** 0.5)
|
58 |
+
|
59 |
+
if rows is None:
|
60 |
+
rows = math.ceil(len(_imgs) / cols)
|
61 |
+
if cols is None:
|
62 |
+
cols = math.ceil(len(_imgs) / rows)
|
63 |
+
|
64 |
+
w, h = _imgs[0].size
|
65 |
+
grid = Image.new("RGB", size=(cols * w, rows * h))
|
66 |
+
grid_w, grid_h = grid.size
|
67 |
+
|
68 |
+
for i, img in enumerate(_imgs):
|
69 |
+
grid.paste(img, box=(i % cols * w, i // cols * h))
|
70 |
+
return grid
|
71 |
+
|
72 |
+
|
73 |
+
def text_img_alignment(img_embeds, text_embeds, target_img_embeds):
|
74 |
+
# evaluation inspired from textual inversion paper
|
75 |
+
# https://arxiv.org/abs/2208.01618
|
76 |
+
|
77 |
+
# text alignment
|
78 |
+
assert img_embeds.shape[0] == text_embeds.shape[0]
|
79 |
+
text_img_sim = (img_embeds * text_embeds).sum(dim=-1) / (
|
80 |
+
img_embeds.norm(dim=-1) * text_embeds.norm(dim=-1)
|
81 |
+
)
|
82 |
+
|
83 |
+
# image alignment
|
84 |
+
img_embed_normalized = img_embeds / img_embeds.norm(dim=-1, keepdim=True)
|
85 |
+
|
86 |
+
avg_target_img_embed = (
|
87 |
+
(target_img_embeds / target_img_embeds.norm(dim=-1, keepdim=True))
|
88 |
+
.mean(dim=0)
|
89 |
+
.unsqueeze(0)
|
90 |
+
.repeat(img_embeds.shape[0], 1)
|
91 |
+
)
|
92 |
+
|
93 |
+
img_img_sim = (img_embed_normalized * avg_target_img_embed).sum(dim=-1)
|
94 |
+
|
95 |
+
return {
|
96 |
+
"text_alignment_avg": text_img_sim.mean().item(),
|
97 |
+
"image_alignment_avg": img_img_sim.mean().item(),
|
98 |
+
"text_alignment_all": text_img_sim.tolist(),
|
99 |
+
"image_alignment_all": img_img_sim.tolist(),
|
100 |
+
}
|
101 |
+
|
102 |
+
|
103 |
+
def prepare_clip_model_sets(eval_clip_id: str = "openai/clip-vit-large-patch14"):
|
104 |
+
text_model = CLIPTextModelWithProjection.from_pretrained(eval_clip_id)
|
105 |
+
tokenizer = CLIPTokenizer.from_pretrained(eval_clip_id)
|
106 |
+
vis_model = CLIPVisionModelWithProjection.from_pretrained(eval_clip_id)
|
107 |
+
processor = CLIPProcessor.from_pretrained(eval_clip_id)
|
108 |
+
|
109 |
+
return text_model, tokenizer, vis_model, processor
|
110 |
+
|
111 |
+
|
112 |
+
def evaluate_pipe(
|
113 |
+
pipe,
|
114 |
+
target_images: List[Image.Image],
|
115 |
+
class_token: str = "",
|
116 |
+
learnt_token: str = "",
|
117 |
+
guidance_scale: float = 5.0,
|
118 |
+
seed=0,
|
119 |
+
clip_model_sets=None,
|
120 |
+
eval_clip_id: str = "openai/clip-vit-large-patch14",
|
121 |
+
n_test: int = 10,
|
122 |
+
n_step: int = 50,
|
123 |
+
):
|
124 |
+
|
125 |
+
if clip_model_sets is not None:
|
126 |
+
text_model, tokenizer, vis_model, processor = clip_model_sets
|
127 |
+
else:
|
128 |
+
text_model, tokenizer, vis_model, processor = prepare_clip_model_sets(
|
129 |
+
eval_clip_id
|
130 |
+
)
|
131 |
+
|
132 |
+
images = []
|
133 |
+
img_embeds = []
|
134 |
+
text_embeds = []
|
135 |
+
for prompt in EXAMPLE_PROMPTS[:n_test]:
|
136 |
+
prompt = prompt.replace("<obj>", learnt_token)
|
137 |
+
torch.manual_seed(seed)
|
138 |
+
with torch.autocast("cuda"):
|
139 |
+
img = pipe(
|
140 |
+
prompt, num_inference_steps=n_step, guidance_scale=guidance_scale
|
141 |
+
).images[0]
|
142 |
+
images.append(img)
|
143 |
+
|
144 |
+
# image
|
145 |
+
inputs = processor(images=img, return_tensors="pt")
|
146 |
+
img_embed = vis_model(**inputs).image_embeds
|
147 |
+
img_embeds.append(img_embed)
|
148 |
+
|
149 |
+
prompt = prompt.replace(learnt_token, class_token)
|
150 |
+
# prompts
|
151 |
+
inputs = tokenizer([prompt], padding=True, return_tensors="pt")
|
152 |
+
outputs = text_model(**inputs)
|
153 |
+
text_embed = outputs.text_embeds
|
154 |
+
text_embeds.append(text_embed)
|
155 |
+
|
156 |
+
# target images
|
157 |
+
inputs = processor(images=target_images, return_tensors="pt")
|
158 |
+
target_img_embeds = vis_model(**inputs).image_embeds
|
159 |
+
|
160 |
+
img_embeds = torch.cat(img_embeds, dim=0)
|
161 |
+
text_embeds = torch.cat(text_embeds, dim=0)
|
162 |
+
|
163 |
+
return text_img_alignment(img_embeds, text_embeds, target_img_embeds)
|
164 |
+
|
165 |
+
|
166 |
+
def visualize_progress(
|
167 |
+
path_alls: Union[str, List[str]],
|
168 |
+
prompt: str,
|
169 |
+
model_id: str = "runwayml/stable-diffusion-v1-5",
|
170 |
+
device="cuda:0",
|
171 |
+
patch_unet=True,
|
172 |
+
patch_text=True,
|
173 |
+
patch_ti=True,
|
174 |
+
unet_scale=1.0,
|
175 |
+
text_sclae=1.0,
|
176 |
+
num_inference_steps=50,
|
177 |
+
guidance_scale=5.0,
|
178 |
+
offset: int = 0,
|
179 |
+
limit: int = 10,
|
180 |
+
seed: int = 0,
|
181 |
+
):
|
182 |
+
|
183 |
+
imgs = []
|
184 |
+
if isinstance(path_alls, str):
|
185 |
+
alls = list(set(glob.glob(path_alls)))
|
186 |
+
|
187 |
+
alls.sort(key=os.path.getmtime)
|
188 |
+
else:
|
189 |
+
alls = path_alls
|
190 |
+
|
191 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
192 |
+
model_id, torch_dtype=torch.float16
|
193 |
+
).to(device)
|
194 |
+
|
195 |
+
print(f"Found {len(alls)} checkpoints")
|
196 |
+
for path in alls[offset:limit]:
|
197 |
+
print(path)
|
198 |
+
|
199 |
+
patch_pipe(
|
200 |
+
pipe, path, patch_unet=patch_unet, patch_text=patch_text, patch_ti=patch_ti
|
201 |
+
)
|
202 |
+
|
203 |
+
tune_lora_scale(pipe.unet, unet_scale)
|
204 |
+
tune_lora_scale(pipe.text_encoder, text_sclae)
|
205 |
+
|
206 |
+
torch.manual_seed(seed)
|
207 |
+
image = pipe(
|
208 |
+
prompt,
|
209 |
+
num_inference_steps=num_inference_steps,
|
210 |
+
guidance_scale=guidance_scale,
|
211 |
+
).images[0]
|
212 |
+
imgs.append(image)
|
213 |
+
|
214 |
+
return imgs
|
lora_diffusion/xformers_utils.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from diffusers.models.attention import BasicTransformerBlock
|
5 |
+
from diffusers.utils.import_utils import is_xformers_available
|
6 |
+
|
7 |
+
from .lora import LoraInjectedLinear
|
8 |
+
|
9 |
+
if is_xformers_available():
|
10 |
+
import xformers
|
11 |
+
import xformers.ops
|
12 |
+
else:
|
13 |
+
xformers = None
|
14 |
+
|
15 |
+
|
16 |
+
@functools.cache
|
17 |
+
def test_xformers_backwards(size):
|
18 |
+
@torch.enable_grad()
|
19 |
+
def _grad(size):
|
20 |
+
q = torch.randn((1, 4, size), device="cuda")
|
21 |
+
k = torch.randn((1, 4, size), device="cuda")
|
22 |
+
v = torch.randn((1, 4, size), device="cuda")
|
23 |
+
|
24 |
+
q = q.detach().requires_grad_()
|
25 |
+
k = k.detach().requires_grad_()
|
26 |
+
v = v.detach().requires_grad_()
|
27 |
+
|
28 |
+
out = xformers.ops.memory_efficient_attention(q, k, v)
|
29 |
+
loss = out.sum(2).mean(0).sum()
|
30 |
+
|
31 |
+
return torch.autograd.grad(loss, v)
|
32 |
+
|
33 |
+
try:
|
34 |
+
_grad(size)
|
35 |
+
print(size, "pass")
|
36 |
+
return True
|
37 |
+
except Exception as e:
|
38 |
+
print(size, "fail")
|
39 |
+
return False
|
40 |
+
|
41 |
+
|
42 |
+
def set_use_memory_efficient_attention_xformers(
|
43 |
+
module: torch.nn.Module, valid: bool
|
44 |
+
) -> None:
|
45 |
+
def fn_test_dim_head(module: torch.nn.Module):
|
46 |
+
if isinstance(module, BasicTransformerBlock):
|
47 |
+
# dim_head isn't stored anywhere, so back-calculate
|
48 |
+
source = module.attn1.to_v
|
49 |
+
if isinstance(source, LoraInjectedLinear):
|
50 |
+
source = source.linear
|
51 |
+
|
52 |
+
dim_head = source.out_features // module.attn1.heads
|
53 |
+
|
54 |
+
result = test_xformers_backwards(dim_head)
|
55 |
+
|
56 |
+
# If dim_head > dim_head_max, turn xformers off
|
57 |
+
if not result:
|
58 |
+
module.set_use_memory_efficient_attention_xformers(False)
|
59 |
+
|
60 |
+
for child in module.children():
|
61 |
+
fn_test_dim_head(child)
|
62 |
+
|
63 |
+
if not is_xformers_available() and valid:
|
64 |
+
print("XFormers is not available. Skipping.")
|
65 |
+
return
|
66 |
+
|
67 |
+
module.set_use_memory_efficient_attention_xformers(valid)
|
68 |
+
|
69 |
+
if valid:
|
70 |
+
fn_test_dim_head(module)
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
diffusers
|
2 |
+
accelerate
|
3 |
+
transformers>=4.25.1
|
train_dreambooth_cloneofsimo_lora.py
ADDED
@@ -0,0 +1,1008 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Bootstrapped from:
|
2 |
+
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import hashlib
|
6 |
+
import itertools
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
import inspect
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import Optional
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
import torch.utils.checkpoint
|
16 |
+
|
17 |
+
|
18 |
+
from accelerate import Accelerator
|
19 |
+
from accelerate.logging import get_logger
|
20 |
+
from accelerate.utils import set_seed
|
21 |
+
from diffusers import (
|
22 |
+
AutoencoderKL,
|
23 |
+
DDPMScheduler,
|
24 |
+
StableDiffusionPipeline,
|
25 |
+
UNet2DConditionModel,
|
26 |
+
)
|
27 |
+
from diffusers.optimization import get_scheduler
|
28 |
+
from huggingface_hub import HfFolder, Repository, whoami
|
29 |
+
|
30 |
+
from tqdm.auto import tqdm
|
31 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
32 |
+
|
33 |
+
from lora_diffusion import (
|
34 |
+
extract_lora_ups_down,
|
35 |
+
inject_trainable_lora,
|
36 |
+
safetensors_available,
|
37 |
+
save_lora_weight,
|
38 |
+
save_safeloras,
|
39 |
+
)
|
40 |
+
from lora_diffusion.xformers_utils import set_use_memory_efficient_attention_xformers
|
41 |
+
from PIL import Image
|
42 |
+
from torch.utils.data import Dataset
|
43 |
+
from torchvision import transforms
|
44 |
+
|
45 |
+
from pathlib import Path
|
46 |
+
|
47 |
+
import random
|
48 |
+
import re
|
49 |
+
|
50 |
+
|
51 |
+
class DreamBoothDataset(Dataset):
|
52 |
+
"""
|
53 |
+
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
54 |
+
It pre-processes the images and the tokenizes prompts.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
instance_data_root,
|
60 |
+
instance_prompt,
|
61 |
+
tokenizer,
|
62 |
+
class_data_root=None,
|
63 |
+
class_prompt=None,
|
64 |
+
size=512,
|
65 |
+
center_crop=False,
|
66 |
+
color_jitter=False,
|
67 |
+
h_flip=False,
|
68 |
+
resize=False,
|
69 |
+
):
|
70 |
+
self.size = size
|
71 |
+
self.center_crop = center_crop
|
72 |
+
self.tokenizer = tokenizer
|
73 |
+
self.resize = resize
|
74 |
+
|
75 |
+
self.instance_data_root = Path(instance_data_root)
|
76 |
+
if not self.instance_data_root.exists():
|
77 |
+
raise ValueError("Instance images root doesn't exists.")
|
78 |
+
|
79 |
+
self.instance_images_path = list(Path(instance_data_root).iterdir())
|
80 |
+
self.num_instance_images = len(self.instance_images_path)
|
81 |
+
self.instance_prompt = instance_prompt
|
82 |
+
self._length = self.num_instance_images
|
83 |
+
|
84 |
+
if class_data_root is not None:
|
85 |
+
self.class_data_root = Path(class_data_root)
|
86 |
+
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
87 |
+
self.class_images_path = list(self.class_data_root.iterdir())
|
88 |
+
self.num_class_images = len(self.class_images_path)
|
89 |
+
self._length = max(self.num_class_images, self.num_instance_images)
|
90 |
+
self.class_prompt = class_prompt
|
91 |
+
else:
|
92 |
+
self.class_data_root = None
|
93 |
+
|
94 |
+
img_transforms = []
|
95 |
+
|
96 |
+
if resize:
|
97 |
+
img_transforms.append(
|
98 |
+
transforms.Resize(
|
99 |
+
size, interpolation=transforms.InterpolationMode.BILINEAR
|
100 |
+
)
|
101 |
+
)
|
102 |
+
if center_crop:
|
103 |
+
img_transforms.append(transforms.CenterCrop(size))
|
104 |
+
if color_jitter:
|
105 |
+
img_transforms.append(transforms.ColorJitter(0.2, 0.1))
|
106 |
+
if h_flip:
|
107 |
+
img_transforms.append(transforms.RandomHorizontalFlip())
|
108 |
+
|
109 |
+
self.image_transforms = transforms.Compose(
|
110 |
+
[*img_transforms, transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
|
111 |
+
)
|
112 |
+
|
113 |
+
def __len__(self):
|
114 |
+
return self._length
|
115 |
+
|
116 |
+
def __getitem__(self, index):
|
117 |
+
example = {}
|
118 |
+
instance_image = Image.open(
|
119 |
+
self.instance_images_path[index % self.num_instance_images]
|
120 |
+
)
|
121 |
+
if not instance_image.mode == "RGB":
|
122 |
+
instance_image = instance_image.convert("RGB")
|
123 |
+
example["instance_images"] = self.image_transforms(instance_image)
|
124 |
+
example["instance_prompt_ids"] = self.tokenizer(
|
125 |
+
self.instance_prompt,
|
126 |
+
padding="do_not_pad",
|
127 |
+
truncation=True,
|
128 |
+
max_length=self.tokenizer.model_max_length,
|
129 |
+
).input_ids
|
130 |
+
|
131 |
+
if self.class_data_root:
|
132 |
+
class_image = Image.open(
|
133 |
+
self.class_images_path[index % self.num_class_images]
|
134 |
+
)
|
135 |
+
if not class_image.mode == "RGB":
|
136 |
+
class_image = class_image.convert("RGB")
|
137 |
+
example["class_images"] = self.image_transforms(class_image)
|
138 |
+
example["class_prompt_ids"] = self.tokenizer(
|
139 |
+
self.class_prompt,
|
140 |
+
padding="do_not_pad",
|
141 |
+
truncation=True,
|
142 |
+
max_length=self.tokenizer.model_max_length,
|
143 |
+
).input_ids
|
144 |
+
|
145 |
+
return example
|
146 |
+
|
147 |
+
|
148 |
+
class PromptDataset(Dataset):
|
149 |
+
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
|
150 |
+
|
151 |
+
def __init__(self, prompt, num_samples):
|
152 |
+
self.prompt = prompt
|
153 |
+
self.num_samples = num_samples
|
154 |
+
|
155 |
+
def __len__(self):
|
156 |
+
return self.num_samples
|
157 |
+
|
158 |
+
def __getitem__(self, index):
|
159 |
+
example = {}
|
160 |
+
example["prompt"] = self.prompt
|
161 |
+
example["index"] = index
|
162 |
+
return example
|
163 |
+
|
164 |
+
|
165 |
+
logger = get_logger(__name__)
|
166 |
+
|
167 |
+
|
168 |
+
def parse_args(input_args=None):
|
169 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
170 |
+
parser.add_argument(
|
171 |
+
"--pretrained_model_name_or_path",
|
172 |
+
type=str,
|
173 |
+
default=None,
|
174 |
+
required=True,
|
175 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
176 |
+
)
|
177 |
+
parser.add_argument(
|
178 |
+
"--pretrained_vae_name_or_path",
|
179 |
+
type=str,
|
180 |
+
default=None,
|
181 |
+
help="Path to pretrained vae or vae identifier from huggingface.co/models.",
|
182 |
+
)
|
183 |
+
parser.add_argument(
|
184 |
+
"--revision",
|
185 |
+
type=str,
|
186 |
+
default=None,
|
187 |
+
required=False,
|
188 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
189 |
+
)
|
190 |
+
parser.add_argument(
|
191 |
+
"--tokenizer_name",
|
192 |
+
type=str,
|
193 |
+
default=None,
|
194 |
+
help="Pretrained tokenizer name or path if not the same as model_name",
|
195 |
+
)
|
196 |
+
parser.add_argument(
|
197 |
+
"--instance_data_dir",
|
198 |
+
type=str,
|
199 |
+
default=None,
|
200 |
+
required=True,
|
201 |
+
help="A folder containing the training data of instance images.",
|
202 |
+
)
|
203 |
+
parser.add_argument(
|
204 |
+
"--class_data_dir",
|
205 |
+
type=str,
|
206 |
+
default=None,
|
207 |
+
required=False,
|
208 |
+
help="A folder containing the training data of class images.",
|
209 |
+
)
|
210 |
+
parser.add_argument(
|
211 |
+
"--instance_prompt",
|
212 |
+
type=str,
|
213 |
+
default=None,
|
214 |
+
required=True,
|
215 |
+
help="The prompt with identifier specifying the instance",
|
216 |
+
)
|
217 |
+
parser.add_argument(
|
218 |
+
"--class_prompt",
|
219 |
+
type=str,
|
220 |
+
default=None,
|
221 |
+
help="The prompt to specify images in the same class as provided instance images.",
|
222 |
+
)
|
223 |
+
parser.add_argument(
|
224 |
+
"--with_prior_preservation",
|
225 |
+
default=False,
|
226 |
+
action="store_true",
|
227 |
+
help="Flag to add prior preservation loss.",
|
228 |
+
)
|
229 |
+
parser.add_argument(
|
230 |
+
"--prior_loss_weight",
|
231 |
+
type=float,
|
232 |
+
default=1.0,
|
233 |
+
help="The weight of prior preservation loss.",
|
234 |
+
)
|
235 |
+
parser.add_argument(
|
236 |
+
"--num_class_images",
|
237 |
+
type=int,
|
238 |
+
default=100,
|
239 |
+
help=(
|
240 |
+
"Minimal class images for prior preservation loss. If not have enough images, additional images will be"
|
241 |
+
" sampled with class_prompt."
|
242 |
+
),
|
243 |
+
)
|
244 |
+
parser.add_argument(
|
245 |
+
"--output_dir",
|
246 |
+
type=str,
|
247 |
+
default="text-inversion-model",
|
248 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
249 |
+
)
|
250 |
+
parser.add_argument(
|
251 |
+
"--output_format",
|
252 |
+
type=str,
|
253 |
+
choices=["pt", "safe", "both"],
|
254 |
+
default="both",
|
255 |
+
help="The output format of the model predicitions and checkpoints.",
|
256 |
+
)
|
257 |
+
parser.add_argument(
|
258 |
+
"--seed", type=int, default=None, help="A seed for reproducible training."
|
259 |
+
)
|
260 |
+
parser.add_argument(
|
261 |
+
"--resolution",
|
262 |
+
type=int,
|
263 |
+
default=512,
|
264 |
+
help=(
|
265 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
266 |
+
" resolution"
|
267 |
+
),
|
268 |
+
)
|
269 |
+
parser.add_argument(
|
270 |
+
"--center_crop",
|
271 |
+
action="store_true",
|
272 |
+
help="Whether to center crop images before resizing to resolution",
|
273 |
+
)
|
274 |
+
parser.add_argument(
|
275 |
+
"--color_jitter",
|
276 |
+
action="store_true",
|
277 |
+
help="Whether to apply color jitter to images",
|
278 |
+
)
|
279 |
+
parser.add_argument(
|
280 |
+
"--train_text_encoder",
|
281 |
+
action="store_true",
|
282 |
+
help="Whether to train the text encoder",
|
283 |
+
)
|
284 |
+
parser.add_argument(
|
285 |
+
"--train_batch_size",
|
286 |
+
type=int,
|
287 |
+
default=4,
|
288 |
+
help="Batch size (per device) for the training dataloader.",
|
289 |
+
)
|
290 |
+
parser.add_argument(
|
291 |
+
"--sample_batch_size",
|
292 |
+
type=int,
|
293 |
+
default=4,
|
294 |
+
help="Batch size (per device) for sampling images.",
|
295 |
+
)
|
296 |
+
parser.add_argument("--num_train_epochs", type=int, default=1)
|
297 |
+
parser.add_argument(
|
298 |
+
"--max_train_steps",
|
299 |
+
type=int,
|
300 |
+
default=None,
|
301 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
302 |
+
)
|
303 |
+
parser.add_argument(
|
304 |
+
"--save_steps",
|
305 |
+
type=int,
|
306 |
+
default=500,
|
307 |
+
help="Save checkpoint every X updates steps.",
|
308 |
+
)
|
309 |
+
parser.add_argument(
|
310 |
+
"--gradient_accumulation_steps",
|
311 |
+
type=int,
|
312 |
+
default=1,
|
313 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
314 |
+
)
|
315 |
+
parser.add_argument(
|
316 |
+
"--gradient_checkpointing",
|
317 |
+
action="store_true",
|
318 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
319 |
+
)
|
320 |
+
parser.add_argument(
|
321 |
+
"--lora_rank",
|
322 |
+
type=int,
|
323 |
+
default=4,
|
324 |
+
help="Rank of LoRA approximation.",
|
325 |
+
)
|
326 |
+
parser.add_argument(
|
327 |
+
"--learning_rate",
|
328 |
+
type=float,
|
329 |
+
default=None,
|
330 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
331 |
+
)
|
332 |
+
parser.add_argument(
|
333 |
+
"--learning_rate_text",
|
334 |
+
type=float,
|
335 |
+
default=5e-6,
|
336 |
+
help="Initial learning rate for text encoder (after the potential warmup period) to use.",
|
337 |
+
)
|
338 |
+
parser.add_argument(
|
339 |
+
"--scale_lr",
|
340 |
+
action="store_true",
|
341 |
+
default=False,
|
342 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
343 |
+
)
|
344 |
+
parser.add_argument(
|
345 |
+
"--lr_scheduler",
|
346 |
+
type=str,
|
347 |
+
default="constant",
|
348 |
+
help=(
|
349 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
350 |
+
' "constant", "constant_with_warmup"]'
|
351 |
+
),
|
352 |
+
)
|
353 |
+
parser.add_argument(
|
354 |
+
"--lr_warmup_steps",
|
355 |
+
type=int,
|
356 |
+
default=500,
|
357 |
+
help="Number of steps for the warmup in the lr scheduler.",
|
358 |
+
)
|
359 |
+
parser.add_argument(
|
360 |
+
"--use_8bit_adam",
|
361 |
+
action="store_true",
|
362 |
+
help="Whether or not to use 8-bit Adam from bitsandbytes.",
|
363 |
+
)
|
364 |
+
parser.add_argument(
|
365 |
+
"--adam_beta1",
|
366 |
+
type=float,
|
367 |
+
default=0.9,
|
368 |
+
help="The beta1 parameter for the Adam optimizer.",
|
369 |
+
)
|
370 |
+
parser.add_argument(
|
371 |
+
"--adam_beta2",
|
372 |
+
type=float,
|
373 |
+
default=0.999,
|
374 |
+
help="The beta2 parameter for the Adam optimizer.",
|
375 |
+
)
|
376 |
+
parser.add_argument(
|
377 |
+
"--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
|
378 |
+
)
|
379 |
+
parser.add_argument(
|
380 |
+
"--adam_epsilon",
|
381 |
+
type=float,
|
382 |
+
default=1e-08,
|
383 |
+
help="Epsilon value for the Adam optimizer",
|
384 |
+
)
|
385 |
+
parser.add_argument(
|
386 |
+
"--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
|
387 |
+
)
|
388 |
+
parser.add_argument(
|
389 |
+
"--push_to_hub",
|
390 |
+
action="store_true",
|
391 |
+
help="Whether or not to push the model to the Hub.",
|
392 |
+
)
|
393 |
+
parser.add_argument(
|
394 |
+
"--hub_token",
|
395 |
+
type=str,
|
396 |
+
default=None,
|
397 |
+
help="The token to use to push to the Model Hub.",
|
398 |
+
)
|
399 |
+
parser.add_argument(
|
400 |
+
"--logging_dir",
|
401 |
+
type=str,
|
402 |
+
default="logs",
|
403 |
+
help=(
|
404 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
405 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
406 |
+
),
|
407 |
+
)
|
408 |
+
parser.add_argument(
|
409 |
+
"--mixed_precision",
|
410 |
+
type=str,
|
411 |
+
default=None,
|
412 |
+
choices=["no", "fp16", "bf16"],
|
413 |
+
help=(
|
414 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
415 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
416 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
417 |
+
),
|
418 |
+
)
|
419 |
+
parser.add_argument(
|
420 |
+
"--local_rank",
|
421 |
+
type=int,
|
422 |
+
default=-1,
|
423 |
+
help="For distributed training: local_rank",
|
424 |
+
)
|
425 |
+
parser.add_argument(
|
426 |
+
"--resume_unet",
|
427 |
+
type=str,
|
428 |
+
default=None,
|
429 |
+
help=("File path for unet lora to resume training."),
|
430 |
+
)
|
431 |
+
parser.add_argument(
|
432 |
+
"--resume_text_encoder",
|
433 |
+
type=str,
|
434 |
+
default=None,
|
435 |
+
help=("File path for text encoder lora to resume training."),
|
436 |
+
)
|
437 |
+
parser.add_argument(
|
438 |
+
"--resize",
|
439 |
+
type=bool,
|
440 |
+
default=True,
|
441 |
+
required=False,
|
442 |
+
help="Should images be resized to --resolution before training?",
|
443 |
+
)
|
444 |
+
parser.add_argument(
|
445 |
+
"--use_xformers", action="store_true", help="Whether or not to use xformers"
|
446 |
+
)
|
447 |
+
|
448 |
+
if input_args is not None:
|
449 |
+
args = parser.parse_args(input_args)
|
450 |
+
else:
|
451 |
+
args = parser.parse_args()
|
452 |
+
|
453 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
454 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
455 |
+
args.local_rank = env_local_rank
|
456 |
+
|
457 |
+
if args.with_prior_preservation:
|
458 |
+
if args.class_data_dir is None:
|
459 |
+
raise ValueError("You must specify a data directory for class images.")
|
460 |
+
if args.class_prompt is None:
|
461 |
+
raise ValueError("You must specify prompt for class images.")
|
462 |
+
else:
|
463 |
+
if args.class_data_dir is not None:
|
464 |
+
logger.warning(
|
465 |
+
"You need not use --class_data_dir without --with_prior_preservation."
|
466 |
+
)
|
467 |
+
if args.class_prompt is not None:
|
468 |
+
logger.warning(
|
469 |
+
"You need not use --class_prompt without --with_prior_preservation."
|
470 |
+
)
|
471 |
+
|
472 |
+
if not safetensors_available:
|
473 |
+
if args.output_format == "both":
|
474 |
+
print(
|
475 |
+
"Safetensors is not available - changing output format to just output PyTorch files"
|
476 |
+
)
|
477 |
+
args.output_format = "pt"
|
478 |
+
elif args.output_format == "safe":
|
479 |
+
raise ValueError(
|
480 |
+
"Safetensors is not available - either install it, or change output_format."
|
481 |
+
)
|
482 |
+
|
483 |
+
return args
|
484 |
+
|
485 |
+
|
486 |
+
def main(args):
|
487 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
488 |
+
|
489 |
+
accelerator = Accelerator(
|
490 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
491 |
+
mixed_precision=args.mixed_precision,
|
492 |
+
log_with="tensorboard",
|
493 |
+
logging_dir=logging_dir,
|
494 |
+
)
|
495 |
+
|
496 |
+
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
|
497 |
+
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
|
498 |
+
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
499 |
+
if (
|
500 |
+
args.train_text_encoder
|
501 |
+
and args.gradient_accumulation_steps > 1
|
502 |
+
and accelerator.num_processes > 1
|
503 |
+
):
|
504 |
+
raise ValueError(
|
505 |
+
"Gradient accumulation is not supported when training the text encoder in distributed training. "
|
506 |
+
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
|
507 |
+
)
|
508 |
+
|
509 |
+
if args.seed is not None:
|
510 |
+
set_seed(args.seed)
|
511 |
+
|
512 |
+
if args.with_prior_preservation:
|
513 |
+
class_images_dir = Path(args.class_data_dir)
|
514 |
+
if not class_images_dir.exists():
|
515 |
+
class_images_dir.mkdir(parents=True)
|
516 |
+
cur_class_images = len(list(class_images_dir.iterdir()))
|
517 |
+
|
518 |
+
if cur_class_images < args.num_class_images:
|
519 |
+
torch_dtype = (
|
520 |
+
torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
521 |
+
)
|
522 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
523 |
+
args.pretrained_model_name_or_path,
|
524 |
+
torch_dtype=torch_dtype,
|
525 |
+
safety_checker=None,
|
526 |
+
revision=args.revision,
|
527 |
+
)
|
528 |
+
pipeline.set_progress_bar_config(disable=True)
|
529 |
+
|
530 |
+
num_new_images = args.num_class_images - cur_class_images
|
531 |
+
logger.info(f"Number of class images to sample: {num_new_images}.")
|
532 |
+
|
533 |
+
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
|
534 |
+
sample_dataloader = torch.utils.data.DataLoader(
|
535 |
+
sample_dataset, batch_size=args.sample_batch_size
|
536 |
+
)
|
537 |
+
|
538 |
+
sample_dataloader = accelerator.prepare(sample_dataloader)
|
539 |
+
pipeline.to(accelerator.device)
|
540 |
+
|
541 |
+
for example in tqdm(
|
542 |
+
sample_dataloader,
|
543 |
+
desc="Generating class images",
|
544 |
+
disable=not accelerator.is_local_main_process,
|
545 |
+
):
|
546 |
+
images = pipeline(example["prompt"]).images
|
547 |
+
|
548 |
+
for i, image in enumerate(images):
|
549 |
+
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
|
550 |
+
image_filename = (
|
551 |
+
class_images_dir
|
552 |
+
/ f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
553 |
+
)
|
554 |
+
image.save(image_filename)
|
555 |
+
|
556 |
+
del pipeline
|
557 |
+
if torch.cuda.is_available():
|
558 |
+
torch.cuda.empty_cache()
|
559 |
+
|
560 |
+
# Handle the repository creation
|
561 |
+
if accelerator.is_main_process:
|
562 |
+
|
563 |
+
if args.output_dir is not None:
|
564 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
565 |
+
|
566 |
+
# Load the tokenizer
|
567 |
+
if args.tokenizer_name:
|
568 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
569 |
+
args.tokenizer_name,
|
570 |
+
revision=args.revision,
|
571 |
+
)
|
572 |
+
elif args.pretrained_model_name_or_path:
|
573 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
574 |
+
args.pretrained_model_name_or_path,
|
575 |
+
subfolder="tokenizer",
|
576 |
+
revision=args.revision,
|
577 |
+
)
|
578 |
+
|
579 |
+
# Load models and create wrapper for stable diffusion
|
580 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
581 |
+
args.pretrained_model_name_or_path,
|
582 |
+
subfolder="text_encoder",
|
583 |
+
revision=args.revision,
|
584 |
+
)
|
585 |
+
vae = AutoencoderKL.from_pretrained(
|
586 |
+
args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,
|
587 |
+
subfolder=None if args.pretrained_vae_name_or_path else "vae",
|
588 |
+
revision=None if args.pretrained_vae_name_or_path else args.revision,
|
589 |
+
)
|
590 |
+
unet = UNet2DConditionModel.from_pretrained(
|
591 |
+
args.pretrained_model_name_or_path,
|
592 |
+
subfolder="unet",
|
593 |
+
revision=args.revision,
|
594 |
+
)
|
595 |
+
unet.requires_grad_(False)
|
596 |
+
unet_lora_params, _ = inject_trainable_lora(
|
597 |
+
unet, r=args.lora_rank, loras=args.resume_unet
|
598 |
+
)
|
599 |
+
|
600 |
+
for _up, _down in extract_lora_ups_down(unet):
|
601 |
+
print("Before training: Unet First Layer lora up", _up.weight.data)
|
602 |
+
print("Before training: Unet First Layer lora down", _down.weight.data)
|
603 |
+
break
|
604 |
+
|
605 |
+
vae.requires_grad_(False)
|
606 |
+
text_encoder.requires_grad_(False)
|
607 |
+
|
608 |
+
if args.train_text_encoder:
|
609 |
+
text_encoder_lora_params, _ = inject_trainable_lora(
|
610 |
+
text_encoder,
|
611 |
+
target_replace_module=["CLIPAttention"],
|
612 |
+
r=args.lora_rank,
|
613 |
+
)
|
614 |
+
for _up, _down in extract_lora_ups_down(
|
615 |
+
text_encoder, target_replace_module=["CLIPAttention"]
|
616 |
+
):
|
617 |
+
print("Before training: text encoder First Layer lora up", _up.weight.data)
|
618 |
+
print(
|
619 |
+
"Before training: text encoder First Layer lora down", _down.weight.data
|
620 |
+
)
|
621 |
+
break
|
622 |
+
|
623 |
+
if args.use_xformers:
|
624 |
+
set_use_memory_efficient_attention_xformers(unet, True)
|
625 |
+
set_use_memory_efficient_attention_xformers(vae, True)
|
626 |
+
|
627 |
+
if args.gradient_checkpointing:
|
628 |
+
unet.enable_gradient_checkpointing()
|
629 |
+
if args.train_text_encoder:
|
630 |
+
text_encoder.gradient_checkpointing_enable()
|
631 |
+
|
632 |
+
if args.scale_lr:
|
633 |
+
args.learning_rate = (
|
634 |
+
args.learning_rate
|
635 |
+
* args.gradient_accumulation_steps
|
636 |
+
* args.train_batch_size
|
637 |
+
* accelerator.num_processes
|
638 |
+
)
|
639 |
+
|
640 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
641 |
+
if args.use_8bit_adam:
|
642 |
+
try:
|
643 |
+
import bitsandbytes as bnb
|
644 |
+
except ImportError:
|
645 |
+
raise ImportError(
|
646 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
647 |
+
)
|
648 |
+
|
649 |
+
optimizer_class = bnb.optim.AdamW8bit
|
650 |
+
else:
|
651 |
+
optimizer_class = torch.optim.AdamW
|
652 |
+
|
653 |
+
text_lr = (
|
654 |
+
args.learning_rate
|
655 |
+
if args.learning_rate_text is None
|
656 |
+
else args.learning_rate_text
|
657 |
+
)
|
658 |
+
|
659 |
+
params_to_optimize = (
|
660 |
+
[
|
661 |
+
{"params": itertools.chain(*unet_lora_params), "lr": args.learning_rate},
|
662 |
+
{
|
663 |
+
"params": itertools.chain(*text_encoder_lora_params),
|
664 |
+
"lr": text_lr,
|
665 |
+
},
|
666 |
+
]
|
667 |
+
if args.train_text_encoder
|
668 |
+
else itertools.chain(*unet_lora_params)
|
669 |
+
)
|
670 |
+
optimizer = optimizer_class(
|
671 |
+
params_to_optimize,
|
672 |
+
lr=args.learning_rate,
|
673 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
674 |
+
weight_decay=args.adam_weight_decay,
|
675 |
+
eps=args.adam_epsilon,
|
676 |
+
)
|
677 |
+
|
678 |
+
noise_scheduler = DDPMScheduler.from_config(
|
679 |
+
args.pretrained_model_name_or_path, subfolder="scheduler"
|
680 |
+
)
|
681 |
+
|
682 |
+
train_dataset = DreamBoothDataset(
|
683 |
+
instance_data_root=args.instance_data_dir,
|
684 |
+
instance_prompt=args.instance_prompt,
|
685 |
+
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
686 |
+
class_prompt=args.class_prompt,
|
687 |
+
tokenizer=tokenizer,
|
688 |
+
size=args.resolution,
|
689 |
+
center_crop=args.center_crop,
|
690 |
+
color_jitter=args.color_jitter,
|
691 |
+
resize=args.resize,
|
692 |
+
)
|
693 |
+
|
694 |
+
def collate_fn(examples):
|
695 |
+
input_ids = [example["instance_prompt_ids"] for example in examples]
|
696 |
+
pixel_values = [example["instance_images"] for example in examples]
|
697 |
+
|
698 |
+
# Concat class and instance examples for prior preservation.
|
699 |
+
# We do this to avoid doing two forward passes.
|
700 |
+
if args.with_prior_preservation:
|
701 |
+
input_ids += [example["class_prompt_ids"] for example in examples]
|
702 |
+
pixel_values += [example["class_images"] for example in examples]
|
703 |
+
|
704 |
+
pixel_values = torch.stack(pixel_values)
|
705 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
706 |
+
|
707 |
+
input_ids = tokenizer.pad(
|
708 |
+
{"input_ids": input_ids},
|
709 |
+
padding="max_length",
|
710 |
+
max_length=tokenizer.model_max_length,
|
711 |
+
return_tensors="pt",
|
712 |
+
).input_ids
|
713 |
+
|
714 |
+
batch = {
|
715 |
+
"input_ids": input_ids,
|
716 |
+
"pixel_values": pixel_values,
|
717 |
+
}
|
718 |
+
return batch
|
719 |
+
|
720 |
+
train_dataloader = torch.utils.data.DataLoader(
|
721 |
+
train_dataset,
|
722 |
+
batch_size=args.train_batch_size,
|
723 |
+
shuffle=True,
|
724 |
+
collate_fn=collate_fn,
|
725 |
+
num_workers=1,
|
726 |
+
)
|
727 |
+
|
728 |
+
# Scheduler and math around the number of training steps.
|
729 |
+
overrode_max_train_steps = False
|
730 |
+
num_update_steps_per_epoch = math.ceil(
|
731 |
+
len(train_dataloader) / args.gradient_accumulation_steps
|
732 |
+
)
|
733 |
+
if args.max_train_steps is None:
|
734 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
735 |
+
overrode_max_train_steps = True
|
736 |
+
|
737 |
+
lr_scheduler = get_scheduler(
|
738 |
+
args.lr_scheduler,
|
739 |
+
optimizer=optimizer,
|
740 |
+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
741 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
742 |
+
)
|
743 |
+
|
744 |
+
if args.train_text_encoder:
|
745 |
+
(
|
746 |
+
unet,
|
747 |
+
text_encoder,
|
748 |
+
optimizer,
|
749 |
+
train_dataloader,
|
750 |
+
lr_scheduler,
|
751 |
+
) = accelerator.prepare(
|
752 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
753 |
+
)
|
754 |
+
else:
|
755 |
+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
756 |
+
unet, optimizer, train_dataloader, lr_scheduler
|
757 |
+
)
|
758 |
+
|
759 |
+
weight_dtype = torch.float32
|
760 |
+
if accelerator.mixed_precision == "fp16":
|
761 |
+
weight_dtype = torch.float16
|
762 |
+
elif accelerator.mixed_precision == "bf16":
|
763 |
+
weight_dtype = torch.bfloat16
|
764 |
+
|
765 |
+
# Move text_encode and vae to gpu.
|
766 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
767 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
768 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
769 |
+
if not args.train_text_encoder:
|
770 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
771 |
+
|
772 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
773 |
+
num_update_steps_per_epoch = math.ceil(
|
774 |
+
len(train_dataloader) / args.gradient_accumulation_steps
|
775 |
+
)
|
776 |
+
if overrode_max_train_steps:
|
777 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
778 |
+
# Afterwards we recalculate our number of training epochs
|
779 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
780 |
+
|
781 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
782 |
+
# The trackers initializes automatically on the main process.
|
783 |
+
if accelerator.is_main_process:
|
784 |
+
accelerator.init_trackers("dreambooth", config=vars(args))
|
785 |
+
|
786 |
+
# Train!
|
787 |
+
total_batch_size = (
|
788 |
+
args.train_batch_size
|
789 |
+
* accelerator.num_processes
|
790 |
+
* args.gradient_accumulation_steps
|
791 |
+
)
|
792 |
+
|
793 |
+
print("***** Running training *****")
|
794 |
+
print(f" Num examples = {len(train_dataset)}")
|
795 |
+
print(f" Num batches each epoch = {len(train_dataloader)}")
|
796 |
+
print(f" Num Epochs = {args.num_train_epochs}")
|
797 |
+
print(f" Instantaneous batch size per device = {args.train_batch_size}")
|
798 |
+
print(
|
799 |
+
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
800 |
+
)
|
801 |
+
print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
802 |
+
print(f" Total optimization steps = {args.max_train_steps}")
|
803 |
+
# Only show the progress bar once on each machine.
|
804 |
+
progress_bar = tqdm(
|
805 |
+
range(args.max_train_steps), disable=not accelerator.is_local_main_process
|
806 |
+
)
|
807 |
+
progress_bar.set_description("Steps")
|
808 |
+
global_step = 0
|
809 |
+
last_save = 0
|
810 |
+
|
811 |
+
for epoch in range(args.num_train_epochs):
|
812 |
+
unet.train()
|
813 |
+
if args.train_text_encoder:
|
814 |
+
text_encoder.train()
|
815 |
+
|
816 |
+
for step, batch in enumerate(train_dataloader):
|
817 |
+
# Convert images to latent space
|
818 |
+
latents = vae.encode(
|
819 |
+
batch["pixel_values"].to(dtype=weight_dtype)
|
820 |
+
).latent_dist.sample()
|
821 |
+
latents = latents * 0.18215
|
822 |
+
|
823 |
+
# Sample noise that we'll add to the latents
|
824 |
+
noise = torch.randn_like(latents)
|
825 |
+
bsz = latents.shape[0]
|
826 |
+
# Sample a random timestep for each image
|
827 |
+
timesteps = torch.randint(
|
828 |
+
0,
|
829 |
+
noise_scheduler.config.num_train_timesteps,
|
830 |
+
(bsz,),
|
831 |
+
device=latents.device,
|
832 |
+
)
|
833 |
+
timesteps = timesteps.long()
|
834 |
+
|
835 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
836 |
+
# (this is the forward diffusion process)
|
837 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
838 |
+
|
839 |
+
# Get the text embedding for conditioning
|
840 |
+
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
841 |
+
|
842 |
+
# Predict the noise residual
|
843 |
+
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
844 |
+
|
845 |
+
# Get the target for loss depending on the prediction type
|
846 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
847 |
+
target = noise
|
848 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
849 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
850 |
+
else:
|
851 |
+
raise ValueError(
|
852 |
+
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
|
853 |
+
)
|
854 |
+
|
855 |
+
if args.with_prior_preservation:
|
856 |
+
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
857 |
+
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
858 |
+
target, target_prior = torch.chunk(target, 2, dim=0)
|
859 |
+
|
860 |
+
# Compute instance loss
|
861 |
+
loss = (
|
862 |
+
F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
863 |
+
.mean([1, 2, 3])
|
864 |
+
.mean()
|
865 |
+
)
|
866 |
+
|
867 |
+
# Compute prior loss
|
868 |
+
prior_loss = F.mse_loss(
|
869 |
+
model_pred_prior.float(), target_prior.float(), reduction="mean"
|
870 |
+
)
|
871 |
+
|
872 |
+
# Add the prior loss to the instance loss.
|
873 |
+
loss = loss + args.prior_loss_weight * prior_loss
|
874 |
+
else:
|
875 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
876 |
+
|
877 |
+
accelerator.backward(loss)
|
878 |
+
if accelerator.sync_gradients:
|
879 |
+
params_to_clip = (
|
880 |
+
itertools.chain(unet.parameters(), text_encoder.parameters())
|
881 |
+
if args.train_text_encoder
|
882 |
+
else unet.parameters()
|
883 |
+
)
|
884 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
885 |
+
optimizer.step()
|
886 |
+
lr_scheduler.step()
|
887 |
+
progress_bar.update(1)
|
888 |
+
optimizer.zero_grad()
|
889 |
+
|
890 |
+
global_step += 1
|
891 |
+
|
892 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
893 |
+
if accelerator.sync_gradients:
|
894 |
+
if args.save_steps and global_step - last_save >= args.save_steps:
|
895 |
+
if accelerator.is_main_process:
|
896 |
+
# newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing
|
897 |
+
# it, the models will be unwrapped, and when they are then used for further training,
|
898 |
+
# we will crash. pass this, but only to newer versions of accelerate. fixes
|
899 |
+
# https://github.com/huggingface/diffusers/issues/1566
|
900 |
+
accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
|
901 |
+
inspect.signature(
|
902 |
+
accelerator.unwrap_model
|
903 |
+
).parameters.keys()
|
904 |
+
)
|
905 |
+
extra_args = (
|
906 |
+
{"keep_fp32_wrapper": True}
|
907 |
+
if accepts_keep_fp32_wrapper
|
908 |
+
else {}
|
909 |
+
)
|
910 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
911 |
+
args.pretrained_model_name_or_path,
|
912 |
+
unet=accelerator.unwrap_model(unet, **extra_args),
|
913 |
+
text_encoder=accelerator.unwrap_model(
|
914 |
+
text_encoder, **extra_args
|
915 |
+
),
|
916 |
+
revision=args.revision,
|
917 |
+
)
|
918 |
+
|
919 |
+
filename_unet = (
|
920 |
+
f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt"
|
921 |
+
)
|
922 |
+
filename_text_encoder = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.text_encoder.pt"
|
923 |
+
print(f"save weights {filename_unet}, {filename_text_encoder}")
|
924 |
+
save_lora_weight(pipeline.unet, filename_unet)
|
925 |
+
if args.train_text_encoder:
|
926 |
+
save_lora_weight(
|
927 |
+
pipeline.text_encoder,
|
928 |
+
filename_text_encoder,
|
929 |
+
target_replace_module=["CLIPAttention"],
|
930 |
+
)
|
931 |
+
|
932 |
+
for _up, _down in extract_lora_ups_down(pipeline.unet):
|
933 |
+
print(
|
934 |
+
"First Unet Layer's Up Weight is now : ",
|
935 |
+
_up.weight.data,
|
936 |
+
)
|
937 |
+
print(
|
938 |
+
"First Unet Layer's Down Weight is now : ",
|
939 |
+
_down.weight.data,
|
940 |
+
)
|
941 |
+
break
|
942 |
+
if args.train_text_encoder:
|
943 |
+
for _up, _down in extract_lora_ups_down(
|
944 |
+
pipeline.text_encoder,
|
945 |
+
target_replace_module=["CLIPAttention"],
|
946 |
+
):
|
947 |
+
print(
|
948 |
+
"First Text Encoder Layer's Up Weight is now : ",
|
949 |
+
_up.weight.data,
|
950 |
+
)
|
951 |
+
print(
|
952 |
+
"First Text Encoder Layer's Down Weight is now : ",
|
953 |
+
_down.weight.data,
|
954 |
+
)
|
955 |
+
break
|
956 |
+
|
957 |
+
last_save = global_step
|
958 |
+
|
959 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
960 |
+
progress_bar.set_postfix(**logs)
|
961 |
+
accelerator.log(logs, step=global_step)
|
962 |
+
|
963 |
+
if global_step >= args.max_train_steps:
|
964 |
+
break
|
965 |
+
|
966 |
+
accelerator.wait_for_everyone()
|
967 |
+
|
968 |
+
# Create the pipeline using using the trained modules and save it.
|
969 |
+
if accelerator.is_main_process:
|
970 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
971 |
+
args.pretrained_model_name_or_path,
|
972 |
+
unet=accelerator.unwrap_model(unet),
|
973 |
+
text_encoder=accelerator.unwrap_model(text_encoder),
|
974 |
+
revision=args.revision,
|
975 |
+
)
|
976 |
+
|
977 |
+
print("\n\nLora TRAINING DONE!\n\n")
|
978 |
+
|
979 |
+
if args.output_format == "pt" or args.output_format == "both":
|
980 |
+
save_lora_weight(pipeline.unet, args.output_dir + "/lora_weight.pt")
|
981 |
+
if args.train_text_encoder:
|
982 |
+
save_lora_weight(
|
983 |
+
pipeline.text_encoder,
|
984 |
+
args.output_dir + "/lora_weight.text_encoder.pt",
|
985 |
+
target_replace_module=["CLIPAttention"],
|
986 |
+
)
|
987 |
+
|
988 |
+
if args.output_format == "safe" or args.output_format == "both":
|
989 |
+
loras = {}
|
990 |
+
loras["unet"] = (pipeline.unet, {"CrossAttention", "Attention", "GEGLU"})
|
991 |
+
if args.train_text_encoder:
|
992 |
+
loras["text_encoder"] = (pipeline.text_encoder, {"CLIPAttention"})
|
993 |
+
|
994 |
+
save_safeloras(loras, args.output_dir + "/lora_weight.safetensors")
|
995 |
+
|
996 |
+
if args.push_to_hub:
|
997 |
+
repo.push_to_hub(
|
998 |
+
commit_message="End of training",
|
999 |
+
blocking=False,
|
1000 |
+
auto_lfs_prune=True,
|
1001 |
+
)
|
1002 |
+
|
1003 |
+
accelerator.end_training()
|
1004 |
+
|
1005 |
+
|
1006 |
+
if __name__ == "__main__":
|
1007 |
+
args = parse_args()
|
1008 |
+
main(args)
|