LucidDreamer / lora_diffusion /cli_pt_to_safetensors.py
haodongli's picture
init
916b126
import os
import fire
import torch
from lora_diffusion import (
DEFAULT_TARGET_REPLACE,
TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
UNET_DEFAULT_TARGET_REPLACE,
convert_loras_to_safeloras_with_embeds,
safetensors_available,
)
_target_by_name = {
"unet": UNET_DEFAULT_TARGET_REPLACE,
"text_encoder": TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
}
def convert(*paths, outpath, overwrite=False, **settings):
"""
Converts one or more pytorch Lora and/or Textual Embedding pytorch files
into a safetensor file.
Pass all the input paths as arguments. Whether they are Textual Embedding
or Lora models will be auto-detected.
For Lora models, their name will be taken from the path, i.e.
"lora_weight.pt" => unet
"lora_weight.text_encoder.pt" => text_encoder
You can also set target_modules and/or rank by providing an argument prefixed
by the name.
So a complete example might be something like:
```
python -m lora_diffusion.cli_pt_to_safetensors lora_weight.* --outpath lora_weight.safetensor --unet.rank 8
```
"""
modelmap = {}
embeds = {}
if os.path.exists(outpath) and not overwrite:
raise ValueError(
f"Output path {outpath} already exists, and overwrite is not True"
)
for path in paths:
data = torch.load(path)
if isinstance(data, dict):
print(f"Loading textual inversion embeds {data.keys()} from {path}")
embeds.update(data)
else:
name_parts = os.path.split(path)[1].split(".")
name = name_parts[-2] if len(name_parts) > 2 else "unet"
model_settings = {
"target_modules": _target_by_name.get(name, DEFAULT_TARGET_REPLACE),
"rank": 4,
}
prefix = f"{name}."
arg_settings = { k[len(prefix) :]: v for k, v in settings.items() if k.startswith(prefix) }
model_settings = { **model_settings, **arg_settings }
print(f"Loading Lora for {name} from {path} with settings {model_settings}")
modelmap[name] = (
path,
model_settings["target_modules"],
model_settings["rank"],
)
convert_loras_to_safeloras_with_embeds(modelmap, embeds, outpath)
def main():
fire.Fire(convert)
if __name__ == "__main__":
main()