|
from lora_diffusion.cli_lora_add import * |
|
from lora_diffusion.lora import * |
|
from lora_diffusion.to_ckpt_v2 import * |
|
|
|
def monkeypatch_or_replace_safeloras(models, safeloras): |
|
loras = parse_safeloras(safeloras) |
|
|
|
for name, (lora, ranks, target) in loras.items(): |
|
model = getattr(models, name, None) |
|
|
|
if not model: |
|
print(f"No model provided for {name}, contained in Lora") |
|
continue |
|
|
|
monkeypatch_or_replace_lora_extended(model, lora, target, ranks) |
|
def parse_safeloras( |
|
safeloras, |
|
) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]: |
|
""" |
|
Converts a loaded safetensor file that contains a set of module Loras |
|
into Parameters and other information |
|
|
|
Output is a dictionary of { |
|
"module name": ( |
|
[list of weights], |
|
[list of ranks], |
|
target_replacement_modules |
|
) |
|
} |
|
""" |
|
loras = {} |
|
|
|
metadata = safeloras['metadata'] |
|
safeloras_ = safeloras['weights'] |
|
get_name = lambda k: k.split(":")[0] |
|
|
|
keys = list(safeloras_.keys()) |
|
keys.sort(key=get_name) |
|
|
|
for name, module_keys in groupby(keys, get_name): |
|
info = metadata.get(name) |
|
|
|
if not info: |
|
raise ValueError( |
|
f"Tensor {name} has no metadata - is this a Lora safetensor?" |
|
) |
|
|
|
|
|
if info == EMBED_FLAG: |
|
continue |
|
|
|
|
|
|
|
target = json.loads(info) |
|
|
|
|
|
module_keys = list(module_keys) |
|
ranks = [4] * (len(module_keys) // 2) |
|
weights = [None] * len(module_keys) |
|
|
|
for key in module_keys: |
|
|
|
_, idx, direction = key.split(":") |
|
idx = int(idx) |
|
|
|
|
|
ranks[idx] = int(metadata[f"{name}:{idx}:rank"]) |
|
|
|
|
|
idx = idx * 2 + (1 if direction == "down" else 0) |
|
|
|
weights[idx] = nn.parameter.Parameter(safeloras_[key]) |
|
loras[name] = (weights, ranks, target) |
|
|
|
return loras |
|
|
|
|
|
def parse_safeloras_embeds( |
|
safeloras, |
|
) -> Dict[str, torch.Tensor]: |
|
""" |
|
Converts a loaded safetensor file that contains Textual Inversion embeds into |
|
a dictionary of embed_token: Tensor |
|
""" |
|
embeds = {} |
|
metadata = safeloras['metadata'] |
|
safeloras_ = safeloras['weights'] |
|
|
|
for key in safeloras_.keys(): |
|
|
|
meta=None |
|
if key in metadata: |
|
meta = metadata[key] |
|
if not meta or meta != EMBED_FLAG: |
|
continue |
|
|
|
embeds[key] = safeloras_[key] |
|
|
|
return embeds |
|
|
|
def patch_pipe( |
|
pipe, |
|
maybe_unet_path, |
|
token: Optional[str] = None, |
|
r: int = 4, |
|
patch_unet=True, |
|
patch_text=True, |
|
patch_ti=True, |
|
idempotent_token=True, |
|
unet_target_replace_module=DEFAULT_TARGET_REPLACE, |
|
text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE, |
|
): |
|
safeloras=maybe_unet_path |
|
monkeypatch_or_replace_safeloras(pipe, safeloras) |
|
tok_dict = parse_safeloras_embeds(safeloras) |
|
|
|
if patch_ti: |
|
apply_learned_embed_in_clip( |
|
tok_dict, |
|
pipe.text_encoder, |
|
pipe.tokenizer, |
|
token=token, |
|
idempotent=idempotent_token, |
|
) |
|
return tok_dict |
|
|
|
def lora_convert(model_path, as_half): |
|
|
|
""" |
|
Modified version of lora_duffusion.to_ckpt_v2.convert_to_ckpt |
|
""" |
|
|
|
assert model_path is not None, "Must provide a model path!" |
|
|
|
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin") |
|
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin") |
|
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin") |
|
|
|
|
|
unet_state_dict = torch.load(unet_path, map_location="cpu") |
|
unet_state_dict = convert_unet_state_dict(unet_state_dict) |
|
unet_state_dict = { |
|
"model.diffusion_model." + k: v for k, v in unet_state_dict.items() |
|
} |
|
|
|
|
|
vae_state_dict = torch.load(vae_path, map_location="cpu") |
|
vae_state_dict = convert_vae_state_dict(vae_state_dict) |
|
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} |
|
|
|
|
|
text_enc_dict = torch.load(text_enc_path, map_location="cpu") |
|
text_enc_dict = convert_text_enc_state_dict(text_enc_dict) |
|
text_enc_dict = { |
|
"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items() |
|
} |
|
|
|
|
|
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} |
|
if as_half: |
|
state_dict = {k: v.half() for k, v in state_dict.items()} |
|
|
|
return state_dict |
|
|
|
def merge(path_1: str, |
|
path_2: str, |
|
alpha_1: float = 0.5, |
|
): |
|
|
|
loaded_pipeline = StableDiffusionPipeline.from_pretrained( |
|
path_1, |
|
).to("cpu") |
|
|
|
tok_dict = patch_pipe(loaded_pipeline, path_2, patch_ti=False) |
|
collapse_lora(loaded_pipeline.unet, alpha_1) |
|
collapse_lora(loaded_pipeline.text_encoder, alpha_1) |
|
|
|
monkeypatch_remove_lora(loaded_pipeline.unet) |
|
monkeypatch_remove_lora(loaded_pipeline.text_encoder) |
|
|
|
_tmp_output = "./merge.tmp" |
|
|
|
loaded_pipeline.save_pretrained(_tmp_output) |
|
state_dict = lora_convert(_tmp_output, as_half=True) |
|
|
|
shutil.rmtree(_tmp_output) |
|
|
|
keys = sorted(tok_dict.keys()) |
|
tok_catted = torch.stack([tok_dict[k] for k in keys]) |
|
ret = { |
|
"string_to_token": {"*": torch.tensor(265)}, |
|
"string_to_param": {"*": tok_catted}, |
|
"name": "", |
|
} |
|
|
|
return state_dict, ret |