File size: 3,357 Bytes
cff1674 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import os, sys
sys.path.insert(0, os.getcwd())
import argparse
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"base_model", help="The model you want to merge with loha", default="", type=str
)
parser.add_argument(
"lycoris_model",
help="the lyco model you want to merge into sd model",
default="",
type=str,
)
parser.add_argument(
"output_name", help="the output model", default="./out.pt", type=str
)
parser.add_argument(
"--is_v2",
help="Your base model is sd v2 or not",
default=False,
action="store_true",
)
parser.add_argument(
"--is_sdxl",
help="Your base/db model is sdxl or not",
default=False,
action="store_true",
)
parser.add_argument(
"--device",
help="Which device you want to use to merge the weight",
default="cpu",
type=str,
)
parser.add_argument("--dtype", help="dtype to save", default="float", type=str)
parser.add_argument(
"--weight", help="weight for the lyco model to merge", default="1.0", type=float
)
return parser.parse_args()
args = ARGS = get_args()
from lycoris.utils import merge
from lycoris.kohya.model_utils import (
load_models_from_stable_diffusion_checkpoint,
save_stable_diffusion_checkpoint,
load_file,
)
from lycoris.kohya.sdxl_model_util import (
load_models_from_sdxl_checkpoint,
save_stable_diffusion_checkpoint as save_sdxl_checkpoint,
)
import torch
@torch.no_grad()
def main():
if args.is_sdxl:
base = load_models_from_sdxl_checkpoint(
None, args.base_model, map_location=args.device
)
else:
base = load_models_from_stable_diffusion_checkpoint(args.is_v2, args.base_model)
if ARGS.lycoris_model.rsplit(".", 1)[-1] == "safetensors":
lyco = load_file(ARGS.lycoris_model)
else:
lyco = torch.load(ARGS.lycoris_model)
dtype_str = ARGS.dtype.replace("fp", "float").replace("bf", "bfloat")
dtype = {
"float": torch.float,
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
"bfloat": torch.bfloat16,
"bfloat16": torch.bfloat16,
}.get(dtype_str, None)
if dtype is None:
raise ValueError(f'Cannot Find the dtype "{dtype}"')
if args.is_sdxl:
base_tes = [base[0], base[1]]
base_unet = base[3]
else:
base_tes = [base[0]]
base_unet = base[2]
merge(base_tes, base_unet, lyco, ARGS.weight, ARGS.device)
if args.is_sdxl:
save_sdxl_checkpoint(
ARGS.output_name,
base[0].cpu(),
base[1].cpu(),
base[3].cpu(),
0,
0,
None,
base[2],
getattr(base[1], "logit_scale", None),
dtype,
)
else:
save_stable_diffusion_checkpoint(
ARGS.is_v2,
ARGS.output_name,
base[0].cpu(),
base[2].cpu(),
None,
0,
0,
dtype,
base[1],
)
if __name__ == "__main__":
main()
|