Upload swap_vae.py with huggingface_hub
Browse files- swap_vae.py +45 -0
swap_vae.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import click
|
3 |
+
|
4 |
+
def overwrite_first_stage(model_state_dict, vae_state_dict):
|
5 |
+
"""
|
6 |
+
Overwrite the First Stage Decoders.
|
7 |
+
|
8 |
+
From the new repo:
|
9 |
+
To keep compatibility with existing models,
|
10 |
+
only the decoder part was finetuned;
|
11 |
+
the checkpoints can be used as a drop-in replacement
|
12 |
+
for the existing autoencoder.
|
13 |
+
|
14 |
+
Sounds like we only need to change the decoder weights.
|
15 |
+
"""
|
16 |
+
|
17 |
+
target = "first_stage_model."
|
18 |
+
for key in model_state_dict.keys():
|
19 |
+
if target in key and ("decoder" in key or "encoder" in key):
|
20 |
+
matching_name = key.split(target)[1]
|
21 |
+
|
22 |
+
# double check this weight exists in the new vae
|
23 |
+
if matching_name in vae_state_dict:
|
24 |
+
model_state_dict[key] = vae_state_dict[matching_name]
|
25 |
+
else:
|
26 |
+
print(f"{key} Does not exist in the new VAE weights!")
|
27 |
+
|
28 |
+
return model_state_dict
|
29 |
+
|
30 |
+
@click.command()
|
31 |
+
@click.option("--base-model", type=str, default="sd-v1-5.ckpt")
|
32 |
+
@click.option("--vae", type=str, default="new_vae.ckpt")
|
33 |
+
@click.option("--output-name", type=str, default="sd-v1-5-new-vae.ckpt")
|
34 |
+
def main(base_model, vae, output_name):
|
35 |
+
print("hello")
|
36 |
+
model = torch.load(base_model)
|
37 |
+
new_vae = torch.load(vae)
|
38 |
+
|
39 |
+
model["state_dict"] = overwrite_first_stage(model["state_dict"], new_vae["state_dict"])
|
40 |
+
|
41 |
+
print(f"Saving to {output_name}")
|
42 |
+
torch.save(model, output_name)
|
43 |
+
|
44 |
+
|
45 |
+
main()
|