File size: 5,202 Bytes
cff1674
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import os, sys

sys.path.insert(0, os.getcwd())
import argparse


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "base_model",
        help="The model which use it to train the dreambooth model",
        default="",
        type=str,
    )
    parser.add_argument(
        "db_model",
        help="the dreambooth model you want to extract the locon",
        default="",
        type=str,
    )
    parser.add_argument(
        "output_name", help="the output model", default="./out.pt", type=str
    )
    parser.add_argument(
        "--is_v2",
        help="Your base/db model is sd v2 or not",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--is_sdxl",
        help="Your base/db model is sdxl or not",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--device",
        help="Which device you want to use to extract the locon",
        default="cpu",
        type=str,
    )
    parser.add_argument(
        "--mode",
        help=(
            'extraction mode, can be "full", "fixed", "threshold", "ratio", "quantile". '
            'If not "fixed", network_dim and conv_dim will be ignored'
        ),
        default="fixed",
        type=str,
    )
    parser.add_argument(
        "--safetensors",
        help="use safetensors to save locon model",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--linear_dim",
        help="network dim for linear layer in fixed mode",
        default=1,
        type=int,
    )
    parser.add_argument(
        "--conv_dim",
        help="network dim for conv layer in fixed mode",
        default=1,
        type=int,
    )
    parser.add_argument(
        "--linear_threshold",
        help="singular value threshold for linear layer in threshold mode",
        default=0.0,
        type=float,
    )
    parser.add_argument(
        "--conv_threshold",
        help="singular value threshold for conv layer in threshold mode",
        default=0.0,
        type=float,
    )
    parser.add_argument(
        "--linear_ratio",
        help="singular ratio for linear layer in ratio mode",
        default=0.0,
        type=float,
    )
    parser.add_argument(
        "--conv_ratio",
        help="singular ratio for conv layer in ratio mode",
        default=0.0,
        type=float,
    )
    parser.add_argument(
        "--linear_quantile",
        help="singular value quantile for linear layer quantile mode",
        default=1.0,
        type=float,
    )
    parser.add_argument(
        "--conv_quantile",
        help="singular value quantile for conv layer quantile mode",
        default=1.0,
        type=float,
    )
    parser.add_argument(
        "--use_sparse_bias",
        help="enable sparse bias",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--sparsity", help="sparsity for sparse bias", default=0.98, type=float
    )
    parser.add_argument(
        "--disable_cp",
        help="don't use cp decomposition",
        default=False,
        action="store_true",
    )
    return parser.parse_args()


ARGS = get_args()


from lycoris.utils import extract_diff
from lycoris.kohya.model_utils import load_models_from_stable_diffusion_checkpoint
from lycoris.kohya.sdxl_model_util import load_models_from_sdxl_checkpoint

import torch
from safetensors.torch import save_file


def main():
    args = ARGS
    if args.is_sdxl:
        base = load_models_from_sdxl_checkpoint(None, args.base_model, args.device)
        db = load_models_from_sdxl_checkpoint(None, args.db_model, args.device)
    else:
        base = load_models_from_stable_diffusion_checkpoint(args.is_v2, args.base_model)
        db = load_models_from_stable_diffusion_checkpoint(args.is_v2, args.db_model)

    linear_mode_param = {
        "fixed": args.linear_dim,
        "threshold": args.linear_threshold,
        "ratio": args.linear_ratio,
        "quantile": args.linear_quantile,
        "full": None,
    }[args.mode]
    conv_mode_param = {
        "fixed": args.conv_dim,
        "threshold": args.conv_threshold,
        "ratio": args.conv_ratio,
        "quantile": args.conv_quantile,
        "full": None,
    }[args.mode]

    if args.is_sdxl:
        db_tes = [db[0], db[1]]
        db_unet = db[3]
        base_tes = [base[0], base[1]]
        base_unet = base[3]
    else:
        db_tes = [db[0]]
        db_unet = db[2]
        base_tes = [base[0]]
        base_unet = base[2]

    state_dict = extract_diff(
        base_tes,
        db_tes,
        base_unet,
        db_unet,
        args.mode,
        linear_mode_param,
        conv_mode_param,
        args.device,
        args.use_sparse_bias,
        args.sparsity,
        not args.disable_cp,
    )

    if args.safetensors:
        save_file(state_dict, args.output_name)
    else:
        torch.save(state_dict, args.output_name)


if __name__ == "__main__":
    main()