shaokun's picture
WIP
a4d7b31
raw
history blame
6.33 kB
from typing import Literal, Union, Dict
import os
import shutil
import fire
from diffusers import StableDiffusionPipeline
from safetensors.torch import safe_open, save_file
import torch
from .lora import (
tune_lora_scale,
patch_pipe,
collapse_lora,
monkeypatch_remove_lora,
)
from .lora_manager import lora_join
from .to_ckpt_v2 import convert_to_ckpt
def _text_lora_path(path: str) -> str:
assert path.endswith(".pt"), "Only .pt files are supported"
return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
def add(
path_1: str,
path_2: str,
output_path: str,
alpha_1: float = 0.5,
alpha_2: float = 0.5,
mode: Literal[
"lpl",
"upl",
"upl-ckpt-v2",
] = "lpl",
with_text_lora: bool = False,
):
print("Lora Add, mode " + mode)
if mode == "lpl":
if path_1.endswith(".pt") and path_2.endswith(".pt"):
for _path_1, _path_2, opt in [(path_1, path_2, "unet")] + (
[(_text_lora_path(path_1), _text_lora_path(path_2), "text_encoder")]
if with_text_lora
else []
):
print("Loading", _path_1, _path_2)
out_list = []
if opt == "text_encoder":
if not os.path.exists(_path_1):
print(f"No text encoder found in {_path_1}, skipping...")
continue
if not os.path.exists(_path_2):
print(f"No text encoder found in {_path_1}, skipping...")
continue
l1 = torch.load(_path_1)
l2 = torch.load(_path_2)
l1pairs = zip(l1[::2], l1[1::2])
l2pairs = zip(l2[::2], l2[1::2])
for (x1, y1), (x2, y2) in zip(l1pairs, l2pairs):
# print("Merging", x1.shape, y1.shape, x2.shape, y2.shape)
x1.data = alpha_1 * x1.data + alpha_2 * x2.data
y1.data = alpha_1 * y1.data + alpha_2 * y2.data
out_list.append(x1)
out_list.append(y1)
if opt == "unet":
print("Saving merged UNET to", output_path)
torch.save(out_list, output_path)
elif opt == "text_encoder":
print("Saving merged text encoder to", _text_lora_path(output_path))
torch.save(
out_list,
_text_lora_path(output_path),
)
elif path_1.endswith(".safetensors") and path_2.endswith(".safetensors"):
safeloras_1 = safe_open(path_1, framework="pt", device="cpu")
safeloras_2 = safe_open(path_2, framework="pt", device="cpu")
metadata = dict(safeloras_1.metadata())
metadata.update(dict(safeloras_2.metadata()))
ret_tensor = {}
for keys in set(list(safeloras_1.keys()) + list(safeloras_2.keys())):
if keys.startswith("text_encoder") or keys.startswith("unet"):
tens1 = safeloras_1.get_tensor(keys)
tens2 = safeloras_2.get_tensor(keys)
tens = alpha_1 * tens1 + alpha_2 * tens2
ret_tensor[keys] = tens
else:
if keys in safeloras_1.keys():
tens1 = safeloras_1.get_tensor(keys)
else:
tens1 = safeloras_2.get_tensor(keys)
ret_tensor[keys] = tens1
save_file(ret_tensor, output_path, metadata)
elif mode == "upl":
print(
f"Merging UNET/CLIP from {path_1} with LoRA from {path_2} to {output_path}. Merging ratio : {alpha_1}."
)
loaded_pipeline = StableDiffusionPipeline.from_pretrained(
path_1,
).to("cpu")
patch_pipe(loaded_pipeline, path_2)
collapse_lora(loaded_pipeline.unet, alpha_1)
collapse_lora(loaded_pipeline.text_encoder, alpha_1)
monkeypatch_remove_lora(loaded_pipeline.unet)
monkeypatch_remove_lora(loaded_pipeline.text_encoder)
loaded_pipeline.save_pretrained(output_path)
elif mode == "upl-ckpt-v2":
assert output_path.endswith(".ckpt"), "Only .ckpt files are supported"
name = os.path.basename(output_path)[0:-5]
print(
f"You will be using {name} as the token in A1111 webui. Make sure {name} is unique enough token."
)
loaded_pipeline = StableDiffusionPipeline.from_pretrained(
path_1,
).to("cpu")
tok_dict = patch_pipe(loaded_pipeline, path_2, patch_ti=False)
collapse_lora(loaded_pipeline.unet, alpha_1)
collapse_lora(loaded_pipeline.text_encoder, alpha_1)
monkeypatch_remove_lora(loaded_pipeline.unet)
monkeypatch_remove_lora(loaded_pipeline.text_encoder)
_tmp_output = output_path + ".tmp"
loaded_pipeline.save_pretrained(_tmp_output)
convert_to_ckpt(_tmp_output, output_path, as_half=True)
# remove the tmp_output folder
shutil.rmtree(_tmp_output)
keys = sorted(tok_dict.keys())
tok_catted = torch.stack([tok_dict[k] for k in keys])
ret = {
"string_to_token": {"*": torch.tensor(265)},
"string_to_param": {"*": tok_catted},
"name": name,
}
torch.save(ret, output_path[:-5] + ".pt")
print(
f"Textual embedding saved as {output_path[:-5]}.pt, put it in the embedding folder and use it as {name} in A1111 repo, "
)
elif mode == "ljl":
print("Using Join mode : alpha will not have an effect here.")
assert path_1.endswith(".safetensors") and path_2.endswith(
".safetensors"
), "Only .safetensors files are supported"
safeloras_1 = safe_open(path_1, framework="pt", device="cpu")
safeloras_2 = safe_open(path_2, framework="pt", device="cpu")
total_tensor, total_metadata, _, _ = lora_join([safeloras_1, safeloras_2])
save_file(total_tensor, output_path, total_metadata)
else:
print("Unknown mode", mode)
raise ValueError(f"Unknown mode {mode}")
def main():
fire.Fire(add)