Spaces:
Paused
Paused
Upload 63 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- diffsynth/__init__.py +6 -0
- diffsynth/controlnets/__init__.py +2 -0
- diffsynth/controlnets/controlnet_unit.py +53 -0
- diffsynth/controlnets/processors.py +51 -0
- diffsynth/data/__init__.py +1 -0
- diffsynth/data/video.py +148 -0
- diffsynth/extensions/ESRGAN/__init__.py +118 -0
- diffsynth/extensions/FastBlend/__init__.py +63 -0
- diffsynth/extensions/FastBlend/api.py +397 -0
- diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
- diffsynth/extensions/FastBlend/data.py +146 -0
- diffsynth/extensions/FastBlend/patch_match.py +298 -0
- diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
- diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
- diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
- diffsynth/extensions/FastBlend/runners/fast.py +141 -0
- diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
- diffsynth/extensions/RIFE/__init__.py +241 -0
- diffsynth/models/__init__.py +482 -0
- diffsynth/models/attention.py +89 -0
- diffsynth/models/hunyuan_dit.py +451 -0
- diffsynth/models/hunyuan_dit_text_encoder.py +161 -0
- diffsynth/models/sd_controlnet.py +587 -0
- diffsynth/models/sd_ipadapter.py +56 -0
- diffsynth/models/sd_lora.py +60 -0
- diffsynth/models/sd_motion.py +198 -0
- diffsynth/models/sd_text_encoder.py +320 -0
- diffsynth/models/sd_unet.py +0 -0
- diffsynth/models/sd_vae_decoder.py +332 -0
- diffsynth/models/sd_vae_encoder.py +278 -0
- diffsynth/models/sdxl_ipadapter.py +121 -0
- diffsynth/models/sdxl_motion.py +103 -0
- diffsynth/models/sdxl_text_encoder.py +757 -0
- diffsynth/models/sdxl_unet.py +0 -0
- diffsynth/models/sdxl_vae_decoder.py +15 -0
- diffsynth/models/sdxl_vae_encoder.py +15 -0
- diffsynth/models/svd_image_encoder.py +504 -0
- diffsynth/models/svd_unet.py +0 -0
- diffsynth/models/svd_vae_decoder.py +577 -0
- diffsynth/models/svd_vae_encoder.py +138 -0
- diffsynth/models/tiler.py +106 -0
- diffsynth/pipelines/__init__.py +6 -0
- diffsynth/pipelines/dancer.py +174 -0
- diffsynth/pipelines/hunyuan_dit.py +298 -0
- diffsynth/pipelines/stable_diffusion.py +167 -0
- diffsynth/pipelines/stable_diffusion_video.py +356 -0
- diffsynth/pipelines/stable_diffusion_xl.py +175 -0
- diffsynth/pipelines/stable_diffusion_xl_video.py +190 -0
- diffsynth/pipelines/stable_video_diffusion.py +307 -0
- diffsynth/processors/FastBlend.py +142 -0
diffsynth/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .data import *
|
2 |
+
from .models import *
|
3 |
+
from .prompts import *
|
4 |
+
from .schedulers import *
|
5 |
+
from .pipelines import *
|
6 |
+
from .controlnets import *
|
diffsynth/controlnets/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager
|
2 |
+
from .processors import Annotator
|
diffsynth/controlnets/controlnet_unit.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from .processors import Processor_id
|
4 |
+
|
5 |
+
|
6 |
+
class ControlNetConfigUnit:
|
7 |
+
def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
|
8 |
+
self.processor_id = processor_id
|
9 |
+
self.model_path = model_path
|
10 |
+
self.scale = scale
|
11 |
+
|
12 |
+
|
13 |
+
class ControlNetUnit:
|
14 |
+
def __init__(self, processor, model, scale=1.0):
|
15 |
+
self.processor = processor
|
16 |
+
self.model = model
|
17 |
+
self.scale = scale
|
18 |
+
|
19 |
+
|
20 |
+
class MultiControlNetManager:
|
21 |
+
def __init__(self, controlnet_units=[]):
|
22 |
+
self.processors = [unit.processor for unit in controlnet_units]
|
23 |
+
self.models = [unit.model for unit in controlnet_units]
|
24 |
+
self.scales = [unit.scale for unit in controlnet_units]
|
25 |
+
|
26 |
+
def process_image(self, image, processor_id=None):
|
27 |
+
if processor_id is None:
|
28 |
+
processed_image = [processor(image) for processor in self.processors]
|
29 |
+
else:
|
30 |
+
processed_image = [self.processors[processor_id](image)]
|
31 |
+
processed_image = torch.concat([
|
32 |
+
torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
|
33 |
+
for image_ in processed_image
|
34 |
+
], dim=0)
|
35 |
+
return processed_image
|
36 |
+
|
37 |
+
def __call__(
|
38 |
+
self,
|
39 |
+
sample, timestep, encoder_hidden_states, conditionings,
|
40 |
+
tiled=False, tile_size=64, tile_stride=32
|
41 |
+
):
|
42 |
+
res_stack = None
|
43 |
+
for conditioning, model, scale in zip(conditionings, self.models, self.scales):
|
44 |
+
res_stack_ = model(
|
45 |
+
sample, timestep, encoder_hidden_states, conditioning,
|
46 |
+
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
47 |
+
)
|
48 |
+
res_stack_ = [res * scale for res in res_stack_]
|
49 |
+
if res_stack is None:
|
50 |
+
res_stack = res_stack_
|
51 |
+
else:
|
52 |
+
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
53 |
+
return res_stack
|
diffsynth/controlnets/processors.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing_extensions import Literal, TypeAlias
|
2 |
+
import warnings
|
3 |
+
with warnings.catch_warnings():
|
4 |
+
warnings.simplefilter("ignore")
|
5 |
+
from controlnet_aux.processor import (
|
6 |
+
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
Processor_id: TypeAlias = Literal[
|
11 |
+
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"
|
12 |
+
]
|
13 |
+
|
14 |
+
class Annotator:
|
15 |
+
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None):
|
16 |
+
if processor_id == "canny":
|
17 |
+
self.processor = CannyDetector()
|
18 |
+
elif processor_id == "depth":
|
19 |
+
self.processor = MidasDetector.from_pretrained(model_path).to("cuda")
|
20 |
+
elif processor_id == "softedge":
|
21 |
+
self.processor = HEDdetector.from_pretrained(model_path).to("cuda")
|
22 |
+
elif processor_id == "lineart":
|
23 |
+
self.processor = LineartDetector.from_pretrained(model_path).to("cuda")
|
24 |
+
elif processor_id == "lineart_anime":
|
25 |
+
self.processor = LineartAnimeDetector.from_pretrained(model_path).to("cuda")
|
26 |
+
elif processor_id == "openpose":
|
27 |
+
self.processor = OpenposeDetector.from_pretrained(model_path).to("cuda")
|
28 |
+
elif processor_id == "tile":
|
29 |
+
self.processor = None
|
30 |
+
else:
|
31 |
+
raise ValueError(f"Unsupported processor_id: {processor_id}")
|
32 |
+
|
33 |
+
self.processor_id = processor_id
|
34 |
+
self.detect_resolution = detect_resolution
|
35 |
+
|
36 |
+
def __call__(self, image):
|
37 |
+
width, height = image.size
|
38 |
+
if self.processor_id == "openpose":
|
39 |
+
kwargs = {
|
40 |
+
"include_body": True,
|
41 |
+
"include_hand": True,
|
42 |
+
"include_face": True
|
43 |
+
}
|
44 |
+
else:
|
45 |
+
kwargs = {}
|
46 |
+
if self.processor is not None:
|
47 |
+
detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
|
48 |
+
image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
|
49 |
+
image = image.resize((width, height))
|
50 |
+
return image
|
51 |
+
|
diffsynth/data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .video import VideoData, save_video, save_frames
|
diffsynth/data/video.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import imageio, os
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
|
7 |
+
class LowMemoryVideo:
|
8 |
+
def __init__(self, file_name):
|
9 |
+
self.reader = imageio.get_reader(file_name)
|
10 |
+
|
11 |
+
def __len__(self):
|
12 |
+
return self.reader.count_frames()
|
13 |
+
|
14 |
+
def __getitem__(self, item):
|
15 |
+
return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
|
16 |
+
|
17 |
+
def __del__(self):
|
18 |
+
self.reader.close()
|
19 |
+
|
20 |
+
|
21 |
+
def split_file_name(file_name):
|
22 |
+
result = []
|
23 |
+
number = -1
|
24 |
+
for i in file_name:
|
25 |
+
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
26 |
+
if number == -1:
|
27 |
+
number = 0
|
28 |
+
number = number*10 + ord(i) - ord("0")
|
29 |
+
else:
|
30 |
+
if number != -1:
|
31 |
+
result.append(number)
|
32 |
+
number = -1
|
33 |
+
result.append(i)
|
34 |
+
if number != -1:
|
35 |
+
result.append(number)
|
36 |
+
result = tuple(result)
|
37 |
+
return result
|
38 |
+
|
39 |
+
|
40 |
+
def search_for_images(folder):
|
41 |
+
file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
|
42 |
+
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
|
43 |
+
file_list = [i[1] for i in sorted(file_list)]
|
44 |
+
file_list = [os.path.join(folder, i) for i in file_list]
|
45 |
+
return file_list
|
46 |
+
|
47 |
+
|
48 |
+
class LowMemoryImageFolder:
|
49 |
+
def __init__(self, folder, file_list=None):
|
50 |
+
if file_list is None:
|
51 |
+
self.file_list = search_for_images(folder)
|
52 |
+
else:
|
53 |
+
self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
|
54 |
+
|
55 |
+
def __len__(self):
|
56 |
+
return len(self.file_list)
|
57 |
+
|
58 |
+
def __getitem__(self, item):
|
59 |
+
return Image.open(self.file_list[item]).convert("RGB")
|
60 |
+
|
61 |
+
def __del__(self):
|
62 |
+
pass
|
63 |
+
|
64 |
+
|
65 |
+
def crop_and_resize(image, height, width):
|
66 |
+
image = np.array(image)
|
67 |
+
image_height, image_width, _ = image.shape
|
68 |
+
if image_height / image_width < height / width:
|
69 |
+
croped_width = int(image_height / height * width)
|
70 |
+
left = (image_width - croped_width) // 2
|
71 |
+
image = image[:, left: left+croped_width]
|
72 |
+
image = Image.fromarray(image).resize((width, height))
|
73 |
+
else:
|
74 |
+
croped_height = int(image_width / width * height)
|
75 |
+
left = (image_height - croped_height) // 2
|
76 |
+
image = image[left: left+croped_height, :]
|
77 |
+
image = Image.fromarray(image).resize((width, height))
|
78 |
+
return image
|
79 |
+
|
80 |
+
|
81 |
+
class VideoData:
|
82 |
+
def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):
|
83 |
+
if video_file is not None:
|
84 |
+
self.data_type = "video"
|
85 |
+
self.data = LowMemoryVideo(video_file, **kwargs)
|
86 |
+
elif image_folder is not None:
|
87 |
+
self.data_type = "images"
|
88 |
+
self.data = LowMemoryImageFolder(image_folder, **kwargs)
|
89 |
+
else:
|
90 |
+
raise ValueError("Cannot open video or image folder")
|
91 |
+
self.length = None
|
92 |
+
self.set_shape(height, width)
|
93 |
+
|
94 |
+
def raw_data(self):
|
95 |
+
frames = []
|
96 |
+
for i in range(self.__len__()):
|
97 |
+
frames.append(self.__getitem__(i))
|
98 |
+
return frames
|
99 |
+
|
100 |
+
def set_length(self, length):
|
101 |
+
self.length = length
|
102 |
+
|
103 |
+
def set_shape(self, height, width):
|
104 |
+
self.height = height
|
105 |
+
self.width = width
|
106 |
+
|
107 |
+
def __len__(self):
|
108 |
+
if self.length is None:
|
109 |
+
return len(self.data)
|
110 |
+
else:
|
111 |
+
return self.length
|
112 |
+
|
113 |
+
def shape(self):
|
114 |
+
if self.height is not None and self.width is not None:
|
115 |
+
return self.height, self.width
|
116 |
+
else:
|
117 |
+
height, width, _ = self.__getitem__(0).shape
|
118 |
+
return height, width
|
119 |
+
|
120 |
+
def __getitem__(self, item):
|
121 |
+
frame = self.data.__getitem__(item)
|
122 |
+
width, height = frame.size
|
123 |
+
if self.height is not None and self.width is not None:
|
124 |
+
if self.height != height or self.width != width:
|
125 |
+
frame = crop_and_resize(frame, self.height, self.width)
|
126 |
+
return frame
|
127 |
+
|
128 |
+
def __del__(self):
|
129 |
+
pass
|
130 |
+
|
131 |
+
def save_images(self, folder):
|
132 |
+
os.makedirs(folder, exist_ok=True)
|
133 |
+
for i in tqdm(range(self.__len__()), desc="Saving images"):
|
134 |
+
frame = self.__getitem__(i)
|
135 |
+
frame.save(os.path.join(folder, f"{i}.png"))
|
136 |
+
|
137 |
+
|
138 |
+
def save_video(frames, save_path, fps, quality=9):
|
139 |
+
writer = imageio.get_writer(save_path, fps=fps, quality=quality)
|
140 |
+
for frame in tqdm(frames, desc="Saving video"):
|
141 |
+
frame = np.array(frame)
|
142 |
+
writer.append_data(frame)
|
143 |
+
writer.close()
|
144 |
+
|
145 |
+
def save_frames(frames, save_path):
|
146 |
+
os.makedirs(save_path, exist_ok=True)
|
147 |
+
for i, frame in enumerate(tqdm(frames, desc="Saving images")):
|
148 |
+
frame.save(os.path.join(save_path, f"{i}.png"))
|
diffsynth/extensions/ESRGAN/__init__.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from einops import repeat
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
class ResidualDenseBlock(torch.nn.Module):
|
8 |
+
|
9 |
+
def __init__(self, num_feat=64, num_grow_ch=32):
|
10 |
+
super(ResidualDenseBlock, self).__init__()
|
11 |
+
self.conv1 = torch.nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
|
12 |
+
self.conv2 = torch.nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
|
13 |
+
self.conv3 = torch.nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
14 |
+
self.conv4 = torch.nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
15 |
+
self.conv5 = torch.nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
|
16 |
+
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
x1 = self.lrelu(self.conv1(x))
|
20 |
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
21 |
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
22 |
+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
23 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
24 |
+
return x5 * 0.2 + x
|
25 |
+
|
26 |
+
|
27 |
+
class RRDB(torch.nn.Module):
|
28 |
+
|
29 |
+
def __init__(self, num_feat, num_grow_ch=32):
|
30 |
+
super(RRDB, self).__init__()
|
31 |
+
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
|
32 |
+
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
|
33 |
+
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
out = self.rdb1(x)
|
37 |
+
out = self.rdb2(out)
|
38 |
+
out = self.rdb3(out)
|
39 |
+
return out * 0.2 + x
|
40 |
+
|
41 |
+
|
42 |
+
class RRDBNet(torch.nn.Module):
|
43 |
+
|
44 |
+
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32):
|
45 |
+
super(RRDBNet, self).__init__()
|
46 |
+
self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
47 |
+
self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
|
48 |
+
self.conv_body = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
49 |
+
# upsample
|
50 |
+
self.conv_up1 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
51 |
+
self.conv_up2 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
52 |
+
self.conv_hr = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
53 |
+
self.conv_last = torch.nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
54 |
+
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
feat = x
|
58 |
+
feat = self.conv_first(feat)
|
59 |
+
body_feat = self.conv_body(self.body(feat))
|
60 |
+
feat = feat + body_feat
|
61 |
+
# upsample
|
62 |
+
feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
|
63 |
+
feat = self.lrelu(self.conv_up1(feat))
|
64 |
+
feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
|
65 |
+
feat = self.lrelu(self.conv_up2(feat))
|
66 |
+
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
67 |
+
return out
|
68 |
+
|
69 |
+
|
70 |
+
class ESRGAN(torch.nn.Module):
|
71 |
+
def __init__(self, model):
|
72 |
+
super().__init__()
|
73 |
+
self.model = model
|
74 |
+
|
75 |
+
@staticmethod
|
76 |
+
def from_pretrained(model_path):
|
77 |
+
model = RRDBNet()
|
78 |
+
state_dict = torch.load(model_path, map_location="cpu")["params_ema"]
|
79 |
+
model.load_state_dict(state_dict)
|
80 |
+
model.eval()
|
81 |
+
return ESRGAN(model)
|
82 |
+
|
83 |
+
def process_image(self, image):
|
84 |
+
image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
|
85 |
+
return image
|
86 |
+
|
87 |
+
def process_images(self, images):
|
88 |
+
images = [self.process_image(image) for image in images]
|
89 |
+
images = torch.stack(images)
|
90 |
+
return images
|
91 |
+
|
92 |
+
def decode_images(self, images):
|
93 |
+
images = (images.permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
|
94 |
+
images = [Image.fromarray(image) for image in images]
|
95 |
+
return images
|
96 |
+
|
97 |
+
@torch.no_grad()
|
98 |
+
def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
|
99 |
+
# Preprocess
|
100 |
+
input_tensor = self.process_images(images)
|
101 |
+
|
102 |
+
# Interpolate
|
103 |
+
output_tensor = []
|
104 |
+
for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
|
105 |
+
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
106 |
+
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
107 |
+
batch_input_tensor = batch_input_tensor.to(
|
108 |
+
device=self.model.conv_first.weight.device,
|
109 |
+
dtype=self.model.conv_first.weight.dtype)
|
110 |
+
batch_output_tensor = self.model(batch_input_tensor)
|
111 |
+
output_tensor.append(batch_output_tensor.cpu())
|
112 |
+
|
113 |
+
# Output
|
114 |
+
output_tensor = torch.concat(output_tensor, dim=0)
|
115 |
+
|
116 |
+
# To images
|
117 |
+
output_images = self.decode_images(output_tensor)
|
118 |
+
return output_images
|
diffsynth/extensions/FastBlend/__init__.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .runners.fast import TableManager, PyramidPatchMatcher
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import cupy as cp
|
5 |
+
|
6 |
+
|
7 |
+
class FastBlendSmoother:
|
8 |
+
def __init__(self):
|
9 |
+
self.batch_size = 8
|
10 |
+
self.window_size = 64
|
11 |
+
self.ebsynth_config = {
|
12 |
+
"minimum_patch_size": 5,
|
13 |
+
"threads_per_block": 8,
|
14 |
+
"num_iter": 5,
|
15 |
+
"gpu_id": 0,
|
16 |
+
"guide_weight": 10.0,
|
17 |
+
"initialize": "identity",
|
18 |
+
"tracking_window_size": 0,
|
19 |
+
}
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def from_model_manager(model_manager):
|
23 |
+
# TODO: fetch GPU ID from model_manager
|
24 |
+
return FastBlendSmoother()
|
25 |
+
|
26 |
+
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config):
|
27 |
+
frames_guide = [np.array(frame) for frame in frames_guide]
|
28 |
+
frames_style = [np.array(frame) for frame in frames_style]
|
29 |
+
table_manager = TableManager()
|
30 |
+
patch_match_engine = PyramidPatchMatcher(
|
31 |
+
image_height=frames_style[0].shape[0],
|
32 |
+
image_width=frames_style[0].shape[1],
|
33 |
+
channel=3,
|
34 |
+
**ebsynth_config
|
35 |
+
)
|
36 |
+
# left part
|
37 |
+
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4")
|
38 |
+
table_l = table_manager.remapping_table_to_blending_table(table_l)
|
39 |
+
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4")
|
40 |
+
# right part
|
41 |
+
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4")
|
42 |
+
table_r = table_manager.remapping_table_to_blending_table(table_r)
|
43 |
+
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1]
|
44 |
+
# merge
|
45 |
+
frames = []
|
46 |
+
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
|
47 |
+
weight_m = -1
|
48 |
+
weight = weight_l + weight_m + weight_r
|
49 |
+
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
|
50 |
+
frames.append(frame)
|
51 |
+
frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames]
|
52 |
+
return frames
|
53 |
+
|
54 |
+
def __call__(self, rendered_frames, original_frames=None, **kwargs):
|
55 |
+
frames = self.run(
|
56 |
+
original_frames, rendered_frames,
|
57 |
+
self.batch_size, self.window_size, self.ebsynth_config
|
58 |
+
)
|
59 |
+
mempool = cp.get_default_memory_pool()
|
60 |
+
pinned_mempool = cp.get_default_pinned_memory_pool()
|
61 |
+
mempool.free_all_blocks()
|
62 |
+
pinned_mempool.free_all_blocks()
|
63 |
+
return frames
|
diffsynth/extensions/FastBlend/api.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .runners import AccurateModeRunner, FastModeRunner, BalancedModeRunner, InterpolationModeRunner, InterpolationModeSingleFrameRunner
|
2 |
+
from .data import VideoData, get_video_fps, save_video, search_for_images
|
3 |
+
import os
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
|
7 |
+
def check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder):
|
8 |
+
frames_guide = VideoData(video_guide, video_guide_folder)
|
9 |
+
frames_style = VideoData(video_style, video_style_folder)
|
10 |
+
message = ""
|
11 |
+
if len(frames_guide) < len(frames_style):
|
12 |
+
message += f"The number of frames mismatches. Only the first {len(frames_guide)} frames of style video will be used.\n"
|
13 |
+
frames_style.set_length(len(frames_guide))
|
14 |
+
elif len(frames_guide) > len(frames_style):
|
15 |
+
message += f"The number of frames mismatches. Only the first {len(frames_style)} frames of guide video will be used.\n"
|
16 |
+
frames_guide.set_length(len(frames_style))
|
17 |
+
height_guide, width_guide = frames_guide.shape()
|
18 |
+
height_style, width_style = frames_style.shape()
|
19 |
+
if height_guide != height_style or width_guide != width_style:
|
20 |
+
message += f"The shape of frames mismatches. The frames in style video will be resized to (height: {height_guide}, width: {width_guide})\n"
|
21 |
+
frames_style.set_shape(height_guide, width_guide)
|
22 |
+
return frames_guide, frames_style, message
|
23 |
+
|
24 |
+
|
25 |
+
def smooth_video(
|
26 |
+
video_guide,
|
27 |
+
video_guide_folder,
|
28 |
+
video_style,
|
29 |
+
video_style_folder,
|
30 |
+
mode,
|
31 |
+
window_size,
|
32 |
+
batch_size,
|
33 |
+
tracking_window_size,
|
34 |
+
output_path,
|
35 |
+
fps,
|
36 |
+
minimum_patch_size,
|
37 |
+
num_iter,
|
38 |
+
guide_weight,
|
39 |
+
initialize,
|
40 |
+
progress = None,
|
41 |
+
):
|
42 |
+
# input
|
43 |
+
frames_guide, frames_style, message = check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder)
|
44 |
+
if len(message) > 0:
|
45 |
+
print(message)
|
46 |
+
# output
|
47 |
+
if output_path == "":
|
48 |
+
if video_style is None:
|
49 |
+
output_path = os.path.join(video_style_folder, "output")
|
50 |
+
else:
|
51 |
+
output_path = os.path.join(os.path.split(video_style)[0], "output")
|
52 |
+
os.makedirs(output_path, exist_ok=True)
|
53 |
+
print("No valid output_path. Your video will be saved here:", output_path)
|
54 |
+
elif not os.path.exists(output_path):
|
55 |
+
os.makedirs(output_path, exist_ok=True)
|
56 |
+
print("Your video will be saved here:", output_path)
|
57 |
+
frames_path = os.path.join(output_path, "frames")
|
58 |
+
video_path = os.path.join(output_path, "video.mp4")
|
59 |
+
os.makedirs(frames_path, exist_ok=True)
|
60 |
+
# process
|
61 |
+
if mode == "Fast" or mode == "Balanced":
|
62 |
+
tracking_window_size = 0
|
63 |
+
ebsynth_config = {
|
64 |
+
"minimum_patch_size": minimum_patch_size,
|
65 |
+
"threads_per_block": 8,
|
66 |
+
"num_iter": num_iter,
|
67 |
+
"gpu_id": 0,
|
68 |
+
"guide_weight": guide_weight,
|
69 |
+
"initialize": initialize,
|
70 |
+
"tracking_window_size": tracking_window_size,
|
71 |
+
}
|
72 |
+
if mode == "Fast":
|
73 |
+
FastModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
74 |
+
elif mode == "Balanced":
|
75 |
+
BalancedModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
76 |
+
elif mode == "Accurate":
|
77 |
+
AccurateModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
78 |
+
# output
|
79 |
+
try:
|
80 |
+
fps = int(fps)
|
81 |
+
except:
|
82 |
+
fps = get_video_fps(video_style) if video_style is not None else 30
|
83 |
+
print("Fps:", fps)
|
84 |
+
print("Saving video...")
|
85 |
+
video_path = save_video(frames_path, video_path, num_frames=len(frames_style), fps=fps)
|
86 |
+
print("Success!")
|
87 |
+
print("Your frames are here:", frames_path)
|
88 |
+
print("Your video is here:", video_path)
|
89 |
+
return output_path, fps, video_path
|
90 |
+
|
91 |
+
|
92 |
+
class KeyFrameMatcher:
|
93 |
+
def __init__(self):
|
94 |
+
pass
|
95 |
+
|
96 |
+
def extract_number_from_filename(self, file_name):
|
97 |
+
result = []
|
98 |
+
number = -1
|
99 |
+
for i in file_name:
|
100 |
+
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
101 |
+
if number == -1:
|
102 |
+
number = 0
|
103 |
+
number = number*10 + ord(i) - ord("0")
|
104 |
+
else:
|
105 |
+
if number != -1:
|
106 |
+
result.append(number)
|
107 |
+
number = -1
|
108 |
+
if number != -1:
|
109 |
+
result.append(number)
|
110 |
+
result = tuple(result)
|
111 |
+
return result
|
112 |
+
|
113 |
+
def extract_number_from_filenames(self, file_names):
|
114 |
+
numbers = [self.extract_number_from_filename(file_name) for file_name in file_names]
|
115 |
+
min_length = min(len(i) for i in numbers)
|
116 |
+
for i in range(min_length-1, -1, -1):
|
117 |
+
if len(set(number[i] for number in numbers))==len(file_names):
|
118 |
+
return [number[i] for number in numbers]
|
119 |
+
return list(range(len(file_names)))
|
120 |
+
|
121 |
+
def match_using_filename(self, file_names_a, file_names_b):
|
122 |
+
file_names_b_set = set(file_names_b)
|
123 |
+
matched_file_name = []
|
124 |
+
for file_name in file_names_a:
|
125 |
+
if file_name not in file_names_b_set:
|
126 |
+
matched_file_name.append(None)
|
127 |
+
else:
|
128 |
+
matched_file_name.append(file_name)
|
129 |
+
return matched_file_name
|
130 |
+
|
131 |
+
def match_using_numbers(self, file_names_a, file_names_b):
|
132 |
+
numbers_a = self.extract_number_from_filenames(file_names_a)
|
133 |
+
numbers_b = self.extract_number_from_filenames(file_names_b)
|
134 |
+
numbers_b_dict = {number: file_name for number, file_name in zip(numbers_b, file_names_b)}
|
135 |
+
matched_file_name = []
|
136 |
+
for number in numbers_a:
|
137 |
+
if number in numbers_b_dict:
|
138 |
+
matched_file_name.append(numbers_b_dict[number])
|
139 |
+
else:
|
140 |
+
matched_file_name.append(None)
|
141 |
+
return matched_file_name
|
142 |
+
|
143 |
+
def match_filenames(self, file_names_a, file_names_b):
|
144 |
+
matched_file_name = self.match_using_filename(file_names_a, file_names_b)
|
145 |
+
if sum([i is not None for i in matched_file_name]) > 0:
|
146 |
+
return matched_file_name
|
147 |
+
matched_file_name = self.match_using_numbers(file_names_a, file_names_b)
|
148 |
+
return matched_file_name
|
149 |
+
|
150 |
+
|
151 |
+
def detect_frames(frames_path, keyframes_path):
|
152 |
+
if not os.path.exists(frames_path) and not os.path.exists(keyframes_path):
|
153 |
+
return "Please input the directory of guide video and rendered frames"
|
154 |
+
elif not os.path.exists(frames_path):
|
155 |
+
return "Please input the directory of guide video"
|
156 |
+
elif not os.path.exists(keyframes_path):
|
157 |
+
return "Please input the directory of rendered frames"
|
158 |
+
frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
|
159 |
+
keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
|
160 |
+
if len(frames)==0:
|
161 |
+
return f"No images detected in {frames_path}"
|
162 |
+
if len(keyframes)==0:
|
163 |
+
return f"No images detected in {keyframes_path}"
|
164 |
+
matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
|
165 |
+
max_filename_length = max([len(i) for i in frames])
|
166 |
+
if sum([i is not None for i in matched_keyframes])==0:
|
167 |
+
message = ""
|
168 |
+
for frame, matched_keyframe in zip(frames, matched_keyframes):
|
169 |
+
message += frame + " " * (max_filename_length - len(frame) + 1)
|
170 |
+
message += "--> No matched keyframes\n"
|
171 |
+
else:
|
172 |
+
message = ""
|
173 |
+
for frame, matched_keyframe in zip(frames, matched_keyframes):
|
174 |
+
message += frame + " " * (max_filename_length - len(frame) + 1)
|
175 |
+
if matched_keyframe is None:
|
176 |
+
message += "--> [to be rendered]\n"
|
177 |
+
else:
|
178 |
+
message += f"--> {matched_keyframe}\n"
|
179 |
+
return message
|
180 |
+
|
181 |
+
|
182 |
+
def check_input_for_interpolating(frames_path, keyframes_path):
|
183 |
+
# search for images
|
184 |
+
frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
|
185 |
+
keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
|
186 |
+
# match frames
|
187 |
+
matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
|
188 |
+
file_list = [file_name for file_name in matched_keyframes if file_name is not None]
|
189 |
+
index_style = [i for i, file_name in enumerate(matched_keyframes) if file_name is not None]
|
190 |
+
frames_guide = VideoData(None, frames_path)
|
191 |
+
frames_style = VideoData(None, keyframes_path, file_list=file_list)
|
192 |
+
# match shape
|
193 |
+
message = ""
|
194 |
+
height_guide, width_guide = frames_guide.shape()
|
195 |
+
height_style, width_style = frames_style.shape()
|
196 |
+
if height_guide != height_style or width_guide != width_style:
|
197 |
+
message += f"The shape of frames mismatches. The rendered keyframes will be resized to (height: {height_guide}, width: {width_guide})\n"
|
198 |
+
frames_style.set_shape(height_guide, width_guide)
|
199 |
+
return frames_guide, frames_style, index_style, message
|
200 |
+
|
201 |
+
|
202 |
+
def interpolate_video(
|
203 |
+
frames_path,
|
204 |
+
keyframes_path,
|
205 |
+
output_path,
|
206 |
+
fps,
|
207 |
+
batch_size,
|
208 |
+
tracking_window_size,
|
209 |
+
minimum_patch_size,
|
210 |
+
num_iter,
|
211 |
+
guide_weight,
|
212 |
+
initialize,
|
213 |
+
progress = None,
|
214 |
+
):
|
215 |
+
# input
|
216 |
+
frames_guide, frames_style, index_style, message = check_input_for_interpolating(frames_path, keyframes_path)
|
217 |
+
if len(message) > 0:
|
218 |
+
print(message)
|
219 |
+
# output
|
220 |
+
if output_path == "":
|
221 |
+
output_path = os.path.join(keyframes_path, "output")
|
222 |
+
os.makedirs(output_path, exist_ok=True)
|
223 |
+
print("No valid output_path. Your video will be saved here:", output_path)
|
224 |
+
elif not os.path.exists(output_path):
|
225 |
+
os.makedirs(output_path, exist_ok=True)
|
226 |
+
print("Your video will be saved here:", output_path)
|
227 |
+
output_frames_path = os.path.join(output_path, "frames")
|
228 |
+
output_video_path = os.path.join(output_path, "video.mp4")
|
229 |
+
os.makedirs(output_frames_path, exist_ok=True)
|
230 |
+
# process
|
231 |
+
ebsynth_config = {
|
232 |
+
"minimum_patch_size": minimum_patch_size,
|
233 |
+
"threads_per_block": 8,
|
234 |
+
"num_iter": num_iter,
|
235 |
+
"gpu_id": 0,
|
236 |
+
"guide_weight": guide_weight,
|
237 |
+
"initialize": initialize,
|
238 |
+
"tracking_window_size": tracking_window_size
|
239 |
+
}
|
240 |
+
if len(index_style)==1:
|
241 |
+
InterpolationModeSingleFrameRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
|
242 |
+
else:
|
243 |
+
InterpolationModeRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
|
244 |
+
try:
|
245 |
+
fps = int(fps)
|
246 |
+
except:
|
247 |
+
fps = 30
|
248 |
+
print("Fps:", fps)
|
249 |
+
print("Saving video...")
|
250 |
+
video_path = save_video(output_frames_path, output_video_path, num_frames=len(frames_guide), fps=fps)
|
251 |
+
print("Success!")
|
252 |
+
print("Your frames are here:", output_frames_path)
|
253 |
+
print("Your video is here:", video_path)
|
254 |
+
return output_path, fps, video_path
|
255 |
+
|
256 |
+
|
257 |
+
def on_ui_tabs():
|
258 |
+
with gr.Blocks(analytics_enabled=False) as ui_component:
|
259 |
+
with gr.Tab("Blend"):
|
260 |
+
gr.Markdown("""
|
261 |
+
# Blend
|
262 |
+
|
263 |
+
Given a guide video and a style video, this algorithm will make the style video fluent according to the motion features of the guide video. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/208d902d-6aba-48d7-b7d5-cd120ebd306d) to see the example. Note that this extension doesn't support long videos. Please use short videos (e.g., several seconds). The algorithm is mainly designed for 512*512 resolution. Please use a larger `Minimum patch size` for higher resolution.
|
264 |
+
""")
|
265 |
+
with gr.Row():
|
266 |
+
with gr.Column():
|
267 |
+
with gr.Tab("Guide video"):
|
268 |
+
video_guide = gr.Video(label="Guide video")
|
269 |
+
with gr.Tab("Guide video (images format)"):
|
270 |
+
video_guide_folder = gr.Textbox(label="Guide video (images format)", value="")
|
271 |
+
with gr.Column():
|
272 |
+
with gr.Tab("Style video"):
|
273 |
+
video_style = gr.Video(label="Style video")
|
274 |
+
with gr.Tab("Style video (images format)"):
|
275 |
+
video_style_folder = gr.Textbox(label="Style video (images format)", value="")
|
276 |
+
with gr.Column():
|
277 |
+
output_path = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of style video")
|
278 |
+
fps = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
|
279 |
+
video_output = gr.Video(label="Output video", interactive=False, show_share_button=True)
|
280 |
+
btn = gr.Button(value="Blend")
|
281 |
+
with gr.Row():
|
282 |
+
with gr.Column():
|
283 |
+
gr.Markdown("# Settings")
|
284 |
+
mode = gr.Radio(["Fast", "Balanced", "Accurate"], label="Inference mode", value="Fast", interactive=True)
|
285 |
+
window_size = gr.Slider(label="Sliding window size", value=15, minimum=1, maximum=1000, step=1, interactive=True)
|
286 |
+
batch_size = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
|
287 |
+
tracking_window_size = gr.Slider(label="Tracking window size (only for accurate mode)", value=0, minimum=0, maximum=10, step=1, interactive=True)
|
288 |
+
gr.Markdown("## Advanced Settings")
|
289 |
+
minimum_patch_size = gr.Slider(label="Minimum patch size (odd number)", value=5, minimum=5, maximum=99, step=2, interactive=True)
|
290 |
+
num_iter = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
|
291 |
+
guide_weight = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
|
292 |
+
initialize = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
|
293 |
+
with gr.Column():
|
294 |
+
gr.Markdown("""
|
295 |
+
# Reference
|
296 |
+
|
297 |
+
* Output directory: the directory to save the video.
|
298 |
+
* Inference mode
|
299 |
+
|
300 |
+
|Mode|Time|Memory|Quality|Frame by frame output|Description|
|
301 |
+
|-|-|-|-|-|-|
|
302 |
+
|Fast|■|■■■|■■|No|Blend the frames using a tree-like data structure, which requires much RAM but is fast.|
|
303 |
+
|Balanced|■■|■|■■|Yes|Blend the frames naively.|
|
304 |
+
|Accurate|■■■|■|■■■|Yes|Blend the frames and align them together for higher video quality. When [batch size] >= [sliding window size] * 2 + 1, the performance is the best.|
|
305 |
+
|
306 |
+
* Sliding window size: our algorithm will blend the frames in a sliding windows. If the size is n, each frame will be blended with the last n frames and the next n frames. A large sliding window can make the video fluent but sometimes smoggy.
|
307 |
+
* Batch size: a larger batch size makes the program faster but requires more VRAM.
|
308 |
+
* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
|
309 |
+
* Advanced settings
|
310 |
+
* Minimum patch size (odd number): the minimum patch size used for patch matching. (Default: 5)
|
311 |
+
* Number of iterations: the number of iterations of patch matching. (Default: 5)
|
312 |
+
* Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
|
313 |
+
* NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
|
314 |
+
""")
|
315 |
+
btn.click(
|
316 |
+
smooth_video,
|
317 |
+
inputs=[
|
318 |
+
video_guide,
|
319 |
+
video_guide_folder,
|
320 |
+
video_style,
|
321 |
+
video_style_folder,
|
322 |
+
mode,
|
323 |
+
window_size,
|
324 |
+
batch_size,
|
325 |
+
tracking_window_size,
|
326 |
+
output_path,
|
327 |
+
fps,
|
328 |
+
minimum_patch_size,
|
329 |
+
num_iter,
|
330 |
+
guide_weight,
|
331 |
+
initialize
|
332 |
+
],
|
333 |
+
outputs=[output_path, fps, video_output]
|
334 |
+
)
|
335 |
+
with gr.Tab("Interpolate"):
|
336 |
+
gr.Markdown("""
|
337 |
+
# Interpolate
|
338 |
+
|
339 |
+
Given a guide video and some rendered keyframes, this algorithm will render the remaining frames. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/3490c5b4-8f67-478f-86de-f9adc2ace16a) to see the example. The algorithm is experimental and is only tested for 512*512 resolution.
|
340 |
+
""")
|
341 |
+
with gr.Row():
|
342 |
+
with gr.Column():
|
343 |
+
with gr.Row():
|
344 |
+
with gr.Column():
|
345 |
+
video_guide_folder_ = gr.Textbox(label="Guide video (images format)", value="")
|
346 |
+
with gr.Column():
|
347 |
+
rendered_keyframes_ = gr.Textbox(label="Rendered keyframes (images format)", value="")
|
348 |
+
with gr.Row():
|
349 |
+
detected_frames = gr.Textbox(label="Detected frames", value="Please input the directory of guide video and rendered frames", lines=9, max_lines=9, interactive=False)
|
350 |
+
video_guide_folder_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
|
351 |
+
rendered_keyframes_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
|
352 |
+
with gr.Column():
|
353 |
+
output_path_ = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of rendered keyframes")
|
354 |
+
fps_ = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
|
355 |
+
video_output_ = gr.Video(label="Output video", interactive=False, show_share_button=True)
|
356 |
+
btn_ = gr.Button(value="Interpolate")
|
357 |
+
with gr.Row():
|
358 |
+
with gr.Column():
|
359 |
+
gr.Markdown("# Settings")
|
360 |
+
batch_size_ = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
|
361 |
+
tracking_window_size_ = gr.Slider(label="Tracking window size", value=0, minimum=0, maximum=10, step=1, interactive=True)
|
362 |
+
gr.Markdown("## Advanced Settings")
|
363 |
+
minimum_patch_size_ = gr.Slider(label="Minimum patch size (odd number, larger is better)", value=15, minimum=5, maximum=99, step=2, interactive=True)
|
364 |
+
num_iter_ = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
|
365 |
+
guide_weight_ = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
|
366 |
+
initialize_ = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
|
367 |
+
with gr.Column():
|
368 |
+
gr.Markdown("""
|
369 |
+
# Reference
|
370 |
+
|
371 |
+
* Output directory: the directory to save the video.
|
372 |
+
* Batch size: a larger batch size makes the program faster but requires more VRAM.
|
373 |
+
* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
|
374 |
+
* Advanced settings
|
375 |
+
* Minimum patch size (odd number): the minimum patch size used for patch matching. **This parameter should be larger than that in blending. (Default: 15)**
|
376 |
+
* Number of iterations: the number of iterations of patch matching. (Default: 5)
|
377 |
+
* Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
|
378 |
+
* NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
|
379 |
+
""")
|
380 |
+
btn_.click(
|
381 |
+
interpolate_video,
|
382 |
+
inputs=[
|
383 |
+
video_guide_folder_,
|
384 |
+
rendered_keyframes_,
|
385 |
+
output_path_,
|
386 |
+
fps_,
|
387 |
+
batch_size_,
|
388 |
+
tracking_window_size_,
|
389 |
+
minimum_patch_size_,
|
390 |
+
num_iter_,
|
391 |
+
guide_weight_,
|
392 |
+
initialize_,
|
393 |
+
],
|
394 |
+
outputs=[output_path_, fps_, video_output_]
|
395 |
+
)
|
396 |
+
|
397 |
+
return [(ui_component, "FastBlend", "FastBlend_ui")]
|
diffsynth/extensions/FastBlend/cupy_kernels.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cupy as cp
|
2 |
+
|
3 |
+
remapping_kernel = cp.RawKernel(r'''
|
4 |
+
extern "C" __global__
|
5 |
+
void remap(
|
6 |
+
const int height,
|
7 |
+
const int width,
|
8 |
+
const int channel,
|
9 |
+
const int patch_size,
|
10 |
+
const int pad_size,
|
11 |
+
const float* source_style,
|
12 |
+
const int* nnf,
|
13 |
+
float* target_style
|
14 |
+
) {
|
15 |
+
const int r = (patch_size - 1) / 2;
|
16 |
+
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
17 |
+
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
18 |
+
if (x >= height or y >= width) return;
|
19 |
+
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
20 |
+
const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size);
|
21 |
+
const int min_px = x < r ? -x : -r;
|
22 |
+
const int max_px = x + r > height - 1 ? height - 1 - x : r;
|
23 |
+
const int min_py = y < r ? -y : -r;
|
24 |
+
const int max_py = y + r > width - 1 ? width - 1 - y : r;
|
25 |
+
int num = 0;
|
26 |
+
for (int px = min_px; px <= max_px; px++){
|
27 |
+
for (int py = min_py; py <= max_py; py++){
|
28 |
+
const int nid = (x + px) * width + y + py;
|
29 |
+
const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px;
|
30 |
+
const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py;
|
31 |
+
if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue;
|
32 |
+
const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size);
|
33 |
+
num++;
|
34 |
+
for (int c = 0; c < channel; c++){
|
35 |
+
target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c];
|
36 |
+
}
|
37 |
+
}
|
38 |
+
}
|
39 |
+
for (int c = 0; c < channel; c++){
|
40 |
+
target_style[z + pid * channel + c] /= num;
|
41 |
+
}
|
42 |
+
}
|
43 |
+
''', 'remap')
|
44 |
+
|
45 |
+
|
46 |
+
patch_error_kernel = cp.RawKernel(r'''
|
47 |
+
extern "C" __global__
|
48 |
+
void patch_error(
|
49 |
+
const int height,
|
50 |
+
const int width,
|
51 |
+
const int channel,
|
52 |
+
const int patch_size,
|
53 |
+
const int pad_size,
|
54 |
+
const float* source,
|
55 |
+
const int* nnf,
|
56 |
+
const float* target,
|
57 |
+
float* error
|
58 |
+
) {
|
59 |
+
const int r = (patch_size - 1) / 2;
|
60 |
+
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
61 |
+
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
62 |
+
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
63 |
+
if (x >= height or y >= width) return;
|
64 |
+
const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0];
|
65 |
+
const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1];
|
66 |
+
float e = 0;
|
67 |
+
for (int px = -r; px <= r; px++){
|
68 |
+
for (int py = -r; py <= r; py++){
|
69 |
+
const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py;
|
70 |
+
const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py;
|
71 |
+
for (int c = 0; c < channel; c++){
|
72 |
+
const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c];
|
73 |
+
e += diff * diff;
|
74 |
+
}
|
75 |
+
}
|
76 |
+
}
|
77 |
+
error[blockIdx.z * height * width + x * width + y] = e;
|
78 |
+
}
|
79 |
+
''', 'patch_error')
|
80 |
+
|
81 |
+
|
82 |
+
pairwise_patch_error_kernel = cp.RawKernel(r'''
|
83 |
+
extern "C" __global__
|
84 |
+
void pairwise_patch_error(
|
85 |
+
const int height,
|
86 |
+
const int width,
|
87 |
+
const int channel,
|
88 |
+
const int patch_size,
|
89 |
+
const int pad_size,
|
90 |
+
const float* source_a,
|
91 |
+
const int* nnf_a,
|
92 |
+
const float* source_b,
|
93 |
+
const int* nnf_b,
|
94 |
+
float* error
|
95 |
+
) {
|
96 |
+
const int r = (patch_size - 1) / 2;
|
97 |
+
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
98 |
+
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
99 |
+
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
100 |
+
if (x >= height or y >= width) return;
|
101 |
+
const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2;
|
102 |
+
const int x_a = nnf_a[z_nnf + 0];
|
103 |
+
const int y_a = nnf_a[z_nnf + 1];
|
104 |
+
const int x_b = nnf_b[z_nnf + 0];
|
105 |
+
const int y_b = nnf_b[z_nnf + 1];
|
106 |
+
float e = 0;
|
107 |
+
for (int px = -r; px <= r; px++){
|
108 |
+
for (int py = -r; py <= r; py++){
|
109 |
+
const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py;
|
110 |
+
const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py;
|
111 |
+
for (int c = 0; c < channel; c++){
|
112 |
+
const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c];
|
113 |
+
e += diff * diff;
|
114 |
+
}
|
115 |
+
}
|
116 |
+
}
|
117 |
+
error[blockIdx.z * height * width + x * width + y] = e;
|
118 |
+
}
|
119 |
+
''', 'pairwise_patch_error')
|
diffsynth/extensions/FastBlend/data.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import imageio, os
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
def read_video(file_name):
|
7 |
+
reader = imageio.get_reader(file_name)
|
8 |
+
video = []
|
9 |
+
for frame in reader:
|
10 |
+
frame = np.array(frame)
|
11 |
+
video.append(frame)
|
12 |
+
reader.close()
|
13 |
+
return video
|
14 |
+
|
15 |
+
|
16 |
+
def get_video_fps(file_name):
|
17 |
+
reader = imageio.get_reader(file_name)
|
18 |
+
fps = reader.get_meta_data()["fps"]
|
19 |
+
reader.close()
|
20 |
+
return fps
|
21 |
+
|
22 |
+
|
23 |
+
def save_video(frames_path, video_path, num_frames, fps):
|
24 |
+
writer = imageio.get_writer(video_path, fps=fps, quality=9)
|
25 |
+
for i in range(num_frames):
|
26 |
+
frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i)))
|
27 |
+
writer.append_data(frame)
|
28 |
+
writer.close()
|
29 |
+
return video_path
|
30 |
+
|
31 |
+
|
32 |
+
class LowMemoryVideo:
|
33 |
+
def __init__(self, file_name):
|
34 |
+
self.reader = imageio.get_reader(file_name)
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return self.reader.count_frames()
|
38 |
+
|
39 |
+
def __getitem__(self, item):
|
40 |
+
return np.array(self.reader.get_data(item))
|
41 |
+
|
42 |
+
def __del__(self):
|
43 |
+
self.reader.close()
|
44 |
+
|
45 |
+
|
46 |
+
def split_file_name(file_name):
|
47 |
+
result = []
|
48 |
+
number = -1
|
49 |
+
for i in file_name:
|
50 |
+
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
51 |
+
if number == -1:
|
52 |
+
number = 0
|
53 |
+
number = number*10 + ord(i) - ord("0")
|
54 |
+
else:
|
55 |
+
if number != -1:
|
56 |
+
result.append(number)
|
57 |
+
number = -1
|
58 |
+
result.append(i)
|
59 |
+
if number != -1:
|
60 |
+
result.append(number)
|
61 |
+
result = tuple(result)
|
62 |
+
return result
|
63 |
+
|
64 |
+
|
65 |
+
def search_for_images(folder):
|
66 |
+
file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
|
67 |
+
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
|
68 |
+
file_list = [i[1] for i in sorted(file_list)]
|
69 |
+
file_list = [os.path.join(folder, i) for i in file_list]
|
70 |
+
return file_list
|
71 |
+
|
72 |
+
|
73 |
+
def read_images(folder):
|
74 |
+
file_list = search_for_images(folder)
|
75 |
+
frames = [np.array(Image.open(i)) for i in file_list]
|
76 |
+
return frames
|
77 |
+
|
78 |
+
|
79 |
+
class LowMemoryImageFolder:
|
80 |
+
def __init__(self, folder, file_list=None):
|
81 |
+
if file_list is None:
|
82 |
+
self.file_list = search_for_images(folder)
|
83 |
+
else:
|
84 |
+
self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
|
85 |
+
|
86 |
+
def __len__(self):
|
87 |
+
return len(self.file_list)
|
88 |
+
|
89 |
+
def __getitem__(self, item):
|
90 |
+
return np.array(Image.open(self.file_list[item]))
|
91 |
+
|
92 |
+
def __del__(self):
|
93 |
+
pass
|
94 |
+
|
95 |
+
|
96 |
+
class VideoData:
|
97 |
+
def __init__(self, video_file, image_folder, **kwargs):
|
98 |
+
if video_file is not None:
|
99 |
+
self.data_type = "video"
|
100 |
+
self.data = LowMemoryVideo(video_file, **kwargs)
|
101 |
+
elif image_folder is not None:
|
102 |
+
self.data_type = "images"
|
103 |
+
self.data = LowMemoryImageFolder(image_folder, **kwargs)
|
104 |
+
else:
|
105 |
+
raise ValueError("Cannot open video or image folder")
|
106 |
+
self.length = None
|
107 |
+
self.height = None
|
108 |
+
self.width = None
|
109 |
+
|
110 |
+
def raw_data(self):
|
111 |
+
frames = []
|
112 |
+
for i in range(self.__len__()):
|
113 |
+
frames.append(self.__getitem__(i))
|
114 |
+
return frames
|
115 |
+
|
116 |
+
def set_length(self, length):
|
117 |
+
self.length = length
|
118 |
+
|
119 |
+
def set_shape(self, height, width):
|
120 |
+
self.height = height
|
121 |
+
self.width = width
|
122 |
+
|
123 |
+
def __len__(self):
|
124 |
+
if self.length is None:
|
125 |
+
return len(self.data)
|
126 |
+
else:
|
127 |
+
return self.length
|
128 |
+
|
129 |
+
def shape(self):
|
130 |
+
if self.height is not None and self.width is not None:
|
131 |
+
return self.height, self.width
|
132 |
+
else:
|
133 |
+
height, width, _ = self.__getitem__(0).shape
|
134 |
+
return height, width
|
135 |
+
|
136 |
+
def __getitem__(self, item):
|
137 |
+
frame = self.data.__getitem__(item)
|
138 |
+
height, width, _ = frame.shape
|
139 |
+
if self.height is not None and self.width is not None:
|
140 |
+
if self.height != height or self.width != width:
|
141 |
+
frame = Image.fromarray(frame).resize((self.width, self.height))
|
142 |
+
frame = np.array(frame)
|
143 |
+
return frame
|
144 |
+
|
145 |
+
def __del__(self):
|
146 |
+
pass
|
diffsynth/extensions/FastBlend/patch_match.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel
|
2 |
+
import numpy as np
|
3 |
+
import cupy as cp
|
4 |
+
import cv2
|
5 |
+
|
6 |
+
|
7 |
+
class PatchMatcher:
|
8 |
+
def __init__(
|
9 |
+
self, height, width, channel, minimum_patch_size,
|
10 |
+
threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
|
11 |
+
random_search_steps=3, random_search_range=4,
|
12 |
+
use_mean_target_style=False, use_pairwise_patch_error=False,
|
13 |
+
tracking_window_size=0
|
14 |
+
):
|
15 |
+
self.height = height
|
16 |
+
self.width = width
|
17 |
+
self.channel = channel
|
18 |
+
self.minimum_patch_size = minimum_patch_size
|
19 |
+
self.threads_per_block = threads_per_block
|
20 |
+
self.num_iter = num_iter
|
21 |
+
self.gpu_id = gpu_id
|
22 |
+
self.guide_weight = guide_weight
|
23 |
+
self.random_search_steps = random_search_steps
|
24 |
+
self.random_search_range = random_search_range
|
25 |
+
self.use_mean_target_style = use_mean_target_style
|
26 |
+
self.use_pairwise_patch_error = use_pairwise_patch_error
|
27 |
+
self.tracking_window_size = tracking_window_size
|
28 |
+
|
29 |
+
self.patch_size_list = [minimum_patch_size + i*2 for i in range(num_iter)][::-1]
|
30 |
+
self.pad_size = self.patch_size_list[0] // 2
|
31 |
+
self.grid = (
|
32 |
+
(height + threads_per_block - 1) // threads_per_block,
|
33 |
+
(width + threads_per_block - 1) // threads_per_block
|
34 |
+
)
|
35 |
+
self.block = (threads_per_block, threads_per_block)
|
36 |
+
|
37 |
+
def pad_image(self, image):
|
38 |
+
return cp.pad(image, ((0, 0), (self.pad_size, self.pad_size), (self.pad_size, self.pad_size), (0, 0)))
|
39 |
+
|
40 |
+
def unpad_image(self, image):
|
41 |
+
return image[:, self.pad_size: -self.pad_size, self.pad_size: -self.pad_size, :]
|
42 |
+
|
43 |
+
def apply_nnf_to_image(self, nnf, source):
|
44 |
+
batch_size = source.shape[0]
|
45 |
+
target = cp.zeros((batch_size, self.height + self.pad_size * 2, self.width + self.pad_size * 2, self.channel), dtype=cp.float32)
|
46 |
+
remapping_kernel(
|
47 |
+
self.grid + (batch_size,),
|
48 |
+
self.block,
|
49 |
+
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target)
|
50 |
+
)
|
51 |
+
return target
|
52 |
+
|
53 |
+
def get_patch_error(self, source, nnf, target):
|
54 |
+
batch_size = source.shape[0]
|
55 |
+
error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
|
56 |
+
patch_error_kernel(
|
57 |
+
self.grid + (batch_size,),
|
58 |
+
self.block,
|
59 |
+
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target, error)
|
60 |
+
)
|
61 |
+
return error
|
62 |
+
|
63 |
+
def get_pairwise_patch_error(self, source, nnf):
|
64 |
+
batch_size = source.shape[0]//2
|
65 |
+
error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
|
66 |
+
source_a, nnf_a = source[0::2].copy(), nnf[0::2].copy()
|
67 |
+
source_b, nnf_b = source[1::2].copy(), nnf[1::2].copy()
|
68 |
+
pairwise_patch_error_kernel(
|
69 |
+
self.grid + (batch_size,),
|
70 |
+
self.block,
|
71 |
+
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source_a, nnf_a, source_b, nnf_b, error)
|
72 |
+
)
|
73 |
+
error = error.repeat(2, axis=0)
|
74 |
+
return error
|
75 |
+
|
76 |
+
def get_error(self, source_guide, target_guide, source_style, target_style, nnf):
|
77 |
+
error_guide = self.get_patch_error(source_guide, nnf, target_guide)
|
78 |
+
if self.use_mean_target_style:
|
79 |
+
target_style = self.apply_nnf_to_image(nnf, source_style)
|
80 |
+
target_style = target_style.mean(axis=0, keepdims=True)
|
81 |
+
target_style = target_style.repeat(source_guide.shape[0], axis=0)
|
82 |
+
if self.use_pairwise_patch_error:
|
83 |
+
error_style = self.get_pairwise_patch_error(source_style, nnf)
|
84 |
+
else:
|
85 |
+
error_style = self.get_patch_error(source_style, nnf, target_style)
|
86 |
+
error = error_guide * self.guide_weight + error_style
|
87 |
+
return error
|
88 |
+
|
89 |
+
def clamp_bound(self, nnf):
|
90 |
+
nnf[:,:,:,0] = cp.clip(nnf[:,:,:,0], 0, self.height-1)
|
91 |
+
nnf[:,:,:,1] = cp.clip(nnf[:,:,:,1], 0, self.width-1)
|
92 |
+
return nnf
|
93 |
+
|
94 |
+
def random_step(self, nnf, r):
|
95 |
+
batch_size = nnf.shape[0]
|
96 |
+
step = cp.random.randint(-r, r+1, size=(batch_size, self.height, self.width, 2), dtype=cp.int32)
|
97 |
+
upd_nnf = self.clamp_bound(nnf + step)
|
98 |
+
return upd_nnf
|
99 |
+
|
100 |
+
def neighboor_step(self, nnf, d):
|
101 |
+
if d==0:
|
102 |
+
upd_nnf = cp.concatenate([nnf[:, :1, :], nnf[:, :-1, :]], axis=1)
|
103 |
+
upd_nnf[:, :, :, 0] += 1
|
104 |
+
elif d==1:
|
105 |
+
upd_nnf = cp.concatenate([nnf[:, :, :1], nnf[:, :, :-1]], axis=2)
|
106 |
+
upd_nnf[:, :, :, 1] += 1
|
107 |
+
elif d==2:
|
108 |
+
upd_nnf = cp.concatenate([nnf[:, 1:, :], nnf[:, -1:, :]], axis=1)
|
109 |
+
upd_nnf[:, :, :, 0] -= 1
|
110 |
+
elif d==3:
|
111 |
+
upd_nnf = cp.concatenate([nnf[:, :, 1:], nnf[:, :, -1:]], axis=2)
|
112 |
+
upd_nnf[:, :, :, 1] -= 1
|
113 |
+
upd_nnf = self.clamp_bound(upd_nnf)
|
114 |
+
return upd_nnf
|
115 |
+
|
116 |
+
def shift_nnf(self, nnf, d):
|
117 |
+
if d>0:
|
118 |
+
d = min(nnf.shape[0], d)
|
119 |
+
upd_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
|
120 |
+
else:
|
121 |
+
d = max(-nnf.shape[0], d)
|
122 |
+
upd_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
|
123 |
+
return upd_nnf
|
124 |
+
|
125 |
+
def track_step(self, nnf, d):
|
126 |
+
if self.use_pairwise_patch_error:
|
127 |
+
upd_nnf = cp.zeros_like(nnf)
|
128 |
+
upd_nnf[0::2] = self.shift_nnf(nnf[0::2], d)
|
129 |
+
upd_nnf[1::2] = self.shift_nnf(nnf[1::2], d)
|
130 |
+
else:
|
131 |
+
upd_nnf = self.shift_nnf(nnf, d)
|
132 |
+
return upd_nnf
|
133 |
+
|
134 |
+
def C(self, n, m):
|
135 |
+
# not used
|
136 |
+
c = 1
|
137 |
+
for i in range(1, n+1):
|
138 |
+
c *= i
|
139 |
+
for i in range(1, m+1):
|
140 |
+
c //= i
|
141 |
+
for i in range(1, n-m+1):
|
142 |
+
c //= i
|
143 |
+
return c
|
144 |
+
|
145 |
+
def bezier_step(self, nnf, r):
|
146 |
+
# not used
|
147 |
+
n = r * 2 - 1
|
148 |
+
upd_nnf = cp.zeros(shape=nnf.shape, dtype=cp.float32)
|
149 |
+
for i, d in enumerate(list(range(-r, 0)) + list(range(1, r+1))):
|
150 |
+
if d>0:
|
151 |
+
ctl_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
|
152 |
+
elif d<0:
|
153 |
+
ctl_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
|
154 |
+
upd_nnf += ctl_nnf * (self.C(n, i) / 2**n)
|
155 |
+
upd_nnf = self.clamp_bound(upd_nnf).astype(nnf.dtype)
|
156 |
+
return upd_nnf
|
157 |
+
|
158 |
+
def update(self, source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf):
|
159 |
+
upd_err = self.get_error(source_guide, target_guide, source_style, target_style, upd_nnf)
|
160 |
+
upd_idx = (upd_err < err)
|
161 |
+
nnf[upd_idx] = upd_nnf[upd_idx]
|
162 |
+
err[upd_idx] = upd_err[upd_idx]
|
163 |
+
return nnf, err
|
164 |
+
|
165 |
+
def propagation(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
166 |
+
for d in cp.random.permutation(4):
|
167 |
+
upd_nnf = self.neighboor_step(nnf, d)
|
168 |
+
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
169 |
+
return nnf, err
|
170 |
+
|
171 |
+
def random_search(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
172 |
+
for i in range(self.random_search_steps):
|
173 |
+
upd_nnf = self.random_step(nnf, self.random_search_range)
|
174 |
+
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
175 |
+
return nnf, err
|
176 |
+
|
177 |
+
def track(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
178 |
+
for d in range(1, self.tracking_window_size + 1):
|
179 |
+
upd_nnf = self.track_step(nnf, d)
|
180 |
+
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
181 |
+
upd_nnf = self.track_step(nnf, -d)
|
182 |
+
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
183 |
+
return nnf, err
|
184 |
+
|
185 |
+
def iteration(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
186 |
+
nnf, err = self.propagation(source_guide, target_guide, source_style, target_style, nnf, err)
|
187 |
+
nnf, err = self.random_search(source_guide, target_guide, source_style, target_style, nnf, err)
|
188 |
+
nnf, err = self.track(source_guide, target_guide, source_style, target_style, nnf, err)
|
189 |
+
return nnf, err
|
190 |
+
|
191 |
+
def estimate_nnf(self, source_guide, target_guide, source_style, nnf):
|
192 |
+
with cp.cuda.Device(self.gpu_id):
|
193 |
+
source_guide = self.pad_image(source_guide)
|
194 |
+
target_guide = self.pad_image(target_guide)
|
195 |
+
source_style = self.pad_image(source_style)
|
196 |
+
for it in range(self.num_iter):
|
197 |
+
self.patch_size = self.patch_size_list[it]
|
198 |
+
target_style = self.apply_nnf_to_image(nnf, source_style)
|
199 |
+
err = self.get_error(source_guide, target_guide, source_style, target_style, nnf)
|
200 |
+
nnf, err = self.iteration(source_guide, target_guide, source_style, target_style, nnf, err)
|
201 |
+
target_style = self.unpad_image(self.apply_nnf_to_image(nnf, source_style))
|
202 |
+
return nnf, target_style
|
203 |
+
|
204 |
+
|
205 |
+
class PyramidPatchMatcher:
|
206 |
+
def __init__(
|
207 |
+
self, image_height, image_width, channel, minimum_patch_size,
|
208 |
+
threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
|
209 |
+
use_mean_target_style=False, use_pairwise_patch_error=False,
|
210 |
+
tracking_window_size=0,
|
211 |
+
initialize="identity"
|
212 |
+
):
|
213 |
+
maximum_patch_size = minimum_patch_size + (num_iter - 1) * 2
|
214 |
+
self.pyramid_level = int(np.log2(min(image_height, image_width) / maximum_patch_size))
|
215 |
+
self.pyramid_heights = []
|
216 |
+
self.pyramid_widths = []
|
217 |
+
self.patch_matchers = []
|
218 |
+
self.minimum_patch_size = minimum_patch_size
|
219 |
+
self.num_iter = num_iter
|
220 |
+
self.gpu_id = gpu_id
|
221 |
+
self.initialize = initialize
|
222 |
+
for level in range(self.pyramid_level):
|
223 |
+
height = image_height//(2**(self.pyramid_level - 1 - level))
|
224 |
+
width = image_width//(2**(self.pyramid_level - 1 - level))
|
225 |
+
self.pyramid_heights.append(height)
|
226 |
+
self.pyramid_widths.append(width)
|
227 |
+
self.patch_matchers.append(PatchMatcher(
|
228 |
+
height, width, channel, minimum_patch_size=minimum_patch_size,
|
229 |
+
threads_per_block=threads_per_block, num_iter=num_iter, gpu_id=gpu_id, guide_weight=guide_weight,
|
230 |
+
use_mean_target_style=use_mean_target_style, use_pairwise_patch_error=use_pairwise_patch_error,
|
231 |
+
tracking_window_size=tracking_window_size
|
232 |
+
))
|
233 |
+
|
234 |
+
def resample_image(self, images, level):
|
235 |
+
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
236 |
+
images = images.get()
|
237 |
+
images_resample = []
|
238 |
+
for image in images:
|
239 |
+
image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
|
240 |
+
images_resample.append(image_resample)
|
241 |
+
images_resample = cp.array(np.stack(images_resample), dtype=cp.float32)
|
242 |
+
return images_resample
|
243 |
+
|
244 |
+
def initialize_nnf(self, batch_size):
|
245 |
+
if self.initialize == "random":
|
246 |
+
height, width = self.pyramid_heights[0], self.pyramid_widths[0]
|
247 |
+
nnf = cp.stack([
|
248 |
+
cp.random.randint(0, height, (batch_size, height, width), dtype=cp.int32),
|
249 |
+
cp.random.randint(0, width, (batch_size, height, width), dtype=cp.int32)
|
250 |
+
], axis=3)
|
251 |
+
elif self.initialize == "identity":
|
252 |
+
height, width = self.pyramid_heights[0], self.pyramid_widths[0]
|
253 |
+
nnf = cp.stack([
|
254 |
+
cp.repeat(cp.arange(height), width).reshape(height, width),
|
255 |
+
cp.tile(cp.arange(width), height).reshape(height, width)
|
256 |
+
], axis=2)
|
257 |
+
nnf = cp.stack([nnf] * batch_size)
|
258 |
+
else:
|
259 |
+
raise NotImplementedError()
|
260 |
+
return nnf
|
261 |
+
|
262 |
+
def update_nnf(self, nnf, level):
|
263 |
+
# upscale
|
264 |
+
nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
|
265 |
+
nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1
|
266 |
+
nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1
|
267 |
+
# check if scale is 2
|
268 |
+
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
269 |
+
if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
|
270 |
+
nnf = nnf.get().astype(np.float32)
|
271 |
+
nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf]
|
272 |
+
nnf = cp.array(np.stack(nnf), dtype=cp.int32)
|
273 |
+
nnf = self.patch_matchers[level].clamp_bound(nnf)
|
274 |
+
return nnf
|
275 |
+
|
276 |
+
def apply_nnf_to_image(self, nnf, image):
|
277 |
+
with cp.cuda.Device(self.gpu_id):
|
278 |
+
image = self.patch_matchers[-1].pad_image(image)
|
279 |
+
image = self.patch_matchers[-1].apply_nnf_to_image(nnf, image)
|
280 |
+
return image
|
281 |
+
|
282 |
+
def estimate_nnf(self, source_guide, target_guide, source_style):
|
283 |
+
with cp.cuda.Device(self.gpu_id):
|
284 |
+
if not isinstance(source_guide, cp.ndarray):
|
285 |
+
source_guide = cp.array(source_guide, dtype=cp.float32)
|
286 |
+
if not isinstance(target_guide, cp.ndarray):
|
287 |
+
target_guide = cp.array(target_guide, dtype=cp.float32)
|
288 |
+
if not isinstance(source_style, cp.ndarray):
|
289 |
+
source_style = cp.array(source_style, dtype=cp.float32)
|
290 |
+
for level in range(self.pyramid_level):
|
291 |
+
nnf = self.initialize_nnf(source_guide.shape[0]) if level==0 else self.update_nnf(nnf, level)
|
292 |
+
source_guide_ = self.resample_image(source_guide, level)
|
293 |
+
target_guide_ = self.resample_image(target_guide, level)
|
294 |
+
source_style_ = self.resample_image(source_style, level)
|
295 |
+
nnf, target_style = self.patch_matchers[level].estimate_nnf(
|
296 |
+
source_guide_, target_guide_, source_style_, nnf
|
297 |
+
)
|
298 |
+
return nnf.get(), target_style.get()
|
diffsynth/extensions/FastBlend/runners/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .accurate import AccurateModeRunner
|
2 |
+
from .fast import FastModeRunner
|
3 |
+
from .balanced import BalancedModeRunner
|
4 |
+
from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner
|
diffsynth/extensions/FastBlend/runners/accurate.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..patch_match import PyramidPatchMatcher
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
class AccurateModeRunner:
|
9 |
+
def __init__(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None):
|
13 |
+
patch_match_engine = PyramidPatchMatcher(
|
14 |
+
image_height=frames_style[0].shape[0],
|
15 |
+
image_width=frames_style[0].shape[1],
|
16 |
+
channel=3,
|
17 |
+
use_mean_target_style=True,
|
18 |
+
**ebsynth_config
|
19 |
+
)
|
20 |
+
# run
|
21 |
+
n = len(frames_style)
|
22 |
+
for target in tqdm(range(n), desc=desc):
|
23 |
+
l, r = max(target - window_size, 0), min(target + window_size + 1, n)
|
24 |
+
remapped_frames = []
|
25 |
+
for i in range(l, r, batch_size):
|
26 |
+
j = min(i + batch_size, r)
|
27 |
+
source_guide = np.stack([frames_guide[source] for source in range(i, j)])
|
28 |
+
target_guide = np.stack([frames_guide[target]] * (j - i))
|
29 |
+
source_style = np.stack([frames_style[source] for source in range(i, j)])
|
30 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
31 |
+
remapped_frames.append(target_style)
|
32 |
+
frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
|
33 |
+
frame = frame.clip(0, 255).astype("uint8")
|
34 |
+
if save_path is not None:
|
35 |
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
diffsynth/extensions/FastBlend/runners/balanced.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..patch_match import PyramidPatchMatcher
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
class BalancedModeRunner:
|
9 |
+
def __init__(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None):
|
13 |
+
patch_match_engine = PyramidPatchMatcher(
|
14 |
+
image_height=frames_style[0].shape[0],
|
15 |
+
image_width=frames_style[0].shape[1],
|
16 |
+
channel=3,
|
17 |
+
**ebsynth_config
|
18 |
+
)
|
19 |
+
# tasks
|
20 |
+
n = len(frames_style)
|
21 |
+
tasks = []
|
22 |
+
for target in range(n):
|
23 |
+
for source in range(target - window_size, target + window_size + 1):
|
24 |
+
if source >= 0 and source < n and source != target:
|
25 |
+
tasks.append((source, target))
|
26 |
+
# run
|
27 |
+
frames = [(None, 1) for i in range(n)]
|
28 |
+
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
29 |
+
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
30 |
+
source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
|
31 |
+
target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
|
32 |
+
source_style = np.stack([frames_style[source] for source, target in tasks_batch])
|
33 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
34 |
+
for (source, target), result in zip(tasks_batch, target_style):
|
35 |
+
frame, weight = frames[target]
|
36 |
+
if frame is None:
|
37 |
+
frame = frames_style[target]
|
38 |
+
frames[target] = (
|
39 |
+
frame * (weight / (weight + 1)) + result / (weight + 1),
|
40 |
+
weight + 1
|
41 |
+
)
|
42 |
+
if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size):
|
43 |
+
frame = frame.clip(0, 255).astype("uint8")
|
44 |
+
if save_path is not None:
|
45 |
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
46 |
+
frames[target] = (None, 1)
|
diffsynth/extensions/FastBlend/runners/fast.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..patch_match import PyramidPatchMatcher
|
2 |
+
import functools, os
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
class TableManager:
|
9 |
+
def __init__(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
def task_list(self, n):
|
13 |
+
tasks = []
|
14 |
+
max_level = 1
|
15 |
+
while (1<<max_level)<=n:
|
16 |
+
max_level += 1
|
17 |
+
for i in range(n):
|
18 |
+
j = i
|
19 |
+
for level in range(max_level):
|
20 |
+
if i&(1<<level):
|
21 |
+
continue
|
22 |
+
j |= 1<<level
|
23 |
+
if j>=n:
|
24 |
+
break
|
25 |
+
meta_data = {
|
26 |
+
"source": i,
|
27 |
+
"target": j,
|
28 |
+
"level": level + 1
|
29 |
+
}
|
30 |
+
tasks.append(meta_data)
|
31 |
+
tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"]))
|
32 |
+
return tasks
|
33 |
+
|
34 |
+
def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""):
|
35 |
+
n = len(frames_guide)
|
36 |
+
tasks = self.task_list(n)
|
37 |
+
remapping_table = [[(frames_style[i], 1)] for i in range(n)]
|
38 |
+
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
39 |
+
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
40 |
+
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
|
41 |
+
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
|
42 |
+
source_style = np.stack([frames_style[task["source"]] for task in tasks_batch])
|
43 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
44 |
+
for task, result in zip(tasks_batch, target_style):
|
45 |
+
target, level = task["target"], task["level"]
|
46 |
+
if len(remapping_table[target])==level:
|
47 |
+
remapping_table[target].append((result, 1))
|
48 |
+
else:
|
49 |
+
frame, weight = remapping_table[target][level]
|
50 |
+
remapping_table[target][level] = (
|
51 |
+
frame * (weight / (weight + 1)) + result / (weight + 1),
|
52 |
+
weight + 1
|
53 |
+
)
|
54 |
+
return remapping_table
|
55 |
+
|
56 |
+
def remapping_table_to_blending_table(self, table):
|
57 |
+
for i in range(len(table)):
|
58 |
+
for j in range(1, len(table[i])):
|
59 |
+
frame_1, weight_1 = table[i][j-1]
|
60 |
+
frame_2, weight_2 = table[i][j]
|
61 |
+
frame = (frame_1 + frame_2) / 2
|
62 |
+
weight = weight_1 + weight_2
|
63 |
+
table[i][j] = (frame, weight)
|
64 |
+
return table
|
65 |
+
|
66 |
+
def tree_query(self, leftbound, rightbound):
|
67 |
+
node_list = []
|
68 |
+
node_index = rightbound
|
69 |
+
while node_index>=leftbound:
|
70 |
+
node_level = 0
|
71 |
+
while (1<<node_level)&node_index and node_index-(1<<node_level+1)+1>=leftbound:
|
72 |
+
node_level += 1
|
73 |
+
node_list.append((node_index, node_level))
|
74 |
+
node_index -= 1<<node_level
|
75 |
+
return node_list
|
76 |
+
|
77 |
+
def process_window_sum(self, frames_guide, blending_table, patch_match_engine, window_size, batch_size, desc=""):
|
78 |
+
n = len(blending_table)
|
79 |
+
tasks = []
|
80 |
+
frames_result = []
|
81 |
+
for target in range(n):
|
82 |
+
node_list = self.tree_query(max(target-window_size, 0), target)
|
83 |
+
for source, level in node_list:
|
84 |
+
if source!=target:
|
85 |
+
meta_data = {
|
86 |
+
"source": source,
|
87 |
+
"target": target,
|
88 |
+
"level": level
|
89 |
+
}
|
90 |
+
tasks.append(meta_data)
|
91 |
+
else:
|
92 |
+
frames_result.append(blending_table[target][level])
|
93 |
+
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
94 |
+
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
95 |
+
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
|
96 |
+
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
|
97 |
+
source_style = np.stack([blending_table[task["source"]][task["level"]][0] for task in tasks_batch])
|
98 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
99 |
+
for task, frame_2 in zip(tasks_batch, target_style):
|
100 |
+
source, target, level = task["source"], task["target"], task["level"]
|
101 |
+
frame_1, weight_1 = frames_result[target]
|
102 |
+
weight_2 = blending_table[source][level][1]
|
103 |
+
weight = weight_1 + weight_2
|
104 |
+
frame = frame_1 * (weight_1 / weight) + frame_2 * (weight_2 / weight)
|
105 |
+
frames_result[target] = (frame, weight)
|
106 |
+
return frames_result
|
107 |
+
|
108 |
+
|
109 |
+
class FastModeRunner:
|
110 |
+
def __init__(self):
|
111 |
+
pass
|
112 |
+
|
113 |
+
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, save_path=None):
|
114 |
+
frames_guide = frames_guide.raw_data()
|
115 |
+
frames_style = frames_style.raw_data()
|
116 |
+
table_manager = TableManager()
|
117 |
+
patch_match_engine = PyramidPatchMatcher(
|
118 |
+
image_height=frames_style[0].shape[0],
|
119 |
+
image_width=frames_style[0].shape[1],
|
120 |
+
channel=3,
|
121 |
+
**ebsynth_config
|
122 |
+
)
|
123 |
+
# left part
|
124 |
+
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="Fast Mode Step 1/4")
|
125 |
+
table_l = table_manager.remapping_table_to_blending_table(table_l)
|
126 |
+
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 2/4")
|
127 |
+
# right part
|
128 |
+
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="Fast Mode Step 3/4")
|
129 |
+
table_r = table_manager.remapping_table_to_blending_table(table_r)
|
130 |
+
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 4/4")[::-1]
|
131 |
+
# merge
|
132 |
+
frames = []
|
133 |
+
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
|
134 |
+
weight_m = -1
|
135 |
+
weight = weight_l + weight_m + weight_r
|
136 |
+
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
|
137 |
+
frames.append(frame)
|
138 |
+
frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
|
139 |
+
if save_path is not None:
|
140 |
+
for target, frame in enumerate(frames):
|
141 |
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
diffsynth/extensions/FastBlend/runners/interpolation.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..patch_match import PyramidPatchMatcher
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
class InterpolationModeRunner:
|
9 |
+
def __init__(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
def get_index_dict(self, index_style):
|
13 |
+
index_dict = {}
|
14 |
+
for i, index in enumerate(index_style):
|
15 |
+
index_dict[index] = i
|
16 |
+
return index_dict
|
17 |
+
|
18 |
+
def get_weight(self, l, m, r):
|
19 |
+
weight_l, weight_r = abs(m - r), abs(m - l)
|
20 |
+
if weight_l + weight_r == 0:
|
21 |
+
weight_l, weight_r = 0.5, 0.5
|
22 |
+
else:
|
23 |
+
weight_l, weight_r = weight_l / (weight_l + weight_r), weight_r / (weight_l + weight_r)
|
24 |
+
return weight_l, weight_r
|
25 |
+
|
26 |
+
def get_task_group(self, index_style, n):
|
27 |
+
task_group = []
|
28 |
+
index_style = sorted(index_style)
|
29 |
+
# first frame
|
30 |
+
if index_style[0]>0:
|
31 |
+
tasks = []
|
32 |
+
for m in range(index_style[0]):
|
33 |
+
tasks.append((index_style[0], m, index_style[0]))
|
34 |
+
task_group.append(tasks)
|
35 |
+
# middle frames
|
36 |
+
for l, r in zip(index_style[:-1], index_style[1:]):
|
37 |
+
tasks = []
|
38 |
+
for m in range(l, r):
|
39 |
+
tasks.append((l, m, r))
|
40 |
+
task_group.append(tasks)
|
41 |
+
# last frame
|
42 |
+
tasks = []
|
43 |
+
for m in range(index_style[-1], n):
|
44 |
+
tasks.append((index_style[-1], m, index_style[-1]))
|
45 |
+
task_group.append(tasks)
|
46 |
+
return task_group
|
47 |
+
|
48 |
+
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
|
49 |
+
patch_match_engine = PyramidPatchMatcher(
|
50 |
+
image_height=frames_style[0].shape[0],
|
51 |
+
image_width=frames_style[0].shape[1],
|
52 |
+
channel=3,
|
53 |
+
use_mean_target_style=False,
|
54 |
+
use_pairwise_patch_error=True,
|
55 |
+
**ebsynth_config
|
56 |
+
)
|
57 |
+
# task
|
58 |
+
index_dict = self.get_index_dict(index_style)
|
59 |
+
task_group = self.get_task_group(index_style, len(frames_guide))
|
60 |
+
# run
|
61 |
+
for tasks in task_group:
|
62 |
+
index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks])
|
63 |
+
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"):
|
64 |
+
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
65 |
+
source_guide, target_guide, source_style = [], [], []
|
66 |
+
for l, m, r in tasks_batch:
|
67 |
+
# l -> m
|
68 |
+
source_guide.append(frames_guide[l])
|
69 |
+
target_guide.append(frames_guide[m])
|
70 |
+
source_style.append(frames_style[index_dict[l]])
|
71 |
+
# r -> m
|
72 |
+
source_guide.append(frames_guide[r])
|
73 |
+
target_guide.append(frames_guide[m])
|
74 |
+
source_style.append(frames_style[index_dict[r]])
|
75 |
+
source_guide = np.stack(source_guide)
|
76 |
+
target_guide = np.stack(target_guide)
|
77 |
+
source_style = np.stack(source_style)
|
78 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
79 |
+
if save_path is not None:
|
80 |
+
for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch):
|
81 |
+
weight_l, weight_r = self.get_weight(l, m, r)
|
82 |
+
frame = frame_l * weight_l + frame_r * weight_r
|
83 |
+
frame = frame.clip(0, 255).astype("uint8")
|
84 |
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m))
|
85 |
+
|
86 |
+
|
87 |
+
class InterpolationModeSingleFrameRunner:
|
88 |
+
def __init__(self):
|
89 |
+
pass
|
90 |
+
|
91 |
+
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
|
92 |
+
# check input
|
93 |
+
tracking_window_size = ebsynth_config["tracking_window_size"]
|
94 |
+
if tracking_window_size * 2 >= batch_size:
|
95 |
+
raise ValueError("batch_size should be larger than track_window_size * 2")
|
96 |
+
frame_style = frames_style[0]
|
97 |
+
frame_guide = frames_guide[index_style[0]]
|
98 |
+
patch_match_engine = PyramidPatchMatcher(
|
99 |
+
image_height=frame_style.shape[0],
|
100 |
+
image_width=frame_style.shape[1],
|
101 |
+
channel=3,
|
102 |
+
**ebsynth_config
|
103 |
+
)
|
104 |
+
# run
|
105 |
+
frame_id, n = 0, len(frames_guide)
|
106 |
+
for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"):
|
107 |
+
if i + batch_size > n:
|
108 |
+
l, r = max(n - batch_size, 0), n
|
109 |
+
else:
|
110 |
+
l, r = i, i + batch_size
|
111 |
+
source_guide = np.stack([frame_guide] * (r-l))
|
112 |
+
target_guide = np.stack([frames_guide[i] for i in range(l, r)])
|
113 |
+
source_style = np.stack([frame_style] * (r-l))
|
114 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
115 |
+
for i, frame in zip(range(l, r), target_style):
|
116 |
+
if i==frame_id:
|
117 |
+
frame = frame.clip(0, 255).astype("uint8")
|
118 |
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id))
|
119 |
+
frame_id += 1
|
120 |
+
if r < n and r-frame_id <= tracking_window_size:
|
121 |
+
break
|
diffsynth/extensions/RIFE/__init__.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
def warp(tenInput, tenFlow, device):
|
9 |
+
backwarp_tenGrid = {}
|
10 |
+
k = (str(tenFlow.device), str(tenFlow.size()))
|
11 |
+
if k not in backwarp_tenGrid:
|
12 |
+
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
|
13 |
+
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
|
14 |
+
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
|
15 |
+
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
|
16 |
+
backwarp_tenGrid[k] = torch.cat(
|
17 |
+
[tenHorizontal, tenVertical], 1).to(device)
|
18 |
+
|
19 |
+
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
|
20 |
+
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
|
21 |
+
|
22 |
+
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
|
23 |
+
return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
|
24 |
+
|
25 |
+
|
26 |
+
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
27 |
+
return nn.Sequential(
|
28 |
+
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
29 |
+
padding=padding, dilation=dilation, bias=True),
|
30 |
+
nn.PReLU(out_planes)
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
class IFBlock(nn.Module):
|
35 |
+
def __init__(self, in_planes, c=64):
|
36 |
+
super(IFBlock, self).__init__()
|
37 |
+
self.conv0 = nn.Sequential(conv(in_planes, c//2, 3, 2, 1), conv(c//2, c, 3, 2, 1),)
|
38 |
+
self.convblock0 = nn.Sequential(conv(c, c), conv(c, c))
|
39 |
+
self.convblock1 = nn.Sequential(conv(c, c), conv(c, c))
|
40 |
+
self.convblock2 = nn.Sequential(conv(c, c), conv(c, c))
|
41 |
+
self.convblock3 = nn.Sequential(conv(c, c), conv(c, c))
|
42 |
+
self.conv1 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 4, 4, 2, 1))
|
43 |
+
self.conv2 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 1, 4, 2, 1))
|
44 |
+
|
45 |
+
def forward(self, x, flow, scale=1):
|
46 |
+
x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
47 |
+
flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale
|
48 |
+
feat = self.conv0(torch.cat((x, flow), 1))
|
49 |
+
feat = self.convblock0(feat) + feat
|
50 |
+
feat = self.convblock1(feat) + feat
|
51 |
+
feat = self.convblock2(feat) + feat
|
52 |
+
feat = self.convblock3(feat) + feat
|
53 |
+
flow = self.conv1(feat)
|
54 |
+
mask = self.conv2(feat)
|
55 |
+
flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale
|
56 |
+
mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
57 |
+
return flow, mask
|
58 |
+
|
59 |
+
|
60 |
+
class IFNet(nn.Module):
|
61 |
+
def __init__(self):
|
62 |
+
super(IFNet, self).__init__()
|
63 |
+
self.block0 = IFBlock(7+4, c=90)
|
64 |
+
self.block1 = IFBlock(7+4, c=90)
|
65 |
+
self.block2 = IFBlock(7+4, c=90)
|
66 |
+
self.block_tea = IFBlock(10+4, c=90)
|
67 |
+
|
68 |
+
def forward(self, x, scale_list=[4, 2, 1], training=False):
|
69 |
+
if training == False:
|
70 |
+
channel = x.shape[1] // 2
|
71 |
+
img0 = x[:, :channel]
|
72 |
+
img1 = x[:, channel:]
|
73 |
+
flow_list = []
|
74 |
+
merged = []
|
75 |
+
mask_list = []
|
76 |
+
warped_img0 = img0
|
77 |
+
warped_img1 = img1
|
78 |
+
flow = (x[:, :4]).detach() * 0
|
79 |
+
mask = (x[:, :1]).detach() * 0
|
80 |
+
block = [self.block0, self.block1, self.block2]
|
81 |
+
for i in range(3):
|
82 |
+
f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
|
83 |
+
f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
|
84 |
+
flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
|
85 |
+
mask = mask + (m0 + (-m1)) / 2
|
86 |
+
mask_list.append(mask)
|
87 |
+
flow_list.append(flow)
|
88 |
+
warped_img0 = warp(img0, flow[:, :2], device=x.device)
|
89 |
+
warped_img1 = warp(img1, flow[:, 2:4], device=x.device)
|
90 |
+
merged.append((warped_img0, warped_img1))
|
91 |
+
'''
|
92 |
+
c0 = self.contextnet(img0, flow[:, :2])
|
93 |
+
c1 = self.contextnet(img1, flow[:, 2:4])
|
94 |
+
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
95 |
+
res = tmp[:, 1:4] * 2 - 1
|
96 |
+
'''
|
97 |
+
for i in range(3):
|
98 |
+
mask_list[i] = torch.sigmoid(mask_list[i])
|
99 |
+
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
100 |
+
return flow_list, mask_list[2], merged
|
101 |
+
|
102 |
+
def state_dict_converter(self):
|
103 |
+
return IFNetStateDictConverter()
|
104 |
+
|
105 |
+
|
106 |
+
class IFNetStateDictConverter:
|
107 |
+
def __init__(self):
|
108 |
+
pass
|
109 |
+
|
110 |
+
def from_diffusers(self, state_dict):
|
111 |
+
state_dict_ = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
112 |
+
return state_dict_
|
113 |
+
|
114 |
+
def from_civitai(self, state_dict):
|
115 |
+
return self.from_diffusers(state_dict)
|
116 |
+
|
117 |
+
|
118 |
+
class RIFEInterpolater:
|
119 |
+
def __init__(self, model, device="cuda"):
|
120 |
+
self.model = model
|
121 |
+
self.device = device
|
122 |
+
# IFNet only does not support float16
|
123 |
+
self.torch_dtype = torch.float32
|
124 |
+
|
125 |
+
@staticmethod
|
126 |
+
def from_model_manager(model_manager):
|
127 |
+
return RIFEInterpolater(model_manager.RIFE, device=model_manager.device)
|
128 |
+
|
129 |
+
def process_image(self, image):
|
130 |
+
width, height = image.size
|
131 |
+
if width % 32 != 0 or height % 32 != 0:
|
132 |
+
width = (width + 31) // 32
|
133 |
+
height = (height + 31) // 32
|
134 |
+
image = image.resize((width, height))
|
135 |
+
image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1)
|
136 |
+
return image
|
137 |
+
|
138 |
+
def process_images(self, images):
|
139 |
+
images = [self.process_image(image) for image in images]
|
140 |
+
images = torch.stack(images)
|
141 |
+
return images
|
142 |
+
|
143 |
+
def decode_images(self, images):
|
144 |
+
images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
|
145 |
+
images = [Image.fromarray(image) for image in images]
|
146 |
+
return images
|
147 |
+
|
148 |
+
def add_interpolated_images(self, images, interpolated_images):
|
149 |
+
output_images = []
|
150 |
+
for image, interpolated_image in zip(images, interpolated_images):
|
151 |
+
output_images.append(image)
|
152 |
+
output_images.append(interpolated_image)
|
153 |
+
output_images.append(images[-1])
|
154 |
+
return output_images
|
155 |
+
|
156 |
+
|
157 |
+
@torch.no_grad()
|
158 |
+
def interpolate_(self, images, scale=1.0):
|
159 |
+
input_tensor = self.process_images(images)
|
160 |
+
input_tensor = torch.cat((input_tensor[:-1], input_tensor[1:]), dim=1)
|
161 |
+
input_tensor = input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
162 |
+
flow, mask, merged = self.model(input_tensor, [4/scale, 2/scale, 1/scale])
|
163 |
+
output_images = self.decode_images(merged[2].cpu())
|
164 |
+
if output_images[0].size != images[0].size:
|
165 |
+
output_images = [image.resize(images[0].size) for image in output_images]
|
166 |
+
return output_images
|
167 |
+
|
168 |
+
|
169 |
+
@torch.no_grad()
|
170 |
+
def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1, progress_bar=lambda x:x):
|
171 |
+
# Preprocess
|
172 |
+
processed_images = self.process_images(images)
|
173 |
+
|
174 |
+
for iter in range(num_iter):
|
175 |
+
# Input
|
176 |
+
input_tensor = torch.cat((processed_images[:-1], processed_images[1:]), dim=1)
|
177 |
+
|
178 |
+
# Interpolate
|
179 |
+
output_tensor = []
|
180 |
+
for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
|
181 |
+
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
182 |
+
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
183 |
+
batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
184 |
+
flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
|
185 |
+
output_tensor.append(merged[2].cpu())
|
186 |
+
|
187 |
+
# Output
|
188 |
+
output_tensor = torch.concat(output_tensor, dim=0).clip(0, 1)
|
189 |
+
processed_images = self.add_interpolated_images(processed_images, output_tensor)
|
190 |
+
processed_images = torch.stack(processed_images)
|
191 |
+
|
192 |
+
# To images
|
193 |
+
output_images = self.decode_images(processed_images)
|
194 |
+
if output_images[0].size != images[0].size:
|
195 |
+
output_images = [image.resize(images[0].size) for image in output_images]
|
196 |
+
return output_images
|
197 |
+
|
198 |
+
|
199 |
+
class RIFESmoother(RIFEInterpolater):
|
200 |
+
def __init__(self, model, device="cuda"):
|
201 |
+
super(RIFESmoother, self).__init__(model, device=device)
|
202 |
+
|
203 |
+
@staticmethod
|
204 |
+
def from_model_manager(model_manager):
|
205 |
+
return RIFESmoother(model_manager.RIFE, device=model_manager.device)
|
206 |
+
|
207 |
+
def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
|
208 |
+
output_tensor = []
|
209 |
+
for batch_id in range(0, input_tensor.shape[0], batch_size):
|
210 |
+
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
211 |
+
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
212 |
+
batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
213 |
+
flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
|
214 |
+
output_tensor.append(merged[2].cpu())
|
215 |
+
output_tensor = torch.concat(output_tensor, dim=0)
|
216 |
+
return output_tensor
|
217 |
+
|
218 |
+
@torch.no_grad()
|
219 |
+
def __call__(self, rendered_frames, scale=1.0, batch_size=4, num_iter=1, **kwargs):
|
220 |
+
# Preprocess
|
221 |
+
processed_images = self.process_images(rendered_frames)
|
222 |
+
|
223 |
+
for iter in range(num_iter):
|
224 |
+
# Input
|
225 |
+
input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1)
|
226 |
+
|
227 |
+
# Interpolate
|
228 |
+
output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
|
229 |
+
|
230 |
+
# Blend
|
231 |
+
input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1)
|
232 |
+
output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
|
233 |
+
|
234 |
+
# Add to frames
|
235 |
+
processed_images[1:-1] = output_tensor
|
236 |
+
|
237 |
+
# To images
|
238 |
+
output_images = self.decode_images(processed_images)
|
239 |
+
if output_images[0].size != rendered_frames[0].size:
|
240 |
+
output_images = [image.resize(rendered_frames[0].size) for image in output_images]
|
241 |
+
return output_images
|
diffsynth/models/__init__.py
ADDED
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, os
|
2 |
+
from safetensors import safe_open
|
3 |
+
|
4 |
+
from .sd_text_encoder import SDTextEncoder
|
5 |
+
from .sd_unet import SDUNet
|
6 |
+
from .sd_vae_encoder import SDVAEEncoder
|
7 |
+
from .sd_vae_decoder import SDVAEDecoder
|
8 |
+
from .sd_lora import SDLoRA
|
9 |
+
|
10 |
+
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
11 |
+
from .sdxl_unet import SDXLUNet
|
12 |
+
from .sdxl_vae_decoder import SDXLVAEDecoder
|
13 |
+
from .sdxl_vae_encoder import SDXLVAEEncoder
|
14 |
+
|
15 |
+
from .sd_controlnet import SDControlNet
|
16 |
+
|
17 |
+
from .sd_motion import SDMotionModel
|
18 |
+
from .sdxl_motion import SDXLMotionModel
|
19 |
+
|
20 |
+
from .svd_image_encoder import SVDImageEncoder
|
21 |
+
from .svd_unet import SVDUNet
|
22 |
+
from .svd_vae_decoder import SVDVAEDecoder
|
23 |
+
from .svd_vae_encoder import SVDVAEEncoder
|
24 |
+
|
25 |
+
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
26 |
+
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
27 |
+
|
28 |
+
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
29 |
+
from .hunyuan_dit import HunyuanDiT
|
30 |
+
|
31 |
+
|
32 |
+
class ModelManager:
|
33 |
+
def __init__(self, torch_dtype=torch.float16, device="cuda"):
|
34 |
+
self.torch_dtype = torch_dtype
|
35 |
+
self.device = device
|
36 |
+
self.model = {}
|
37 |
+
self.model_path = {}
|
38 |
+
self.textual_inversion_dict = {}
|
39 |
+
|
40 |
+
def is_stable_video_diffusion(self, state_dict):
|
41 |
+
param_name = "model.diffusion_model.output_blocks.9.1.time_stack.0.norm_in.weight"
|
42 |
+
return param_name in state_dict
|
43 |
+
|
44 |
+
def is_RIFE(self, state_dict):
|
45 |
+
param_name = "block_tea.convblock3.0.1.weight"
|
46 |
+
return param_name in state_dict or ("module." + param_name) in state_dict
|
47 |
+
|
48 |
+
def is_beautiful_prompt(self, state_dict):
|
49 |
+
param_name = "transformer.h.9.self_attention.query_key_value.weight"
|
50 |
+
return param_name in state_dict
|
51 |
+
|
52 |
+
def is_stabe_diffusion_xl(self, state_dict):
|
53 |
+
param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight"
|
54 |
+
return param_name in state_dict
|
55 |
+
|
56 |
+
def is_stable_diffusion(self, state_dict):
|
57 |
+
if self.is_stabe_diffusion_xl(state_dict):
|
58 |
+
return False
|
59 |
+
param_name = "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight"
|
60 |
+
return param_name in state_dict
|
61 |
+
|
62 |
+
def is_controlnet(self, state_dict):
|
63 |
+
param_name = "control_model.time_embed.0.weight"
|
64 |
+
param_name_2 = "mid_block.resnets.1.time_emb_proj.weight" # For controlnets in diffusers format
|
65 |
+
return param_name in state_dict or param_name_2 in state_dict
|
66 |
+
|
67 |
+
def is_animatediff(self, state_dict):
|
68 |
+
param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
|
69 |
+
return param_name in state_dict
|
70 |
+
|
71 |
+
def is_animatediff_xl(self, state_dict):
|
72 |
+
param_name = "up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.ff_norm.weight"
|
73 |
+
return param_name in state_dict
|
74 |
+
|
75 |
+
def is_sd_lora(self, state_dict):
|
76 |
+
param_name = "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_ff_net_2.lora_up.weight"
|
77 |
+
return param_name in state_dict
|
78 |
+
|
79 |
+
def is_translator(self, state_dict):
|
80 |
+
param_name = "model.encoder.layers.5.self_attn_layer_norm.weight"
|
81 |
+
return param_name in state_dict and len(state_dict) == 254
|
82 |
+
|
83 |
+
def is_ipadapter(self, state_dict):
|
84 |
+
return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([3072, 1024])
|
85 |
+
|
86 |
+
def is_ipadapter_image_encoder(self, state_dict):
|
87 |
+
param_name = "vision_model.encoder.layers.31.self_attn.v_proj.weight"
|
88 |
+
return param_name in state_dict and len(state_dict) == 521
|
89 |
+
|
90 |
+
def is_ipadapter_xl(self, state_dict):
|
91 |
+
return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([8192, 1280])
|
92 |
+
|
93 |
+
def is_ipadapter_xl_image_encoder(self, state_dict):
|
94 |
+
param_name = "vision_model.encoder.layers.47.self_attn.v_proj.weight"
|
95 |
+
return param_name in state_dict and len(state_dict) == 777
|
96 |
+
|
97 |
+
def is_hunyuan_dit_clip_text_encoder(self, state_dict):
|
98 |
+
param_name = "bert.encoder.layer.23.attention.output.dense.weight"
|
99 |
+
return param_name in state_dict
|
100 |
+
|
101 |
+
def is_hunyuan_dit_t5_text_encoder(self, state_dict):
|
102 |
+
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
|
103 |
+
return param_name in state_dict
|
104 |
+
|
105 |
+
def is_hunyuan_dit(self, state_dict):
|
106 |
+
param_name = "final_layer.adaLN_modulation.1.weight"
|
107 |
+
return param_name in state_dict
|
108 |
+
|
109 |
+
def is_diffusers_vae(self, state_dict):
|
110 |
+
param_name = "quant_conv.weight"
|
111 |
+
return param_name in state_dict
|
112 |
+
|
113 |
+
def is_ExVideo_StableVideoDiffusion(self, state_dict):
|
114 |
+
param_name = "blocks.185.positional_embedding.embeddings"
|
115 |
+
return param_name in state_dict
|
116 |
+
|
117 |
+
def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None):
|
118 |
+
component_dict = {
|
119 |
+
"image_encoder": SVDImageEncoder,
|
120 |
+
"unet": SVDUNet,
|
121 |
+
"vae_decoder": SVDVAEDecoder,
|
122 |
+
"vae_encoder": SVDVAEEncoder,
|
123 |
+
}
|
124 |
+
if components is None:
|
125 |
+
components = ["image_encoder", "unet", "vae_decoder", "vae_encoder"]
|
126 |
+
for component in components:
|
127 |
+
if component == "unet":
|
128 |
+
self.model[component] = component_dict[component](add_positional_conv=add_positional_conv)
|
129 |
+
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv), strict=False)
|
130 |
+
else:
|
131 |
+
self.model[component] = component_dict[component]()
|
132 |
+
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
133 |
+
self.model[component].to(self.torch_dtype).to(self.device)
|
134 |
+
self.model_path[component] = file_path
|
135 |
+
|
136 |
+
def load_stable_diffusion(self, state_dict, components=None, file_path=""):
|
137 |
+
component_dict = {
|
138 |
+
"text_encoder": SDTextEncoder,
|
139 |
+
"unet": SDUNet,
|
140 |
+
"vae_decoder": SDVAEDecoder,
|
141 |
+
"vae_encoder": SDVAEEncoder,
|
142 |
+
"refiner": SDXLUNet,
|
143 |
+
}
|
144 |
+
if components is None:
|
145 |
+
components = ["text_encoder", "unet", "vae_decoder", "vae_encoder"]
|
146 |
+
for component in components:
|
147 |
+
if component == "text_encoder":
|
148 |
+
# Add additional token embeddings to text encoder
|
149 |
+
token_embeddings = [state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"]]
|
150 |
+
for keyword in self.textual_inversion_dict:
|
151 |
+
_, embeddings = self.textual_inversion_dict[keyword]
|
152 |
+
token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype))
|
153 |
+
token_embeddings = torch.concat(token_embeddings, dim=0)
|
154 |
+
state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings
|
155 |
+
self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0])
|
156 |
+
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
157 |
+
self.model[component].to(self.torch_dtype).to(self.device)
|
158 |
+
else:
|
159 |
+
self.model[component] = component_dict[component]()
|
160 |
+
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
161 |
+
self.model[component].to(self.torch_dtype).to(self.device)
|
162 |
+
self.model_path[component] = file_path
|
163 |
+
|
164 |
+
def load_stable_diffusion_xl(self, state_dict, components=None, file_path=""):
|
165 |
+
component_dict = {
|
166 |
+
"text_encoder": SDXLTextEncoder,
|
167 |
+
"text_encoder_2": SDXLTextEncoder2,
|
168 |
+
"unet": SDXLUNet,
|
169 |
+
"vae_decoder": SDXLVAEDecoder,
|
170 |
+
"vae_encoder": SDXLVAEEncoder,
|
171 |
+
}
|
172 |
+
if components is None:
|
173 |
+
components = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
|
174 |
+
for component in components:
|
175 |
+
self.model[component] = component_dict[component]()
|
176 |
+
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
177 |
+
if component in ["vae_decoder", "vae_encoder"]:
|
178 |
+
# These two model will output nan when float16 is enabled.
|
179 |
+
# The precision problem happens in the last three resnet blocks.
|
180 |
+
# I do not know how to solve this problem.
|
181 |
+
self.model[component].to(torch.float32).to(self.device)
|
182 |
+
else:
|
183 |
+
self.model[component].to(self.torch_dtype).to(self.device)
|
184 |
+
self.model_path[component] = file_path
|
185 |
+
|
186 |
+
def load_controlnet(self, state_dict, file_path=""):
|
187 |
+
component = "controlnet"
|
188 |
+
if component not in self.model:
|
189 |
+
self.model[component] = []
|
190 |
+
self.model_path[component] = []
|
191 |
+
model = SDControlNet()
|
192 |
+
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
193 |
+
model.to(self.torch_dtype).to(self.device)
|
194 |
+
self.model[component].append(model)
|
195 |
+
self.model_path[component].append(file_path)
|
196 |
+
|
197 |
+
def load_animatediff(self, state_dict, file_path=""):
|
198 |
+
component = "motion_modules"
|
199 |
+
model = SDMotionModel()
|
200 |
+
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
201 |
+
model.to(self.torch_dtype).to(self.device)
|
202 |
+
self.model[component] = model
|
203 |
+
self.model_path[component] = file_path
|
204 |
+
|
205 |
+
def load_animatediff_xl(self, state_dict, file_path=""):
|
206 |
+
component = "motion_modules_xl"
|
207 |
+
model = SDXLMotionModel()
|
208 |
+
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
209 |
+
model.to(self.torch_dtype).to(self.device)
|
210 |
+
self.model[component] = model
|
211 |
+
self.model_path[component] = file_path
|
212 |
+
|
213 |
+
def load_beautiful_prompt(self, state_dict, file_path=""):
|
214 |
+
component = "beautiful_prompt"
|
215 |
+
from transformers import AutoModelForCausalLM
|
216 |
+
model_folder = os.path.dirname(file_path)
|
217 |
+
model = AutoModelForCausalLM.from_pretrained(
|
218 |
+
model_folder, state_dict=state_dict, local_files_only=True, torch_dtype=self.torch_dtype
|
219 |
+
).to(self.device).eval()
|
220 |
+
self.model[component] = model
|
221 |
+
self.model_path[component] = file_path
|
222 |
+
|
223 |
+
def load_RIFE(self, state_dict, file_path=""):
|
224 |
+
component = "RIFE"
|
225 |
+
from ..extensions.RIFE import IFNet
|
226 |
+
model = IFNet().eval()
|
227 |
+
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
228 |
+
model.to(torch.float32).to(self.device)
|
229 |
+
self.model[component] = model
|
230 |
+
self.model_path[component] = file_path
|
231 |
+
|
232 |
+
def load_sd_lora(self, state_dict, alpha):
|
233 |
+
SDLoRA().add_lora_to_text_encoder(self.model["text_encoder"], state_dict, alpha=alpha, device=self.device)
|
234 |
+
SDLoRA().add_lora_to_unet(self.model["unet"], state_dict, alpha=alpha, device=self.device)
|
235 |
+
|
236 |
+
def load_translator(self, state_dict, file_path=""):
|
237 |
+
# This model is lightweight, we do not place it on GPU.
|
238 |
+
component = "translator"
|
239 |
+
from transformers import AutoModelForSeq2SeqLM
|
240 |
+
model_folder = os.path.dirname(file_path)
|
241 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_folder).eval()
|
242 |
+
self.model[component] = model
|
243 |
+
self.model_path[component] = file_path
|
244 |
+
|
245 |
+
def load_ipadapter(self, state_dict, file_path=""):
|
246 |
+
component = "ipadapter"
|
247 |
+
model = SDIpAdapter()
|
248 |
+
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
249 |
+
model.to(self.torch_dtype).to(self.device)
|
250 |
+
self.model[component] = model
|
251 |
+
self.model_path[component] = file_path
|
252 |
+
|
253 |
+
def load_ipadapter_image_encoder(self, state_dict, file_path=""):
|
254 |
+
component = "ipadapter_image_encoder"
|
255 |
+
model = IpAdapterCLIPImageEmbedder()
|
256 |
+
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
257 |
+
model.to(self.torch_dtype).to(self.device)
|
258 |
+
self.model[component] = model
|
259 |
+
self.model_path[component] = file_path
|
260 |
+
|
261 |
+
def load_ipadapter_xl(self, state_dict, file_path=""):
|
262 |
+
component = "ipadapter_xl"
|
263 |
+
model = SDXLIpAdapter()
|
264 |
+
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
265 |
+
model.to(self.torch_dtype).to(self.device)
|
266 |
+
self.model[component] = model
|
267 |
+
self.model_path[component] = file_path
|
268 |
+
|
269 |
+
def load_ipadapter_xl_image_encoder(self, state_dict, file_path=""):
|
270 |
+
component = "ipadapter_xl_image_encoder"
|
271 |
+
model = IpAdapterXLCLIPImageEmbedder()
|
272 |
+
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
273 |
+
model.to(self.torch_dtype).to(self.device)
|
274 |
+
self.model[component] = model
|
275 |
+
self.model_path[component] = file_path
|
276 |
+
|
277 |
+
def load_hunyuan_dit_clip_text_encoder(self, state_dict, file_path=""):
|
278 |
+
component = "hunyuan_dit_clip_text_encoder"
|
279 |
+
model = HunyuanDiTCLIPTextEncoder()
|
280 |
+
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
281 |
+
model.to(self.torch_dtype).to(self.device)
|
282 |
+
self.model[component] = model
|
283 |
+
self.model_path[component] = file_path
|
284 |
+
|
285 |
+
def load_hunyuan_dit_t5_text_encoder(self, state_dict, file_path=""):
|
286 |
+
component = "hunyuan_dit_t5_text_encoder"
|
287 |
+
model = HunyuanDiTT5TextEncoder()
|
288 |
+
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
289 |
+
model.to(self.torch_dtype).to(self.device)
|
290 |
+
self.model[component] = model
|
291 |
+
self.model_path[component] = file_path
|
292 |
+
|
293 |
+
def load_hunyuan_dit(self, state_dict, file_path=""):
|
294 |
+
component = "hunyuan_dit"
|
295 |
+
model = HunyuanDiT()
|
296 |
+
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
297 |
+
model.to(self.torch_dtype).to(self.device)
|
298 |
+
self.model[component] = model
|
299 |
+
self.model_path[component] = file_path
|
300 |
+
|
301 |
+
def load_diffusers_vae(self, state_dict, file_path=""):
|
302 |
+
# TODO: detect SD and SDXL
|
303 |
+
component = "vae_encoder"
|
304 |
+
model = SDXLVAEEncoder()
|
305 |
+
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
306 |
+
model.to(self.torch_dtype).to(self.device)
|
307 |
+
self.model[component] = model
|
308 |
+
self.model_path[component] = file_path
|
309 |
+
component = "vae_decoder"
|
310 |
+
model = SDXLVAEDecoder()
|
311 |
+
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
312 |
+
model.to(self.torch_dtype).to(self.device)
|
313 |
+
self.model[component] = model
|
314 |
+
self.model_path[component] = file_path
|
315 |
+
|
316 |
+
def load_ExVideo_StableVideoDiffusion(self, state_dict, file_path=""):
|
317 |
+
unet_state_dict = self.model["unet"].state_dict()
|
318 |
+
self.model["unet"].to("cpu")
|
319 |
+
del self.model["unet"]
|
320 |
+
add_positional_conv = state_dict["blocks.185.positional_embedding.embeddings"].shape[0]
|
321 |
+
self.model["unet"] = SVDUNet(add_positional_conv=add_positional_conv)
|
322 |
+
self.model["unet"].load_state_dict(unet_state_dict, strict=False)
|
323 |
+
self.model["unet"].load_state_dict(state_dict, strict=False)
|
324 |
+
self.model["unet"].to(self.torch_dtype).to(self.device)
|
325 |
+
|
326 |
+
def search_for_embeddings(self, state_dict):
|
327 |
+
embeddings = []
|
328 |
+
for k in state_dict:
|
329 |
+
if isinstance(state_dict[k], torch.Tensor):
|
330 |
+
embeddings.append(state_dict[k])
|
331 |
+
elif isinstance(state_dict[k], dict):
|
332 |
+
embeddings += self.search_for_embeddings(state_dict[k])
|
333 |
+
return embeddings
|
334 |
+
|
335 |
+
def load_textual_inversions(self, folder):
|
336 |
+
# Store additional tokens here
|
337 |
+
self.textual_inversion_dict = {}
|
338 |
+
|
339 |
+
# Load every textual inversion file
|
340 |
+
for file_name in os.listdir(folder):
|
341 |
+
if file_name.endswith(".txt"):
|
342 |
+
continue
|
343 |
+
keyword = os.path.splitext(file_name)[0]
|
344 |
+
state_dict = load_state_dict(os.path.join(folder, file_name))
|
345 |
+
|
346 |
+
# Search for embeddings
|
347 |
+
for embeddings in self.search_for_embeddings(state_dict):
|
348 |
+
if len(embeddings.shape) == 2 and embeddings.shape[1] == 768:
|
349 |
+
tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])]
|
350 |
+
self.textual_inversion_dict[keyword] = (tokens, embeddings)
|
351 |
+
break
|
352 |
+
|
353 |
+
def load_model(self, file_path, components=None, lora_alphas=[]):
|
354 |
+
state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
|
355 |
+
if self.is_stable_video_diffusion(state_dict):
|
356 |
+
self.load_stable_video_diffusion(state_dict, file_path=file_path)
|
357 |
+
elif self.is_animatediff(state_dict):
|
358 |
+
self.load_animatediff(state_dict, file_path=file_path)
|
359 |
+
elif self.is_animatediff_xl(state_dict):
|
360 |
+
self.load_animatediff_xl(state_dict, file_path=file_path)
|
361 |
+
elif self.is_controlnet(state_dict):
|
362 |
+
self.load_controlnet(state_dict, file_path=file_path)
|
363 |
+
elif self.is_stabe_diffusion_xl(state_dict):
|
364 |
+
self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path)
|
365 |
+
elif self.is_stable_diffusion(state_dict):
|
366 |
+
self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
|
367 |
+
elif self.is_sd_lora(state_dict):
|
368 |
+
self.load_sd_lora(state_dict, alpha=lora_alphas.pop(0))
|
369 |
+
elif self.is_beautiful_prompt(state_dict):
|
370 |
+
self.load_beautiful_prompt(state_dict, file_path=file_path)
|
371 |
+
elif self.is_RIFE(state_dict):
|
372 |
+
self.load_RIFE(state_dict, file_path=file_path)
|
373 |
+
elif self.is_translator(state_dict):
|
374 |
+
self.load_translator(state_dict, file_path=file_path)
|
375 |
+
elif self.is_ipadapter(state_dict):
|
376 |
+
self.load_ipadapter(state_dict, file_path=file_path)
|
377 |
+
elif self.is_ipadapter_image_encoder(state_dict):
|
378 |
+
self.load_ipadapter_image_encoder(state_dict, file_path=file_path)
|
379 |
+
elif self.is_ipadapter_xl(state_dict):
|
380 |
+
self.load_ipadapter_xl(state_dict, file_path=file_path)
|
381 |
+
elif self.is_ipadapter_xl_image_encoder(state_dict):
|
382 |
+
self.load_ipadapter_xl_image_encoder(state_dict, file_path=file_path)
|
383 |
+
elif self.is_hunyuan_dit_clip_text_encoder(state_dict):
|
384 |
+
self.load_hunyuan_dit_clip_text_encoder(state_dict, file_path=file_path)
|
385 |
+
elif self.is_hunyuan_dit_t5_text_encoder(state_dict):
|
386 |
+
self.load_hunyuan_dit_t5_text_encoder(state_dict, file_path=file_path)
|
387 |
+
elif self.is_hunyuan_dit(state_dict):
|
388 |
+
self.load_hunyuan_dit(state_dict, file_path=file_path)
|
389 |
+
elif self.is_diffusers_vae(state_dict):
|
390 |
+
self.load_diffusers_vae(state_dict, file_path=file_path)
|
391 |
+
elif self.is_ExVideo_StableVideoDiffusion(state_dict):
|
392 |
+
self.load_ExVideo_StableVideoDiffusion(state_dict, file_path=file_path)
|
393 |
+
|
394 |
+
def load_models(self, file_path_list, lora_alphas=[]):
|
395 |
+
for file_path in file_path_list:
|
396 |
+
self.load_model(file_path, lora_alphas=lora_alphas)
|
397 |
+
|
398 |
+
def to(self, device):
|
399 |
+
for component in self.model:
|
400 |
+
if isinstance(self.model[component], list):
|
401 |
+
for model in self.model[component]:
|
402 |
+
model.to(device)
|
403 |
+
else:
|
404 |
+
self.model[component].to(device)
|
405 |
+
torch.cuda.empty_cache()
|
406 |
+
|
407 |
+
def get_model_with_model_path(self, model_path):
|
408 |
+
for component in self.model_path:
|
409 |
+
if isinstance(self.model_path[component], str):
|
410 |
+
if os.path.samefile(self.model_path[component], model_path):
|
411 |
+
return self.model[component]
|
412 |
+
elif isinstance(self.model_path[component], list):
|
413 |
+
for i, model_path_ in enumerate(self.model_path[component]):
|
414 |
+
if os.path.samefile(model_path_, model_path):
|
415 |
+
return self.model[component][i]
|
416 |
+
raise ValueError(f"Please load model {model_path} before you use it.")
|
417 |
+
|
418 |
+
def __getattr__(self, __name):
|
419 |
+
if __name in self.model:
|
420 |
+
return self.model[__name]
|
421 |
+
else:
|
422 |
+
return super.__getattribute__(__name)
|
423 |
+
|
424 |
+
|
425 |
+
def load_state_dict(file_path, torch_dtype=None):
|
426 |
+
if file_path.endswith(".safetensors"):
|
427 |
+
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
428 |
+
else:
|
429 |
+
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
430 |
+
|
431 |
+
|
432 |
+
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
433 |
+
state_dict = {}
|
434 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
435 |
+
for k in f.keys():
|
436 |
+
state_dict[k] = f.get_tensor(k)
|
437 |
+
if torch_dtype is not None:
|
438 |
+
state_dict[k] = state_dict[k].to(torch_dtype)
|
439 |
+
return state_dict
|
440 |
+
|
441 |
+
|
442 |
+
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
443 |
+
state_dict = torch.load(file_path, map_location="cpu")
|
444 |
+
if torch_dtype is not None:
|
445 |
+
for i in state_dict:
|
446 |
+
if isinstance(state_dict[i], torch.Tensor):
|
447 |
+
state_dict[i] = state_dict[i].to(torch_dtype)
|
448 |
+
return state_dict
|
449 |
+
|
450 |
+
|
451 |
+
def search_parameter(param, state_dict):
|
452 |
+
for name, param_ in state_dict.items():
|
453 |
+
if param.numel() == param_.numel():
|
454 |
+
if param.shape == param_.shape:
|
455 |
+
if torch.dist(param, param_) < 1e-6:
|
456 |
+
return name
|
457 |
+
else:
|
458 |
+
if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
|
459 |
+
return name
|
460 |
+
return None
|
461 |
+
|
462 |
+
|
463 |
+
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
|
464 |
+
matched_keys = set()
|
465 |
+
with torch.no_grad():
|
466 |
+
for name in source_state_dict:
|
467 |
+
rename = search_parameter(source_state_dict[name], target_state_dict)
|
468 |
+
if rename is not None:
|
469 |
+
print(f'"{name}": "{rename}",')
|
470 |
+
matched_keys.add(rename)
|
471 |
+
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
|
472 |
+
length = source_state_dict[name].shape[0] // 3
|
473 |
+
rename = []
|
474 |
+
for i in range(3):
|
475 |
+
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
|
476 |
+
if None not in rename:
|
477 |
+
print(f'"{name}": {rename},')
|
478 |
+
for rename_ in rename:
|
479 |
+
matched_keys.add(rename_)
|
480 |
+
for name in target_state_dict:
|
481 |
+
if name not in matched_keys:
|
482 |
+
print("Cannot find", name, target_state_dict[name].shape)
|
diffsynth/models/attention.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from einops import rearrange
|
3 |
+
|
4 |
+
|
5 |
+
def low_version_attention(query, key, value, attn_bias=None):
|
6 |
+
scale = 1 / query.shape[-1] ** 0.5
|
7 |
+
query = query * scale
|
8 |
+
attn = torch.matmul(query, key.transpose(-2, -1))
|
9 |
+
if attn_bias is not None:
|
10 |
+
attn = attn + attn_bias
|
11 |
+
attn = attn.softmax(-1)
|
12 |
+
return attn @ value
|
13 |
+
|
14 |
+
|
15 |
+
class Attention(torch.nn.Module):
|
16 |
+
|
17 |
+
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
18 |
+
super().__init__()
|
19 |
+
dim_inner = head_dim * num_heads
|
20 |
+
kv_dim = kv_dim if kv_dim is not None else q_dim
|
21 |
+
self.num_heads = num_heads
|
22 |
+
self.head_dim = head_dim
|
23 |
+
|
24 |
+
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
25 |
+
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
26 |
+
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
27 |
+
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
28 |
+
|
29 |
+
def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
|
30 |
+
batch_size = q.shape[0]
|
31 |
+
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
32 |
+
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
33 |
+
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
34 |
+
hidden_states = hidden_states + scale * ip_hidden_states
|
35 |
+
return hidden_states
|
36 |
+
|
37 |
+
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
38 |
+
if encoder_hidden_states is None:
|
39 |
+
encoder_hidden_states = hidden_states
|
40 |
+
|
41 |
+
batch_size = encoder_hidden_states.shape[0]
|
42 |
+
|
43 |
+
q = self.to_q(hidden_states)
|
44 |
+
k = self.to_k(encoder_hidden_states)
|
45 |
+
v = self.to_v(encoder_hidden_states)
|
46 |
+
|
47 |
+
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
48 |
+
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
49 |
+
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
50 |
+
|
51 |
+
if qkv_preprocessor is not None:
|
52 |
+
q, k, v = qkv_preprocessor(q, k, v)
|
53 |
+
|
54 |
+
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
55 |
+
if ipadapter_kwargs is not None:
|
56 |
+
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
|
57 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
58 |
+
hidden_states = hidden_states.to(q.dtype)
|
59 |
+
|
60 |
+
hidden_states = self.to_out(hidden_states)
|
61 |
+
|
62 |
+
return hidden_states
|
63 |
+
|
64 |
+
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
65 |
+
if encoder_hidden_states is None:
|
66 |
+
encoder_hidden_states = hidden_states
|
67 |
+
|
68 |
+
q = self.to_q(hidden_states)
|
69 |
+
k = self.to_k(encoder_hidden_states)
|
70 |
+
v = self.to_v(encoder_hidden_states)
|
71 |
+
|
72 |
+
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
|
73 |
+
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
|
74 |
+
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
|
75 |
+
|
76 |
+
if attn_mask is not None:
|
77 |
+
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
|
78 |
+
else:
|
79 |
+
import xformers.ops as xops
|
80 |
+
hidden_states = xops.memory_efficient_attention(q, k, v)
|
81 |
+
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
|
82 |
+
|
83 |
+
hidden_states = hidden_states.to(q.dtype)
|
84 |
+
hidden_states = self.to_out(hidden_states)
|
85 |
+
|
86 |
+
return hidden_states
|
87 |
+
|
88 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
89 |
+
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
|
diffsynth/models/hunyuan_dit.py
ADDED
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .attention import Attention
|
2 |
+
from .tiler import TileWorker
|
3 |
+
from einops import repeat, rearrange
|
4 |
+
import math
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class HunyuanDiTRotaryEmbedding(torch.nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, q_norm_shape=88, k_norm_shape=88, rotary_emb_on_k=True):
|
11 |
+
super().__init__()
|
12 |
+
self.q_norm = torch.nn.LayerNorm((q_norm_shape,), elementwise_affine=True, eps=1e-06)
|
13 |
+
self.k_norm = torch.nn.LayerNorm((k_norm_shape,), elementwise_affine=True, eps=1e-06)
|
14 |
+
self.rotary_emb_on_k = rotary_emb_on_k
|
15 |
+
self.k_cache, self.v_cache = [], []
|
16 |
+
|
17 |
+
def reshape_for_broadcast(self, freqs_cis, x):
|
18 |
+
ndim = x.ndim
|
19 |
+
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
20 |
+
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
21 |
+
|
22 |
+
def rotate_half(self, x):
|
23 |
+
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
24 |
+
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
25 |
+
|
26 |
+
def apply_rotary_emb(self, xq, xk, freqs_cis):
|
27 |
+
xk_out = None
|
28 |
+
cos, sin = self.reshape_for_broadcast(freqs_cis, xq)
|
29 |
+
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
30 |
+
xq_out = (xq.float() * cos + self.rotate_half(xq.float()) * sin).type_as(xq)
|
31 |
+
if xk is not None:
|
32 |
+
xk_out = (xk.float() * cos + self.rotate_half(xk.float()) * sin).type_as(xk)
|
33 |
+
return xq_out, xk_out
|
34 |
+
|
35 |
+
def forward(self, q, k, v, freqs_cis_img, to_cache=False):
|
36 |
+
# norm
|
37 |
+
q = self.q_norm(q)
|
38 |
+
k = self.k_norm(k)
|
39 |
+
|
40 |
+
# RoPE
|
41 |
+
if self.rotary_emb_on_k:
|
42 |
+
q, k = self.apply_rotary_emb(q, k, freqs_cis_img)
|
43 |
+
else:
|
44 |
+
q, _ = self.apply_rotary_emb(q, None, freqs_cis_img)
|
45 |
+
|
46 |
+
if to_cache:
|
47 |
+
self.k_cache.append(k)
|
48 |
+
self.v_cache.append(v)
|
49 |
+
elif len(self.k_cache) > 0 and len(self.v_cache) > 0:
|
50 |
+
k = torch.concat([k] + self.k_cache, dim=2)
|
51 |
+
v = torch.concat([v] + self.v_cache, dim=2)
|
52 |
+
self.k_cache, self.v_cache = [], []
|
53 |
+
return q, k, v
|
54 |
+
|
55 |
+
|
56 |
+
class FP32_Layernorm(torch.nn.LayerNorm):
|
57 |
+
def forward(self, inputs):
|
58 |
+
origin_dtype = inputs.dtype
|
59 |
+
return torch.nn.functional.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).to(origin_dtype)
|
60 |
+
|
61 |
+
|
62 |
+
class FP32_SiLU(torch.nn.SiLU):
|
63 |
+
def forward(self, inputs):
|
64 |
+
origin_dtype = inputs.dtype
|
65 |
+
return torch.nn.functional.silu(inputs.float(), inplace=False).to(origin_dtype)
|
66 |
+
|
67 |
+
|
68 |
+
class HunyuanDiTFinalLayer(torch.nn.Module):
|
69 |
+
def __init__(self, final_hidden_size=1408, condition_dim=1408, patch_size=2, out_channels=8):
|
70 |
+
super().__init__()
|
71 |
+
self.norm_final = torch.nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
|
72 |
+
self.linear = torch.nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
|
73 |
+
self.adaLN_modulation = torch.nn.Sequential(
|
74 |
+
FP32_SiLU(),
|
75 |
+
torch.nn.Linear(condition_dim, 2 * final_hidden_size, bias=True)
|
76 |
+
)
|
77 |
+
|
78 |
+
def modulate(self, x, shift, scale):
|
79 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
80 |
+
|
81 |
+
def forward(self, hidden_states, condition_emb):
|
82 |
+
shift, scale = self.adaLN_modulation(condition_emb).chunk(2, dim=1)
|
83 |
+
hidden_states = self.modulate(self.norm_final(hidden_states), shift, scale)
|
84 |
+
hidden_states = self.linear(hidden_states)
|
85 |
+
return hidden_states
|
86 |
+
|
87 |
+
|
88 |
+
class HunyuanDiTBlock(torch.nn.Module):
|
89 |
+
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
hidden_dim=1408,
|
93 |
+
condition_dim=1408,
|
94 |
+
num_heads=16,
|
95 |
+
mlp_ratio=4.3637,
|
96 |
+
text_dim=1024,
|
97 |
+
skip_connection=False
|
98 |
+
):
|
99 |
+
super().__init__()
|
100 |
+
self.norm1 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
101 |
+
self.rota1 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads)
|
102 |
+
self.attn1 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
|
103 |
+
self.norm2 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
104 |
+
self.rota2 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads, rotary_emb_on_k=False)
|
105 |
+
self.attn2 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, kv_dim=text_dim, bias_q=True, bias_kv=True, bias_out=True)
|
106 |
+
self.norm3 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
107 |
+
self.modulation = torch.nn.Sequential(FP32_SiLU(), torch.nn.Linear(condition_dim, hidden_dim, bias=True))
|
108 |
+
self.mlp = torch.nn.Sequential(
|
109 |
+
torch.nn.Linear(hidden_dim, int(hidden_dim*mlp_ratio), bias=True),
|
110 |
+
torch.nn.GELU(approximate="tanh"),
|
111 |
+
torch.nn.Linear(int(hidden_dim*mlp_ratio), hidden_dim, bias=True)
|
112 |
+
)
|
113 |
+
if skip_connection:
|
114 |
+
self.skip_norm = FP32_Layernorm((hidden_dim * 2,), eps=1e-6, elementwise_affine=True)
|
115 |
+
self.skip_linear = torch.nn.Linear(hidden_dim * 2, hidden_dim, bias=True)
|
116 |
+
else:
|
117 |
+
self.skip_norm, self.skip_linear = None, None
|
118 |
+
|
119 |
+
def forward(self, hidden_states, condition_emb, text_emb, freq_cis_img, residual=None, to_cache=False):
|
120 |
+
# Long Skip Connection
|
121 |
+
if self.skip_norm is not None and self.skip_linear is not None:
|
122 |
+
hidden_states = torch.cat([hidden_states, residual], dim=-1)
|
123 |
+
hidden_states = self.skip_norm(hidden_states)
|
124 |
+
hidden_states = self.skip_linear(hidden_states)
|
125 |
+
|
126 |
+
# Self-Attention
|
127 |
+
shift_msa = self.modulation(condition_emb).unsqueeze(dim=1)
|
128 |
+
attn_input = self.norm1(hidden_states) + shift_msa
|
129 |
+
hidden_states = hidden_states + self.attn1(attn_input, qkv_preprocessor=lambda q, k, v: self.rota1(q, k, v, freq_cis_img, to_cache=to_cache))
|
130 |
+
|
131 |
+
# Cross-Attention
|
132 |
+
attn_input = self.norm3(hidden_states)
|
133 |
+
hidden_states = hidden_states + self.attn2(attn_input, text_emb, qkv_preprocessor=lambda q, k, v: self.rota2(q, k, v, freq_cis_img))
|
134 |
+
|
135 |
+
# FFN Layer
|
136 |
+
mlp_input = self.norm2(hidden_states)
|
137 |
+
hidden_states = hidden_states + self.mlp(mlp_input)
|
138 |
+
return hidden_states
|
139 |
+
|
140 |
+
|
141 |
+
class AttentionPool(torch.nn.Module):
|
142 |
+
def __init__(self, spacial_dim, embed_dim, num_heads, output_dim = None):
|
143 |
+
super().__init__()
|
144 |
+
self.positional_embedding = torch.nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
|
145 |
+
self.k_proj = torch.nn.Linear(embed_dim, embed_dim)
|
146 |
+
self.q_proj = torch.nn.Linear(embed_dim, embed_dim)
|
147 |
+
self.v_proj = torch.nn.Linear(embed_dim, embed_dim)
|
148 |
+
self.c_proj = torch.nn.Linear(embed_dim, output_dim or embed_dim)
|
149 |
+
self.num_heads = num_heads
|
150 |
+
|
151 |
+
def forward(self, x):
|
152 |
+
x = x.permute(1, 0, 2) # NLC -> LNC
|
153 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
154 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
|
155 |
+
x, _ = torch.nn.functional.multi_head_attention_forward(
|
156 |
+
query=x[:1], key=x, value=x,
|
157 |
+
embed_dim_to_check=x.shape[-1],
|
158 |
+
num_heads=self.num_heads,
|
159 |
+
q_proj_weight=self.q_proj.weight,
|
160 |
+
k_proj_weight=self.k_proj.weight,
|
161 |
+
v_proj_weight=self.v_proj.weight,
|
162 |
+
in_proj_weight=None,
|
163 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
164 |
+
bias_k=None,
|
165 |
+
bias_v=None,
|
166 |
+
add_zero_attn=False,
|
167 |
+
dropout_p=0,
|
168 |
+
out_proj_weight=self.c_proj.weight,
|
169 |
+
out_proj_bias=self.c_proj.bias,
|
170 |
+
use_separate_proj_weight=True,
|
171 |
+
training=self.training,
|
172 |
+
need_weights=False
|
173 |
+
)
|
174 |
+
return x.squeeze(0)
|
175 |
+
|
176 |
+
|
177 |
+
class PatchEmbed(torch.nn.Module):
|
178 |
+
def __init__(
|
179 |
+
self,
|
180 |
+
patch_size=(2, 2),
|
181 |
+
in_chans=4,
|
182 |
+
embed_dim=1408,
|
183 |
+
bias=True,
|
184 |
+
):
|
185 |
+
super().__init__()
|
186 |
+
self.proj = torch.nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
187 |
+
|
188 |
+
def forward(self, x):
|
189 |
+
x = self.proj(x)
|
190 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
191 |
+
return x
|
192 |
+
|
193 |
+
|
194 |
+
def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
|
195 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
196 |
+
if not repeat_only:
|
197 |
+
half = dim // 2
|
198 |
+
freqs = torch.exp(
|
199 |
+
-math.log(max_period)
|
200 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
201 |
+
/ half
|
202 |
+
).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线
|
203 |
+
args = t[:, None].float() * freqs[None]
|
204 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
205 |
+
if dim % 2:
|
206 |
+
embedding = torch.cat(
|
207 |
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
208 |
+
)
|
209 |
+
else:
|
210 |
+
embedding = repeat(t, "b -> b d", d=dim)
|
211 |
+
return embedding
|
212 |
+
|
213 |
+
|
214 |
+
class TimestepEmbedder(torch.nn.Module):
|
215 |
+
def __init__(self, hidden_size=1408, frequency_embedding_size=256):
|
216 |
+
super().__init__()
|
217 |
+
self.mlp = torch.nn.Sequential(
|
218 |
+
torch.nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
219 |
+
torch.nn.SiLU(),
|
220 |
+
torch.nn.Linear(hidden_size, hidden_size, bias=True),
|
221 |
+
)
|
222 |
+
self.frequency_embedding_size = frequency_embedding_size
|
223 |
+
|
224 |
+
def forward(self, t):
|
225 |
+
t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
|
226 |
+
t_emb = self.mlp(t_freq)
|
227 |
+
return t_emb
|
228 |
+
|
229 |
+
|
230 |
+
class HunyuanDiT(torch.nn.Module):
|
231 |
+
def __init__(self, num_layers_down=21, num_layers_up=19, in_channels=4, out_channels=8, hidden_dim=1408, text_dim=1024, t5_dim=2048, text_length=77, t5_length=256):
|
232 |
+
super().__init__()
|
233 |
+
|
234 |
+
# Embedders
|
235 |
+
self.text_emb_padding = torch.nn.Parameter(torch.randn(text_length + t5_length, text_dim, dtype=torch.float32))
|
236 |
+
self.t5_embedder = torch.nn.Sequential(
|
237 |
+
torch.nn.Linear(t5_dim, t5_dim * 4, bias=True),
|
238 |
+
FP32_SiLU(),
|
239 |
+
torch.nn.Linear(t5_dim * 4, text_dim, bias=True),
|
240 |
+
)
|
241 |
+
self.t5_pooler = AttentionPool(t5_length, t5_dim, num_heads=8, output_dim=1024)
|
242 |
+
self.style_embedder = torch.nn.Parameter(torch.randn(hidden_dim))
|
243 |
+
self.patch_embedder = PatchEmbed(in_chans=in_channels)
|
244 |
+
self.timestep_embedder = TimestepEmbedder()
|
245 |
+
self.extra_embedder = torch.nn.Sequential(
|
246 |
+
torch.nn.Linear(256 * 6 + 1024 + hidden_dim, hidden_dim * 4),
|
247 |
+
FP32_SiLU(),
|
248 |
+
torch.nn.Linear(hidden_dim * 4, hidden_dim),
|
249 |
+
)
|
250 |
+
|
251 |
+
# Transformer blocks
|
252 |
+
self.num_layers_down = num_layers_down
|
253 |
+
self.num_layers_up = num_layers_up
|
254 |
+
self.blocks = torch.nn.ModuleList(
|
255 |
+
[HunyuanDiTBlock(skip_connection=False) for _ in range(num_layers_down)] + \
|
256 |
+
[HunyuanDiTBlock(skip_connection=True) for _ in range(num_layers_up)]
|
257 |
+
)
|
258 |
+
|
259 |
+
# Output layers
|
260 |
+
self.final_layer = HunyuanDiTFinalLayer()
|
261 |
+
self.out_channels = out_channels
|
262 |
+
|
263 |
+
def prepare_text_emb(self, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5):
|
264 |
+
text_emb_mask = text_emb_mask.bool()
|
265 |
+
text_emb_mask_t5 = text_emb_mask_t5.bool()
|
266 |
+
text_emb_t5 = self.t5_embedder(text_emb_t5)
|
267 |
+
text_emb = torch.cat([text_emb, text_emb_t5], dim=1)
|
268 |
+
text_emb_mask = torch.cat([text_emb_mask, text_emb_mask_t5], dim=-1)
|
269 |
+
text_emb = torch.where(text_emb_mask.unsqueeze(2), text_emb, self.text_emb_padding.to(text_emb))
|
270 |
+
return text_emb
|
271 |
+
|
272 |
+
def prepare_extra_emb(self, text_emb_t5, timestep, size_emb, dtype, batch_size):
|
273 |
+
# Text embedding
|
274 |
+
pooled_text_emb_t5 = self.t5_pooler(text_emb_t5)
|
275 |
+
|
276 |
+
# Timestep embedding
|
277 |
+
timestep_emb = self.timestep_embedder(timestep)
|
278 |
+
|
279 |
+
# Size embedding
|
280 |
+
size_emb = timestep_embedding(size_emb.view(-1), 256).to(dtype)
|
281 |
+
size_emb = size_emb.view(-1, 6 * 256)
|
282 |
+
|
283 |
+
# Style embedding
|
284 |
+
style_emb = repeat(self.style_embedder, "D -> B D", B=batch_size)
|
285 |
+
|
286 |
+
# Concatenate all extra vectors
|
287 |
+
extra_emb = torch.cat([pooled_text_emb_t5, size_emb, style_emb], dim=1)
|
288 |
+
condition_emb = timestep_emb + self.extra_embedder(extra_emb)
|
289 |
+
|
290 |
+
return condition_emb
|
291 |
+
|
292 |
+
def unpatchify(self, x, h, w):
|
293 |
+
return rearrange(x, "B (H W) (P Q C) -> B C (H P) (W Q)", H=h, W=w, P=2, Q=2)
|
294 |
+
|
295 |
+
def build_mask(self, data, is_bound):
|
296 |
+
_, _, H, W = data.shape
|
297 |
+
h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
|
298 |
+
w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
|
299 |
+
border_width = (H + W) // 4
|
300 |
+
pad = torch.ones_like(h) * border_width
|
301 |
+
mask = torch.stack([
|
302 |
+
pad if is_bound[0] else h + 1,
|
303 |
+
pad if is_bound[1] else H - h,
|
304 |
+
pad if is_bound[2] else w + 1,
|
305 |
+
pad if is_bound[3] else W - w
|
306 |
+
]).min(dim=0).values
|
307 |
+
mask = mask.clip(1, border_width)
|
308 |
+
mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
|
309 |
+
mask = rearrange(mask, "H W -> 1 H W")
|
310 |
+
return mask
|
311 |
+
|
312 |
+
def tiled_block_forward(self, block, hidden_states, condition_emb, text_emb, freq_cis_img, residual, torch_dtype, data_device, computation_device, tile_size, tile_stride):
|
313 |
+
B, C, H, W = hidden_states.shape
|
314 |
+
|
315 |
+
weight = torch.zeros((1, 1, H, W), dtype=torch_dtype, device=data_device)
|
316 |
+
values = torch.zeros((B, C, H, W), dtype=torch_dtype, device=data_device)
|
317 |
+
|
318 |
+
# Split tasks
|
319 |
+
tasks = []
|
320 |
+
for h in range(0, H, tile_stride):
|
321 |
+
for w in range(0, W, tile_stride):
|
322 |
+
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
|
323 |
+
continue
|
324 |
+
h_, w_ = h + tile_size, w + tile_size
|
325 |
+
if h_ > H: h, h_ = H - tile_size, H
|
326 |
+
if w_ > W: w, w_ = W - tile_size, W
|
327 |
+
tasks.append((h, h_, w, w_))
|
328 |
+
|
329 |
+
# Run
|
330 |
+
for hl, hr, wl, wr in tasks:
|
331 |
+
hidden_states_batch = hidden_states[:, :, hl:hr, wl:wr].to(computation_device)
|
332 |
+
hidden_states_batch = rearrange(hidden_states_batch, "B C H W -> B (H W) C")
|
333 |
+
if residual is not None:
|
334 |
+
residual_batch = residual[:, :, hl:hr, wl:wr].to(computation_device)
|
335 |
+
residual_batch = rearrange(residual_batch, "B C H W -> B (H W) C")
|
336 |
+
else:
|
337 |
+
residual_batch = None
|
338 |
+
|
339 |
+
# Forward
|
340 |
+
hidden_states_batch = block(hidden_states_batch, condition_emb, text_emb, freq_cis_img, residual_batch).to(data_device)
|
341 |
+
hidden_states_batch = rearrange(hidden_states_batch, "B (H W) C -> B C H W", H=hr-hl)
|
342 |
+
|
343 |
+
mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
|
344 |
+
values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
|
345 |
+
weight[:, :, hl:hr, wl:wr] += mask
|
346 |
+
values /= weight
|
347 |
+
return values
|
348 |
+
|
349 |
+
def forward(
|
350 |
+
self, hidden_states, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5, timestep, size_emb, freq_cis_img,
|
351 |
+
tiled=False, tile_size=64, tile_stride=32,
|
352 |
+
to_cache=False,
|
353 |
+
use_gradient_checkpointing=False,
|
354 |
+
):
|
355 |
+
# Embeddings
|
356 |
+
text_emb = self.prepare_text_emb(text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5)
|
357 |
+
condition_emb = self.prepare_extra_emb(text_emb_t5, timestep, size_emb, hidden_states.dtype, hidden_states.shape[0])
|
358 |
+
|
359 |
+
# Input
|
360 |
+
height, width = hidden_states.shape[-2], hidden_states.shape[-1]
|
361 |
+
hidden_states = self.patch_embedder(hidden_states)
|
362 |
+
|
363 |
+
# Blocks
|
364 |
+
def create_custom_forward(module):
|
365 |
+
def custom_forward(*inputs):
|
366 |
+
return module(*inputs)
|
367 |
+
return custom_forward
|
368 |
+
if tiled:
|
369 |
+
hidden_states = rearrange(hidden_states, "B (H W) C -> B C H W", H=height//2)
|
370 |
+
residuals = []
|
371 |
+
for block_id, block in enumerate(self.blocks):
|
372 |
+
residual = residuals.pop() if block_id >= self.num_layers_down else None
|
373 |
+
hidden_states = self.tiled_block_forward(
|
374 |
+
block, hidden_states, condition_emb, text_emb, freq_cis_img, residual,
|
375 |
+
torch_dtype=hidden_states.dtype, data_device=hidden_states.device, computation_device=hidden_states.device,
|
376 |
+
tile_size=tile_size, tile_stride=tile_stride
|
377 |
+
)
|
378 |
+
if block_id < self.num_layers_down - 2:
|
379 |
+
residuals.append(hidden_states)
|
380 |
+
hidden_states = rearrange(hidden_states, "B C H W -> B (H W) C")
|
381 |
+
else:
|
382 |
+
residuals = []
|
383 |
+
for block_id, block in enumerate(self.blocks):
|
384 |
+
residual = residuals.pop() if block_id >= self.num_layers_down else None
|
385 |
+
if self.training and use_gradient_checkpointing:
|
386 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
387 |
+
create_custom_forward(block),
|
388 |
+
hidden_states, condition_emb, text_emb, freq_cis_img, residual,
|
389 |
+
use_reentrant=False,
|
390 |
+
)
|
391 |
+
else:
|
392 |
+
hidden_states = block(hidden_states, condition_emb, text_emb, freq_cis_img, residual, to_cache=to_cache)
|
393 |
+
if block_id < self.num_layers_down - 2:
|
394 |
+
residuals.append(hidden_states)
|
395 |
+
|
396 |
+
# Output
|
397 |
+
hidden_states = self.final_layer(hidden_states, condition_emb)
|
398 |
+
hidden_states = self.unpatchify(hidden_states, height//2, width//2)
|
399 |
+
hidden_states, _ = hidden_states.chunk(2, dim=1)
|
400 |
+
return hidden_states
|
401 |
+
|
402 |
+
def state_dict_converter(self):
|
403 |
+
return HunyuanDiTStateDictConverter()
|
404 |
+
|
405 |
+
|
406 |
+
|
407 |
+
class HunyuanDiTStateDictConverter():
|
408 |
+
def __init__(self):
|
409 |
+
pass
|
410 |
+
|
411 |
+
def from_diffusers(self, state_dict):
|
412 |
+
state_dict_ = {}
|
413 |
+
for name, param in state_dict.items():
|
414 |
+
name_ = name
|
415 |
+
name_ = name_.replace(".default_modulation.", ".modulation.")
|
416 |
+
name_ = name_.replace(".mlp.fc1.", ".mlp.0.")
|
417 |
+
name_ = name_.replace(".mlp.fc2.", ".mlp.2.")
|
418 |
+
name_ = name_.replace(".attn1.q_norm.", ".rota1.q_norm.")
|
419 |
+
name_ = name_.replace(".attn2.q_norm.", ".rota2.q_norm.")
|
420 |
+
name_ = name_.replace(".attn1.k_norm.", ".rota1.k_norm.")
|
421 |
+
name_ = name_.replace(".attn2.k_norm.", ".rota2.k_norm.")
|
422 |
+
name_ = name_.replace(".q_proj.", ".to_q.")
|
423 |
+
name_ = name_.replace(".out_proj.", ".to_out.")
|
424 |
+
name_ = name_.replace("text_embedding_padding", "text_emb_padding")
|
425 |
+
name_ = name_.replace("mlp_t5.0.", "t5_embedder.0.")
|
426 |
+
name_ = name_.replace("mlp_t5.2.", "t5_embedder.2.")
|
427 |
+
name_ = name_.replace("pooler.", "t5_pooler.")
|
428 |
+
name_ = name_.replace("x_embedder.", "patch_embedder.")
|
429 |
+
name_ = name_.replace("t_embedder.", "timestep_embedder.")
|
430 |
+
name_ = name_.replace("t5_pooler.to_q.", "t5_pooler.q_proj.")
|
431 |
+
name_ = name_.replace("style_embedder.weight", "style_embedder")
|
432 |
+
if ".kv_proj." in name_:
|
433 |
+
param_k = param[:param.shape[0]//2]
|
434 |
+
param_v = param[param.shape[0]//2:]
|
435 |
+
state_dict_[name_.replace(".kv_proj.", ".to_k.")] = param_k
|
436 |
+
state_dict_[name_.replace(".kv_proj.", ".to_v.")] = param_v
|
437 |
+
elif ".Wqkv." in name_:
|
438 |
+
param_q = param[:param.shape[0]//3]
|
439 |
+
param_k = param[param.shape[0]//3:param.shape[0]//3*2]
|
440 |
+
param_v = param[param.shape[0]//3*2:]
|
441 |
+
state_dict_[name_.replace(".Wqkv.", ".to_q.")] = param_q
|
442 |
+
state_dict_[name_.replace(".Wqkv.", ".to_k.")] = param_k
|
443 |
+
state_dict_[name_.replace(".Wqkv.", ".to_v.")] = param_v
|
444 |
+
elif "style_embedder" in name_:
|
445 |
+
state_dict_[name_] = param.squeeze()
|
446 |
+
else:
|
447 |
+
state_dict_[name_] = param
|
448 |
+
return state_dict_
|
449 |
+
|
450 |
+
def from_civitai(self, state_dict):
|
451 |
+
return self.from_diffusers(state_dict)
|
diffsynth/models/hunyuan_dit_text_encoder.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BertModel, BertConfig, T5EncoderModel, T5Config
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
class HunyuanDiTCLIPTextEncoder(BertModel):
|
7 |
+
def __init__(self):
|
8 |
+
config = BertConfig(
|
9 |
+
_name_or_path = "",
|
10 |
+
architectures = ["BertModel"],
|
11 |
+
attention_probs_dropout_prob = 0.1,
|
12 |
+
bos_token_id = 0,
|
13 |
+
classifier_dropout = None,
|
14 |
+
directionality = "bidi",
|
15 |
+
eos_token_id = 2,
|
16 |
+
hidden_act = "gelu",
|
17 |
+
hidden_dropout_prob = 0.1,
|
18 |
+
hidden_size = 1024,
|
19 |
+
initializer_range = 0.02,
|
20 |
+
intermediate_size = 4096,
|
21 |
+
layer_norm_eps = 1e-12,
|
22 |
+
max_position_embeddings = 512,
|
23 |
+
model_type = "bert",
|
24 |
+
num_attention_heads = 16,
|
25 |
+
num_hidden_layers = 24,
|
26 |
+
output_past = True,
|
27 |
+
pad_token_id = 0,
|
28 |
+
pooler_fc_size = 768,
|
29 |
+
pooler_num_attention_heads = 12,
|
30 |
+
pooler_num_fc_layers = 3,
|
31 |
+
pooler_size_per_head = 128,
|
32 |
+
pooler_type = "first_token_transform",
|
33 |
+
position_embedding_type = "absolute",
|
34 |
+
torch_dtype = "float32",
|
35 |
+
transformers_version = "4.37.2",
|
36 |
+
type_vocab_size = 2,
|
37 |
+
use_cache = True,
|
38 |
+
vocab_size = 47020
|
39 |
+
)
|
40 |
+
super().__init__(config, add_pooling_layer=False)
|
41 |
+
self.eval()
|
42 |
+
|
43 |
+
def forward(self, input_ids, attention_mask, clip_skip=1):
|
44 |
+
input_shape = input_ids.size()
|
45 |
+
|
46 |
+
batch_size, seq_length = input_shape
|
47 |
+
device = input_ids.device
|
48 |
+
|
49 |
+
past_key_values_length = 0
|
50 |
+
|
51 |
+
if attention_mask is None:
|
52 |
+
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
53 |
+
|
54 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
55 |
+
|
56 |
+
embedding_output = self.embeddings(
|
57 |
+
input_ids=input_ids,
|
58 |
+
position_ids=None,
|
59 |
+
token_type_ids=None,
|
60 |
+
inputs_embeds=None,
|
61 |
+
past_key_values_length=0,
|
62 |
+
)
|
63 |
+
encoder_outputs = self.encoder(
|
64 |
+
embedding_output,
|
65 |
+
attention_mask=extended_attention_mask,
|
66 |
+
head_mask=None,
|
67 |
+
encoder_hidden_states=None,
|
68 |
+
encoder_attention_mask=None,
|
69 |
+
past_key_values=None,
|
70 |
+
use_cache=False,
|
71 |
+
output_attentions=False,
|
72 |
+
output_hidden_states=True,
|
73 |
+
return_dict=True,
|
74 |
+
)
|
75 |
+
all_hidden_states = encoder_outputs.hidden_states
|
76 |
+
prompt_emb = all_hidden_states[-clip_skip]
|
77 |
+
if clip_skip > 1:
|
78 |
+
mean, std = all_hidden_states[-1].mean(), all_hidden_states[-1].std()
|
79 |
+
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
80 |
+
return prompt_emb
|
81 |
+
|
82 |
+
def state_dict_converter(self):
|
83 |
+
return HunyuanDiTCLIPTextEncoderStateDictConverter()
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
class HunyuanDiTT5TextEncoder(T5EncoderModel):
|
88 |
+
def __init__(self):
|
89 |
+
config = T5Config(
|
90 |
+
_name_or_path = "../HunyuanDiT/t2i/mt5",
|
91 |
+
architectures = ["MT5ForConditionalGeneration"],
|
92 |
+
classifier_dropout = 0.0,
|
93 |
+
d_ff = 5120,
|
94 |
+
d_kv = 64,
|
95 |
+
d_model = 2048,
|
96 |
+
decoder_start_token_id = 0,
|
97 |
+
dense_act_fn = "gelu_new",
|
98 |
+
dropout_rate = 0.1,
|
99 |
+
eos_token_id = 1,
|
100 |
+
feed_forward_proj = "gated-gelu",
|
101 |
+
initializer_factor = 1.0,
|
102 |
+
is_encoder_decoder = True,
|
103 |
+
is_gated_act = True,
|
104 |
+
layer_norm_epsilon = 1e-06,
|
105 |
+
model_type = "t5",
|
106 |
+
num_decoder_layers = 24,
|
107 |
+
num_heads = 32,
|
108 |
+
num_layers = 24,
|
109 |
+
output_past = True,
|
110 |
+
pad_token_id = 0,
|
111 |
+
relative_attention_max_distance = 128,
|
112 |
+
relative_attention_num_buckets = 32,
|
113 |
+
tie_word_embeddings = False,
|
114 |
+
tokenizer_class = "T5Tokenizer",
|
115 |
+
transformers_version = "4.37.2",
|
116 |
+
use_cache = True,
|
117 |
+
vocab_size = 250112
|
118 |
+
)
|
119 |
+
super().__init__(config)
|
120 |
+
self.eval()
|
121 |
+
|
122 |
+
def forward(self, input_ids, attention_mask, clip_skip=1):
|
123 |
+
outputs = super().forward(
|
124 |
+
input_ids=input_ids,
|
125 |
+
attention_mask=attention_mask,
|
126 |
+
output_hidden_states=True,
|
127 |
+
)
|
128 |
+
prompt_emb = outputs.hidden_states[-clip_skip]
|
129 |
+
if clip_skip > 1:
|
130 |
+
mean, std = outputs.hidden_states[-1].mean(), outputs.hidden_states[-1].std()
|
131 |
+
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
132 |
+
return prompt_emb
|
133 |
+
|
134 |
+
def state_dict_converter(self):
|
135 |
+
return HunyuanDiTT5TextEncoderStateDictConverter()
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
class HunyuanDiTCLIPTextEncoderStateDictConverter():
|
140 |
+
def __init__(self):
|
141 |
+
pass
|
142 |
+
|
143 |
+
def from_diffusers(self, state_dict):
|
144 |
+
state_dict_ = {name[5:]: param for name, param in state_dict.items() if name.startswith("bert.")}
|
145 |
+
return state_dict_
|
146 |
+
|
147 |
+
def from_civitai(self, state_dict):
|
148 |
+
return self.from_diffusers(state_dict)
|
149 |
+
|
150 |
+
|
151 |
+
class HunyuanDiTT5TextEncoderStateDictConverter():
|
152 |
+
def __init__(self):
|
153 |
+
pass
|
154 |
+
|
155 |
+
def from_diffusers(self, state_dict):
|
156 |
+
state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("encoder.")}
|
157 |
+
state_dict_["shared.weight"] = state_dict["shared.weight"]
|
158 |
+
return state_dict_
|
159 |
+
|
160 |
+
def from_civitai(self, state_dict):
|
161 |
+
return self.from_diffusers(state_dict)
|
diffsynth/models/sd_controlnet.py
ADDED
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
|
3 |
+
from .tiler import TileWorker
|
4 |
+
|
5 |
+
|
6 |
+
class ControlNetConditioningLayer(torch.nn.Module):
|
7 |
+
def __init__(self, channels = (3, 16, 32, 96, 256, 320)):
|
8 |
+
super().__init__()
|
9 |
+
self.blocks = torch.nn.ModuleList([])
|
10 |
+
self.blocks.append(torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1))
|
11 |
+
self.blocks.append(torch.nn.SiLU())
|
12 |
+
for i in range(1, len(channels) - 2):
|
13 |
+
self.blocks.append(torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1))
|
14 |
+
self.blocks.append(torch.nn.SiLU())
|
15 |
+
self.blocks.append(torch.nn.Conv2d(channels[i], channels[i+1], kernel_size=3, padding=1, stride=2))
|
16 |
+
self.blocks.append(torch.nn.SiLU())
|
17 |
+
self.blocks.append(torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1))
|
18 |
+
|
19 |
+
def forward(self, conditioning):
|
20 |
+
for block in self.blocks:
|
21 |
+
conditioning = block(conditioning)
|
22 |
+
return conditioning
|
23 |
+
|
24 |
+
|
25 |
+
class SDControlNet(torch.nn.Module):
|
26 |
+
def __init__(self, global_pool=False):
|
27 |
+
super().__init__()
|
28 |
+
self.time_proj = Timesteps(320)
|
29 |
+
self.time_embedding = torch.nn.Sequential(
|
30 |
+
torch.nn.Linear(320, 1280),
|
31 |
+
torch.nn.SiLU(),
|
32 |
+
torch.nn.Linear(1280, 1280)
|
33 |
+
)
|
34 |
+
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
|
35 |
+
|
36 |
+
self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
|
37 |
+
|
38 |
+
self.blocks = torch.nn.ModuleList([
|
39 |
+
# CrossAttnDownBlock2D
|
40 |
+
ResnetBlock(320, 320, 1280),
|
41 |
+
AttentionBlock(8, 40, 320, 1, 768),
|
42 |
+
PushBlock(),
|
43 |
+
ResnetBlock(320, 320, 1280),
|
44 |
+
AttentionBlock(8, 40, 320, 1, 768),
|
45 |
+
PushBlock(),
|
46 |
+
DownSampler(320),
|
47 |
+
PushBlock(),
|
48 |
+
# CrossAttnDownBlock2D
|
49 |
+
ResnetBlock(320, 640, 1280),
|
50 |
+
AttentionBlock(8, 80, 640, 1, 768),
|
51 |
+
PushBlock(),
|
52 |
+
ResnetBlock(640, 640, 1280),
|
53 |
+
AttentionBlock(8, 80, 640, 1, 768),
|
54 |
+
PushBlock(),
|
55 |
+
DownSampler(640),
|
56 |
+
PushBlock(),
|
57 |
+
# CrossAttnDownBlock2D
|
58 |
+
ResnetBlock(640, 1280, 1280),
|
59 |
+
AttentionBlock(8, 160, 1280, 1, 768),
|
60 |
+
PushBlock(),
|
61 |
+
ResnetBlock(1280, 1280, 1280),
|
62 |
+
AttentionBlock(8, 160, 1280, 1, 768),
|
63 |
+
PushBlock(),
|
64 |
+
DownSampler(1280),
|
65 |
+
PushBlock(),
|
66 |
+
# DownBlock2D
|
67 |
+
ResnetBlock(1280, 1280, 1280),
|
68 |
+
PushBlock(),
|
69 |
+
ResnetBlock(1280, 1280, 1280),
|
70 |
+
PushBlock(),
|
71 |
+
# UNetMidBlock2DCrossAttn
|
72 |
+
ResnetBlock(1280, 1280, 1280),
|
73 |
+
AttentionBlock(8, 160, 1280, 1, 768),
|
74 |
+
ResnetBlock(1280, 1280, 1280),
|
75 |
+
PushBlock()
|
76 |
+
])
|
77 |
+
|
78 |
+
self.controlnet_blocks = torch.nn.ModuleList([
|
79 |
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
80 |
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
|
81 |
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
|
82 |
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
|
83 |
+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
|
84 |
+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
|
85 |
+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
|
86 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
|
87 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
88 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
89 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
90 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
91 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
92 |
+
])
|
93 |
+
|
94 |
+
self.global_pool = global_pool
|
95 |
+
|
96 |
+
def forward(
|
97 |
+
self,
|
98 |
+
sample, timestep, encoder_hidden_states, conditioning,
|
99 |
+
tiled=False, tile_size=64, tile_stride=32,
|
100 |
+
):
|
101 |
+
# 1. time
|
102 |
+
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
103 |
+
time_emb = self.time_embedding(time_emb)
|
104 |
+
time_emb = time_emb.repeat(sample.shape[0], 1)
|
105 |
+
|
106 |
+
# 2. pre-process
|
107 |
+
height, width = sample.shape[2], sample.shape[3]
|
108 |
+
hidden_states = self.conv_in(sample) + self.controlnet_conv_in(conditioning)
|
109 |
+
text_emb = encoder_hidden_states
|
110 |
+
res_stack = [hidden_states]
|
111 |
+
|
112 |
+
# 3. blocks
|
113 |
+
for i, block in enumerate(self.blocks):
|
114 |
+
if tiled and not isinstance(block, PushBlock):
|
115 |
+
_, _, inter_height, _ = hidden_states.shape
|
116 |
+
resize_scale = inter_height / height
|
117 |
+
hidden_states = TileWorker().tiled_forward(
|
118 |
+
lambda x: block(x, time_emb, text_emb, res_stack)[0],
|
119 |
+
hidden_states,
|
120 |
+
int(tile_size * resize_scale),
|
121 |
+
int(tile_stride * resize_scale),
|
122 |
+
tile_device=hidden_states.device,
|
123 |
+
tile_dtype=hidden_states.dtype
|
124 |
+
)
|
125 |
+
else:
|
126 |
+
hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
|
127 |
+
|
128 |
+
# 4. ControlNet blocks
|
129 |
+
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
|
130 |
+
|
131 |
+
# pool
|
132 |
+
if self.global_pool:
|
133 |
+
controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack]
|
134 |
+
|
135 |
+
return controlnet_res_stack
|
136 |
+
|
137 |
+
def state_dict_converter(self):
|
138 |
+
return SDControlNetStateDictConverter()
|
139 |
+
|
140 |
+
|
141 |
+
class SDControlNetStateDictConverter:
|
142 |
+
def __init__(self):
|
143 |
+
pass
|
144 |
+
|
145 |
+
def from_diffusers(self, state_dict):
|
146 |
+
# architecture
|
147 |
+
block_types = [
|
148 |
+
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
|
149 |
+
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
|
150 |
+
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
|
151 |
+
'ResnetBlock', 'PushBlock', 'ResnetBlock', 'PushBlock',
|
152 |
+
'ResnetBlock', 'AttentionBlock', 'ResnetBlock',
|
153 |
+
'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'UpSampler',
|
154 |
+
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
|
155 |
+
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
|
156 |
+
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock'
|
157 |
+
]
|
158 |
+
|
159 |
+
# controlnet_rename_dict
|
160 |
+
controlnet_rename_dict = {
|
161 |
+
"controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight",
|
162 |
+
"controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias",
|
163 |
+
"controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight",
|
164 |
+
"controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias",
|
165 |
+
"controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight",
|
166 |
+
"controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias",
|
167 |
+
"controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight",
|
168 |
+
"controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias",
|
169 |
+
"controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight",
|
170 |
+
"controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias",
|
171 |
+
"controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight",
|
172 |
+
"controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias",
|
173 |
+
"controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
|
174 |
+
"controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
|
175 |
+
"controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
|
176 |
+
"controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
|
177 |
+
}
|
178 |
+
|
179 |
+
# Rename each parameter
|
180 |
+
name_list = sorted([name for name in state_dict])
|
181 |
+
rename_dict = {}
|
182 |
+
block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
|
183 |
+
last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
|
184 |
+
for name in name_list:
|
185 |
+
names = name.split(".")
|
186 |
+
if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
|
187 |
+
pass
|
188 |
+
elif name in controlnet_rename_dict:
|
189 |
+
names = controlnet_rename_dict[name].split(".")
|
190 |
+
elif names[0] == "controlnet_down_blocks":
|
191 |
+
names[0] = "controlnet_blocks"
|
192 |
+
elif names[0] == "controlnet_mid_block":
|
193 |
+
names = ["controlnet_blocks", "12", names[-1]]
|
194 |
+
elif names[0] in ["time_embedding", "add_embedding"]:
|
195 |
+
if names[0] == "add_embedding":
|
196 |
+
names[0] = "add_time_embedding"
|
197 |
+
names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
|
198 |
+
elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
|
199 |
+
if names[0] == "mid_block":
|
200 |
+
names.insert(1, "0")
|
201 |
+
block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
|
202 |
+
block_type_with_id = ".".join(names[:4])
|
203 |
+
if block_type_with_id != last_block_type_with_id[block_type]:
|
204 |
+
block_id[block_type] += 1
|
205 |
+
last_block_type_with_id[block_type] = block_type_with_id
|
206 |
+
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
207 |
+
block_id[block_type] += 1
|
208 |
+
block_type_with_id = ".".join(names[:4])
|
209 |
+
names = ["blocks", str(block_id[block_type])] + names[4:]
|
210 |
+
if "ff" in names:
|
211 |
+
ff_index = names.index("ff")
|
212 |
+
component = ".".join(names[ff_index:ff_index+3])
|
213 |
+
component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
|
214 |
+
names = names[:ff_index] + [component] + names[ff_index+3:]
|
215 |
+
if "to_out" in names:
|
216 |
+
names.pop(names.index("to_out") + 1)
|
217 |
+
else:
|
218 |
+
raise ValueError(f"Unknown parameters: {name}")
|
219 |
+
rename_dict[name] = ".".join(names)
|
220 |
+
|
221 |
+
# Convert state_dict
|
222 |
+
state_dict_ = {}
|
223 |
+
for name, param in state_dict.items():
|
224 |
+
if ".proj_in." in name or ".proj_out." in name:
|
225 |
+
param = param.squeeze()
|
226 |
+
if rename_dict[name] in [
|
227 |
+
"controlnet_blocks.1.bias", "controlnet_blocks.2.bias", "controlnet_blocks.3.bias", "controlnet_blocks.5.bias", "controlnet_blocks.6.bias",
|
228 |
+
"controlnet_blocks.8.bias", "controlnet_blocks.9.bias", "controlnet_blocks.10.bias", "controlnet_blocks.11.bias", "controlnet_blocks.12.bias"
|
229 |
+
]:
|
230 |
+
continue
|
231 |
+
state_dict_[rename_dict[name]] = param
|
232 |
+
return state_dict_
|
233 |
+
|
234 |
+
def from_civitai(self, state_dict):
|
235 |
+
if "mid_block.resnets.1.time_emb_proj.weight" in state_dict:
|
236 |
+
# For controlnets in diffusers format
|
237 |
+
return self.from_diffusers(state_dict)
|
238 |
+
rename_dict = {
|
239 |
+
"control_model.time_embed.0.weight": "time_embedding.0.weight",
|
240 |
+
"control_model.time_embed.0.bias": "time_embedding.0.bias",
|
241 |
+
"control_model.time_embed.2.weight": "time_embedding.2.weight",
|
242 |
+
"control_model.time_embed.2.bias": "time_embedding.2.bias",
|
243 |
+
"control_model.input_blocks.0.0.weight": "conv_in.weight",
|
244 |
+
"control_model.input_blocks.0.0.bias": "conv_in.bias",
|
245 |
+
"control_model.input_blocks.1.0.in_layers.0.weight": "blocks.0.norm1.weight",
|
246 |
+
"control_model.input_blocks.1.0.in_layers.0.bias": "blocks.0.norm1.bias",
|
247 |
+
"control_model.input_blocks.1.0.in_layers.2.weight": "blocks.0.conv1.weight",
|
248 |
+
"control_model.input_blocks.1.0.in_layers.2.bias": "blocks.0.conv1.bias",
|
249 |
+
"control_model.input_blocks.1.0.emb_layers.1.weight": "blocks.0.time_emb_proj.weight",
|
250 |
+
"control_model.input_blocks.1.0.emb_layers.1.bias": "blocks.0.time_emb_proj.bias",
|
251 |
+
"control_model.input_blocks.1.0.out_layers.0.weight": "blocks.0.norm2.weight",
|
252 |
+
"control_model.input_blocks.1.0.out_layers.0.bias": "blocks.0.norm2.bias",
|
253 |
+
"control_model.input_blocks.1.0.out_layers.3.weight": "blocks.0.conv2.weight",
|
254 |
+
"control_model.input_blocks.1.0.out_layers.3.bias": "blocks.0.conv2.bias",
|
255 |
+
"control_model.input_blocks.1.1.norm.weight": "blocks.1.norm.weight",
|
256 |
+
"control_model.input_blocks.1.1.norm.bias": "blocks.1.norm.bias",
|
257 |
+
"control_model.input_blocks.1.1.proj_in.weight": "blocks.1.proj_in.weight",
|
258 |
+
"control_model.input_blocks.1.1.proj_in.bias": "blocks.1.proj_in.bias",
|
259 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "blocks.1.transformer_blocks.0.attn1.to_q.weight",
|
260 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "blocks.1.transformer_blocks.0.attn1.to_k.weight",
|
261 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "blocks.1.transformer_blocks.0.attn1.to_v.weight",
|
262 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.1.transformer_blocks.0.attn1.to_out.weight",
|
263 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.1.transformer_blocks.0.attn1.to_out.bias",
|
264 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.1.transformer_blocks.0.act_fn.proj.weight",
|
265 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.1.transformer_blocks.0.act_fn.proj.bias",
|
266 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "blocks.1.transformer_blocks.0.ff.weight",
|
267 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "blocks.1.transformer_blocks.0.ff.bias",
|
268 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "blocks.1.transformer_blocks.0.attn2.to_q.weight",
|
269 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "blocks.1.transformer_blocks.0.attn2.to_k.weight",
|
270 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "blocks.1.transformer_blocks.0.attn2.to_v.weight",
|
271 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.1.transformer_blocks.0.attn2.to_out.weight",
|
272 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.1.transformer_blocks.0.attn2.to_out.bias",
|
273 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "blocks.1.transformer_blocks.0.norm1.weight",
|
274 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "blocks.1.transformer_blocks.0.norm1.bias",
|
275 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "blocks.1.transformer_blocks.0.norm2.weight",
|
276 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "blocks.1.transformer_blocks.0.norm2.bias",
|
277 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "blocks.1.transformer_blocks.0.norm3.weight",
|
278 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "blocks.1.transformer_blocks.0.norm3.bias",
|
279 |
+
"control_model.input_blocks.1.1.proj_out.weight": "blocks.1.proj_out.weight",
|
280 |
+
"control_model.input_blocks.1.1.proj_out.bias": "blocks.1.proj_out.bias",
|
281 |
+
"control_model.input_blocks.2.0.in_layers.0.weight": "blocks.3.norm1.weight",
|
282 |
+
"control_model.input_blocks.2.0.in_layers.0.bias": "blocks.3.norm1.bias",
|
283 |
+
"control_model.input_blocks.2.0.in_layers.2.weight": "blocks.3.conv1.weight",
|
284 |
+
"control_model.input_blocks.2.0.in_layers.2.bias": "blocks.3.conv1.bias",
|
285 |
+
"control_model.input_blocks.2.0.emb_layers.1.weight": "blocks.3.time_emb_proj.weight",
|
286 |
+
"control_model.input_blocks.2.0.emb_layers.1.bias": "blocks.3.time_emb_proj.bias",
|
287 |
+
"control_model.input_blocks.2.0.out_layers.0.weight": "blocks.3.norm2.weight",
|
288 |
+
"control_model.input_blocks.2.0.out_layers.0.bias": "blocks.3.norm2.bias",
|
289 |
+
"control_model.input_blocks.2.0.out_layers.3.weight": "blocks.3.conv2.weight",
|
290 |
+
"control_model.input_blocks.2.0.out_layers.3.bias": "blocks.3.conv2.bias",
|
291 |
+
"control_model.input_blocks.2.1.norm.weight": "blocks.4.norm.weight",
|
292 |
+
"control_model.input_blocks.2.1.norm.bias": "blocks.4.norm.bias",
|
293 |
+
"control_model.input_blocks.2.1.proj_in.weight": "blocks.4.proj_in.weight",
|
294 |
+
"control_model.input_blocks.2.1.proj_in.bias": "blocks.4.proj_in.bias",
|
295 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "blocks.4.transformer_blocks.0.attn1.to_q.weight",
|
296 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "blocks.4.transformer_blocks.0.attn1.to_k.weight",
|
297 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "blocks.4.transformer_blocks.0.attn1.to_v.weight",
|
298 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.4.transformer_blocks.0.attn1.to_out.weight",
|
299 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.4.transformer_blocks.0.attn1.to_out.bias",
|
300 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.4.transformer_blocks.0.act_fn.proj.weight",
|
301 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.4.transformer_blocks.0.act_fn.proj.bias",
|
302 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "blocks.4.transformer_blocks.0.ff.weight",
|
303 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "blocks.4.transformer_blocks.0.ff.bias",
|
304 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "blocks.4.transformer_blocks.0.attn2.to_q.weight",
|
305 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "blocks.4.transformer_blocks.0.attn2.to_k.weight",
|
306 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "blocks.4.transformer_blocks.0.attn2.to_v.weight",
|
307 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.4.transformer_blocks.0.attn2.to_out.weight",
|
308 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.4.transformer_blocks.0.attn2.to_out.bias",
|
309 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "blocks.4.transformer_blocks.0.norm1.weight",
|
310 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "blocks.4.transformer_blocks.0.norm1.bias",
|
311 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "blocks.4.transformer_blocks.0.norm2.weight",
|
312 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "blocks.4.transformer_blocks.0.norm2.bias",
|
313 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "blocks.4.transformer_blocks.0.norm3.weight",
|
314 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "blocks.4.transformer_blocks.0.norm3.bias",
|
315 |
+
"control_model.input_blocks.2.1.proj_out.weight": "blocks.4.proj_out.weight",
|
316 |
+
"control_model.input_blocks.2.1.proj_out.bias": "blocks.4.proj_out.bias",
|
317 |
+
"control_model.input_blocks.3.0.op.weight": "blocks.6.conv.weight",
|
318 |
+
"control_model.input_blocks.3.0.op.bias": "blocks.6.conv.bias",
|
319 |
+
"control_model.input_blocks.4.0.in_layers.0.weight": "blocks.8.norm1.weight",
|
320 |
+
"control_model.input_blocks.4.0.in_layers.0.bias": "blocks.8.norm1.bias",
|
321 |
+
"control_model.input_blocks.4.0.in_layers.2.weight": "blocks.8.conv1.weight",
|
322 |
+
"control_model.input_blocks.4.0.in_layers.2.bias": "blocks.8.conv1.bias",
|
323 |
+
"control_model.input_blocks.4.0.emb_layers.1.weight": "blocks.8.time_emb_proj.weight",
|
324 |
+
"control_model.input_blocks.4.0.emb_layers.1.bias": "blocks.8.time_emb_proj.bias",
|
325 |
+
"control_model.input_blocks.4.0.out_layers.0.weight": "blocks.8.norm2.weight",
|
326 |
+
"control_model.input_blocks.4.0.out_layers.0.bias": "blocks.8.norm2.bias",
|
327 |
+
"control_model.input_blocks.4.0.out_layers.3.weight": "blocks.8.conv2.weight",
|
328 |
+
"control_model.input_blocks.4.0.out_layers.3.bias": "blocks.8.conv2.bias",
|
329 |
+
"control_model.input_blocks.4.0.skip_connection.weight": "blocks.8.conv_shortcut.weight",
|
330 |
+
"control_model.input_blocks.4.0.skip_connection.bias": "blocks.8.conv_shortcut.bias",
|
331 |
+
"control_model.input_blocks.4.1.norm.weight": "blocks.9.norm.weight",
|
332 |
+
"control_model.input_blocks.4.1.norm.bias": "blocks.9.norm.bias",
|
333 |
+
"control_model.input_blocks.4.1.proj_in.weight": "blocks.9.proj_in.weight",
|
334 |
+
"control_model.input_blocks.4.1.proj_in.bias": "blocks.9.proj_in.bias",
|
335 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.9.transformer_blocks.0.attn1.to_q.weight",
|
336 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.9.transformer_blocks.0.attn1.to_k.weight",
|
337 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.9.transformer_blocks.0.attn1.to_v.weight",
|
338 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.9.transformer_blocks.0.attn1.to_out.weight",
|
339 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.9.transformer_blocks.0.attn1.to_out.bias",
|
340 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.9.transformer_blocks.0.act_fn.proj.weight",
|
341 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.9.transformer_blocks.0.act_fn.proj.bias",
|
342 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.9.transformer_blocks.0.ff.weight",
|
343 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.9.transformer_blocks.0.ff.bias",
|
344 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.9.transformer_blocks.0.attn2.to_q.weight",
|
345 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.9.transformer_blocks.0.attn2.to_k.weight",
|
346 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.9.transformer_blocks.0.attn2.to_v.weight",
|
347 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.9.transformer_blocks.0.attn2.to_out.weight",
|
348 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.9.transformer_blocks.0.attn2.to_out.bias",
|
349 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.9.transformer_blocks.0.norm1.weight",
|
350 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.9.transformer_blocks.0.norm1.bias",
|
351 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.9.transformer_blocks.0.norm2.weight",
|
352 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.9.transformer_blocks.0.norm2.bias",
|
353 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.9.transformer_blocks.0.norm3.weight",
|
354 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.9.transformer_blocks.0.norm3.bias",
|
355 |
+
"control_model.input_blocks.4.1.proj_out.weight": "blocks.9.proj_out.weight",
|
356 |
+
"control_model.input_blocks.4.1.proj_out.bias": "blocks.9.proj_out.bias",
|
357 |
+
"control_model.input_blocks.5.0.in_layers.0.weight": "blocks.11.norm1.weight",
|
358 |
+
"control_model.input_blocks.5.0.in_layers.0.bias": "blocks.11.norm1.bias",
|
359 |
+
"control_model.input_blocks.5.0.in_layers.2.weight": "blocks.11.conv1.weight",
|
360 |
+
"control_model.input_blocks.5.0.in_layers.2.bias": "blocks.11.conv1.bias",
|
361 |
+
"control_model.input_blocks.5.0.emb_layers.1.weight": "blocks.11.time_emb_proj.weight",
|
362 |
+
"control_model.input_blocks.5.0.emb_layers.1.bias": "blocks.11.time_emb_proj.bias",
|
363 |
+
"control_model.input_blocks.5.0.out_layers.0.weight": "blocks.11.norm2.weight",
|
364 |
+
"control_model.input_blocks.5.0.out_layers.0.bias": "blocks.11.norm2.bias",
|
365 |
+
"control_model.input_blocks.5.0.out_layers.3.weight": "blocks.11.conv2.weight",
|
366 |
+
"control_model.input_blocks.5.0.out_layers.3.bias": "blocks.11.conv2.bias",
|
367 |
+
"control_model.input_blocks.5.1.norm.weight": "blocks.12.norm.weight",
|
368 |
+
"control_model.input_blocks.5.1.norm.bias": "blocks.12.norm.bias",
|
369 |
+
"control_model.input_blocks.5.1.proj_in.weight": "blocks.12.proj_in.weight",
|
370 |
+
"control_model.input_blocks.5.1.proj_in.bias": "blocks.12.proj_in.bias",
|
371 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.12.transformer_blocks.0.attn1.to_q.weight",
|
372 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.12.transformer_blocks.0.attn1.to_k.weight",
|
373 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.12.transformer_blocks.0.attn1.to_v.weight",
|
374 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.12.transformer_blocks.0.attn1.to_out.weight",
|
375 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.12.transformer_blocks.0.attn1.to_out.bias",
|
376 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.12.transformer_blocks.0.act_fn.proj.weight",
|
377 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.12.transformer_blocks.0.act_fn.proj.bias",
|
378 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.12.transformer_blocks.0.ff.weight",
|
379 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.12.transformer_blocks.0.ff.bias",
|
380 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.12.transformer_blocks.0.attn2.to_q.weight",
|
381 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.12.transformer_blocks.0.attn2.to_k.weight",
|
382 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.12.transformer_blocks.0.attn2.to_v.weight",
|
383 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.12.transformer_blocks.0.attn2.to_out.weight",
|
384 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.12.transformer_blocks.0.attn2.to_out.bias",
|
385 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.12.transformer_blocks.0.norm1.weight",
|
386 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.12.transformer_blocks.0.norm1.bias",
|
387 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.12.transformer_blocks.0.norm2.weight",
|
388 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.12.transformer_blocks.0.norm2.bias",
|
389 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.12.transformer_blocks.0.norm3.weight",
|
390 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.12.transformer_blocks.0.norm3.bias",
|
391 |
+
"control_model.input_blocks.5.1.proj_out.weight": "blocks.12.proj_out.weight",
|
392 |
+
"control_model.input_blocks.5.1.proj_out.bias": "blocks.12.proj_out.bias",
|
393 |
+
"control_model.input_blocks.6.0.op.weight": "blocks.14.conv.weight",
|
394 |
+
"control_model.input_blocks.6.0.op.bias": "blocks.14.conv.bias",
|
395 |
+
"control_model.input_blocks.7.0.in_layers.0.weight": "blocks.16.norm1.weight",
|
396 |
+
"control_model.input_blocks.7.0.in_layers.0.bias": "blocks.16.norm1.bias",
|
397 |
+
"control_model.input_blocks.7.0.in_layers.2.weight": "blocks.16.conv1.weight",
|
398 |
+
"control_model.input_blocks.7.0.in_layers.2.bias": "blocks.16.conv1.bias",
|
399 |
+
"control_model.input_blocks.7.0.emb_layers.1.weight": "blocks.16.time_emb_proj.weight",
|
400 |
+
"control_model.input_blocks.7.0.emb_layers.1.bias": "blocks.16.time_emb_proj.bias",
|
401 |
+
"control_model.input_blocks.7.0.out_layers.0.weight": "blocks.16.norm2.weight",
|
402 |
+
"control_model.input_blocks.7.0.out_layers.0.bias": "blocks.16.norm2.bias",
|
403 |
+
"control_model.input_blocks.7.0.out_layers.3.weight": "blocks.16.conv2.weight",
|
404 |
+
"control_model.input_blocks.7.0.out_layers.3.bias": "blocks.16.conv2.bias",
|
405 |
+
"control_model.input_blocks.7.0.skip_connection.weight": "blocks.16.conv_shortcut.weight",
|
406 |
+
"control_model.input_blocks.7.0.skip_connection.bias": "blocks.16.conv_shortcut.bias",
|
407 |
+
"control_model.input_blocks.7.1.norm.weight": "blocks.17.norm.weight",
|
408 |
+
"control_model.input_blocks.7.1.norm.bias": "blocks.17.norm.bias",
|
409 |
+
"control_model.input_blocks.7.1.proj_in.weight": "blocks.17.proj_in.weight",
|
410 |
+
"control_model.input_blocks.7.1.proj_in.bias": "blocks.17.proj_in.bias",
|
411 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.17.transformer_blocks.0.attn1.to_q.weight",
|
412 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.17.transformer_blocks.0.attn1.to_k.weight",
|
413 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.17.transformer_blocks.0.attn1.to_v.weight",
|
414 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.17.transformer_blocks.0.attn1.to_out.weight",
|
415 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.17.transformer_blocks.0.attn1.to_out.bias",
|
416 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.17.transformer_blocks.0.act_fn.proj.weight",
|
417 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.17.transformer_blocks.0.act_fn.proj.bias",
|
418 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.17.transformer_blocks.0.ff.weight",
|
419 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.17.transformer_blocks.0.ff.bias",
|
420 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.17.transformer_blocks.0.attn2.to_q.weight",
|
421 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.17.transformer_blocks.0.attn2.to_k.weight",
|
422 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.17.transformer_blocks.0.attn2.to_v.weight",
|
423 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.17.transformer_blocks.0.attn2.to_out.weight",
|
424 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.17.transformer_blocks.0.attn2.to_out.bias",
|
425 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.17.transformer_blocks.0.norm1.weight",
|
426 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.17.transformer_blocks.0.norm1.bias",
|
427 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.17.transformer_blocks.0.norm2.weight",
|
428 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.17.transformer_blocks.0.norm2.bias",
|
429 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.17.transformer_blocks.0.norm3.weight",
|
430 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.17.transformer_blocks.0.norm3.bias",
|
431 |
+
"control_model.input_blocks.7.1.proj_out.weight": "blocks.17.proj_out.weight",
|
432 |
+
"control_model.input_blocks.7.1.proj_out.bias": "blocks.17.proj_out.bias",
|
433 |
+
"control_model.input_blocks.8.0.in_layers.0.weight": "blocks.19.norm1.weight",
|
434 |
+
"control_model.input_blocks.8.0.in_layers.0.bias": "blocks.19.norm1.bias",
|
435 |
+
"control_model.input_blocks.8.0.in_layers.2.weight": "blocks.19.conv1.weight",
|
436 |
+
"control_model.input_blocks.8.0.in_layers.2.bias": "blocks.19.conv1.bias",
|
437 |
+
"control_model.input_blocks.8.0.emb_layers.1.weight": "blocks.19.time_emb_proj.weight",
|
438 |
+
"control_model.input_blocks.8.0.emb_layers.1.bias": "blocks.19.time_emb_proj.bias",
|
439 |
+
"control_model.input_blocks.8.0.out_layers.0.weight": "blocks.19.norm2.weight",
|
440 |
+
"control_model.input_blocks.8.0.out_layers.0.bias": "blocks.19.norm2.bias",
|
441 |
+
"control_model.input_blocks.8.0.out_layers.3.weight": "blocks.19.conv2.weight",
|
442 |
+
"control_model.input_blocks.8.0.out_layers.3.bias": "blocks.19.conv2.bias",
|
443 |
+
"control_model.input_blocks.8.1.norm.weight": "blocks.20.norm.weight",
|
444 |
+
"control_model.input_blocks.8.1.norm.bias": "blocks.20.norm.bias",
|
445 |
+
"control_model.input_blocks.8.1.proj_in.weight": "blocks.20.proj_in.weight",
|
446 |
+
"control_model.input_blocks.8.1.proj_in.bias": "blocks.20.proj_in.bias",
|
447 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.20.transformer_blocks.0.attn1.to_q.weight",
|
448 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.20.transformer_blocks.0.attn1.to_k.weight",
|
449 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.20.transformer_blocks.0.attn1.to_v.weight",
|
450 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.20.transformer_blocks.0.attn1.to_out.weight",
|
451 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.20.transformer_blocks.0.attn1.to_out.bias",
|
452 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.20.transformer_blocks.0.act_fn.proj.weight",
|
453 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.20.transformer_blocks.0.act_fn.proj.bias",
|
454 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.20.transformer_blocks.0.ff.weight",
|
455 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.20.transformer_blocks.0.ff.bias",
|
456 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.20.transformer_blocks.0.attn2.to_q.weight",
|
457 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.20.transformer_blocks.0.attn2.to_k.weight",
|
458 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.20.transformer_blocks.0.attn2.to_v.weight",
|
459 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.20.transformer_blocks.0.attn2.to_out.weight",
|
460 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.20.transformer_blocks.0.attn2.to_out.bias",
|
461 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.20.transformer_blocks.0.norm1.weight",
|
462 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.20.transformer_blocks.0.norm1.bias",
|
463 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.20.transformer_blocks.0.norm2.weight",
|
464 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.20.transformer_blocks.0.norm2.bias",
|
465 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.20.transformer_blocks.0.norm3.weight",
|
466 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.20.transformer_blocks.0.norm3.bias",
|
467 |
+
"control_model.input_blocks.8.1.proj_out.weight": "blocks.20.proj_out.weight",
|
468 |
+
"control_model.input_blocks.8.1.proj_out.bias": "blocks.20.proj_out.bias",
|
469 |
+
"control_model.input_blocks.9.0.op.weight": "blocks.22.conv.weight",
|
470 |
+
"control_model.input_blocks.9.0.op.bias": "blocks.22.conv.bias",
|
471 |
+
"control_model.input_blocks.10.0.in_layers.0.weight": "blocks.24.norm1.weight",
|
472 |
+
"control_model.input_blocks.10.0.in_layers.0.bias": "blocks.24.norm1.bias",
|
473 |
+
"control_model.input_blocks.10.0.in_layers.2.weight": "blocks.24.conv1.weight",
|
474 |
+
"control_model.input_blocks.10.0.in_layers.2.bias": "blocks.24.conv1.bias",
|
475 |
+
"control_model.input_blocks.10.0.emb_layers.1.weight": "blocks.24.time_emb_proj.weight",
|
476 |
+
"control_model.input_blocks.10.0.emb_layers.1.bias": "blocks.24.time_emb_proj.bias",
|
477 |
+
"control_model.input_blocks.10.0.out_layers.0.weight": "blocks.24.norm2.weight",
|
478 |
+
"control_model.input_blocks.10.0.out_layers.0.bias": "blocks.24.norm2.bias",
|
479 |
+
"control_model.input_blocks.10.0.out_layers.3.weight": "blocks.24.conv2.weight",
|
480 |
+
"control_model.input_blocks.10.0.out_layers.3.bias": "blocks.24.conv2.bias",
|
481 |
+
"control_model.input_blocks.11.0.in_layers.0.weight": "blocks.26.norm1.weight",
|
482 |
+
"control_model.input_blocks.11.0.in_layers.0.bias": "blocks.26.norm1.bias",
|
483 |
+
"control_model.input_blocks.11.0.in_layers.2.weight": "blocks.26.conv1.weight",
|
484 |
+
"control_model.input_blocks.11.0.in_layers.2.bias": "blocks.26.conv1.bias",
|
485 |
+
"control_model.input_blocks.11.0.emb_layers.1.weight": "blocks.26.time_emb_proj.weight",
|
486 |
+
"control_model.input_blocks.11.0.emb_layers.1.bias": "blocks.26.time_emb_proj.bias",
|
487 |
+
"control_model.input_blocks.11.0.out_layers.0.weight": "blocks.26.norm2.weight",
|
488 |
+
"control_model.input_blocks.11.0.out_layers.0.bias": "blocks.26.norm2.bias",
|
489 |
+
"control_model.input_blocks.11.0.out_layers.3.weight": "blocks.26.conv2.weight",
|
490 |
+
"control_model.input_blocks.11.0.out_layers.3.bias": "blocks.26.conv2.bias",
|
491 |
+
"control_model.zero_convs.0.0.weight": "controlnet_blocks.0.weight",
|
492 |
+
"control_model.zero_convs.0.0.bias": "controlnet_blocks.0.bias",
|
493 |
+
"control_model.zero_convs.1.0.weight": "controlnet_blocks.1.weight",
|
494 |
+
"control_model.zero_convs.1.0.bias": "controlnet_blocks.0.bias",
|
495 |
+
"control_model.zero_convs.2.0.weight": "controlnet_blocks.2.weight",
|
496 |
+
"control_model.zero_convs.2.0.bias": "controlnet_blocks.0.bias",
|
497 |
+
"control_model.zero_convs.3.0.weight": "controlnet_blocks.3.weight",
|
498 |
+
"control_model.zero_convs.3.0.bias": "controlnet_blocks.0.bias",
|
499 |
+
"control_model.zero_convs.4.0.weight": "controlnet_blocks.4.weight",
|
500 |
+
"control_model.zero_convs.4.0.bias": "controlnet_blocks.4.bias",
|
501 |
+
"control_model.zero_convs.5.0.weight": "controlnet_blocks.5.weight",
|
502 |
+
"control_model.zero_convs.5.0.bias": "controlnet_blocks.4.bias",
|
503 |
+
"control_model.zero_convs.6.0.weight": "controlnet_blocks.6.weight",
|
504 |
+
"control_model.zero_convs.6.0.bias": "controlnet_blocks.4.bias",
|
505 |
+
"control_model.zero_convs.7.0.weight": "controlnet_blocks.7.weight",
|
506 |
+
"control_model.zero_convs.7.0.bias": "controlnet_blocks.7.bias",
|
507 |
+
"control_model.zero_convs.8.0.weight": "controlnet_blocks.8.weight",
|
508 |
+
"control_model.zero_convs.8.0.bias": "controlnet_blocks.7.bias",
|
509 |
+
"control_model.zero_convs.9.0.weight": "controlnet_blocks.9.weight",
|
510 |
+
"control_model.zero_convs.9.0.bias": "controlnet_blocks.7.bias",
|
511 |
+
"control_model.zero_convs.10.0.weight": "controlnet_blocks.10.weight",
|
512 |
+
"control_model.zero_convs.10.0.bias": "controlnet_blocks.7.bias",
|
513 |
+
"control_model.zero_convs.11.0.weight": "controlnet_blocks.11.weight",
|
514 |
+
"control_model.zero_convs.11.0.bias": "controlnet_blocks.7.bias",
|
515 |
+
"control_model.input_hint_block.0.weight": "controlnet_conv_in.blocks.0.weight",
|
516 |
+
"control_model.input_hint_block.0.bias": "controlnet_conv_in.blocks.0.bias",
|
517 |
+
"control_model.input_hint_block.2.weight": "controlnet_conv_in.blocks.2.weight",
|
518 |
+
"control_model.input_hint_block.2.bias": "controlnet_conv_in.blocks.2.bias",
|
519 |
+
"control_model.input_hint_block.4.weight": "controlnet_conv_in.blocks.4.weight",
|
520 |
+
"control_model.input_hint_block.4.bias": "controlnet_conv_in.blocks.4.bias",
|
521 |
+
"control_model.input_hint_block.6.weight": "controlnet_conv_in.blocks.6.weight",
|
522 |
+
"control_model.input_hint_block.6.bias": "controlnet_conv_in.blocks.6.bias",
|
523 |
+
"control_model.input_hint_block.8.weight": "controlnet_conv_in.blocks.8.weight",
|
524 |
+
"control_model.input_hint_block.8.bias": "controlnet_conv_in.blocks.8.bias",
|
525 |
+
"control_model.input_hint_block.10.weight": "controlnet_conv_in.blocks.10.weight",
|
526 |
+
"control_model.input_hint_block.10.bias": "controlnet_conv_in.blocks.10.bias",
|
527 |
+
"control_model.input_hint_block.12.weight": "controlnet_conv_in.blocks.12.weight",
|
528 |
+
"control_model.input_hint_block.12.bias": "controlnet_conv_in.blocks.12.bias",
|
529 |
+
"control_model.input_hint_block.14.weight": "controlnet_conv_in.blocks.14.weight",
|
530 |
+
"control_model.input_hint_block.14.bias": "controlnet_conv_in.blocks.14.bias",
|
531 |
+
"control_model.middle_block.0.in_layers.0.weight": "blocks.28.norm1.weight",
|
532 |
+
"control_model.middle_block.0.in_layers.0.bias": "blocks.28.norm1.bias",
|
533 |
+
"control_model.middle_block.0.in_layers.2.weight": "blocks.28.conv1.weight",
|
534 |
+
"control_model.middle_block.0.in_layers.2.bias": "blocks.28.conv1.bias",
|
535 |
+
"control_model.middle_block.0.emb_layers.1.weight": "blocks.28.time_emb_proj.weight",
|
536 |
+
"control_model.middle_block.0.emb_layers.1.bias": "blocks.28.time_emb_proj.bias",
|
537 |
+
"control_model.middle_block.0.out_layers.0.weight": "blocks.28.norm2.weight",
|
538 |
+
"control_model.middle_block.0.out_layers.0.bias": "blocks.28.norm2.bias",
|
539 |
+
"control_model.middle_block.0.out_layers.3.weight": "blocks.28.conv2.weight",
|
540 |
+
"control_model.middle_block.0.out_layers.3.bias": "blocks.28.conv2.bias",
|
541 |
+
"control_model.middle_block.1.norm.weight": "blocks.29.norm.weight",
|
542 |
+
"control_model.middle_block.1.norm.bias": "blocks.29.norm.bias",
|
543 |
+
"control_model.middle_block.1.proj_in.weight": "blocks.29.proj_in.weight",
|
544 |
+
"control_model.middle_block.1.proj_in.bias": "blocks.29.proj_in.bias",
|
545 |
+
"control_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "blocks.29.transformer_blocks.0.attn1.to_q.weight",
|
546 |
+
"control_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "blocks.29.transformer_blocks.0.attn1.to_k.weight",
|
547 |
+
"control_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "blocks.29.transformer_blocks.0.attn1.to_v.weight",
|
548 |
+
"control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.29.transformer_blocks.0.attn1.to_out.weight",
|
549 |
+
"control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.29.transformer_blocks.0.attn1.to_out.bias",
|
550 |
+
"control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.29.transformer_blocks.0.act_fn.proj.weight",
|
551 |
+
"control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.29.transformer_blocks.0.act_fn.proj.bias",
|
552 |
+
"control_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "blocks.29.transformer_blocks.0.ff.weight",
|
553 |
+
"control_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "blocks.29.transformer_blocks.0.ff.bias",
|
554 |
+
"control_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "blocks.29.transformer_blocks.0.attn2.to_q.weight",
|
555 |
+
"control_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "blocks.29.transformer_blocks.0.attn2.to_k.weight",
|
556 |
+
"control_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "blocks.29.transformer_blocks.0.attn2.to_v.weight",
|
557 |
+
"control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.29.transformer_blocks.0.attn2.to_out.weight",
|
558 |
+
"control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.29.transformer_blocks.0.attn2.to_out.bias",
|
559 |
+
"control_model.middle_block.1.transformer_blocks.0.norm1.weight": "blocks.29.transformer_blocks.0.norm1.weight",
|
560 |
+
"control_model.middle_block.1.transformer_blocks.0.norm1.bias": "blocks.29.transformer_blocks.0.norm1.bias",
|
561 |
+
"control_model.middle_block.1.transformer_blocks.0.norm2.weight": "blocks.29.transformer_blocks.0.norm2.weight",
|
562 |
+
"control_model.middle_block.1.transformer_blocks.0.norm2.bias": "blocks.29.transformer_blocks.0.norm2.bias",
|
563 |
+
"control_model.middle_block.1.transformer_blocks.0.norm3.weight": "blocks.29.transformer_blocks.0.norm3.weight",
|
564 |
+
"control_model.middle_block.1.transformer_blocks.0.norm3.bias": "blocks.29.transformer_blocks.0.norm3.bias",
|
565 |
+
"control_model.middle_block.1.proj_out.weight": "blocks.29.proj_out.weight",
|
566 |
+
"control_model.middle_block.1.proj_out.bias": "blocks.29.proj_out.bias",
|
567 |
+
"control_model.middle_block.2.in_layers.0.weight": "blocks.30.norm1.weight",
|
568 |
+
"control_model.middle_block.2.in_layers.0.bias": "blocks.30.norm1.bias",
|
569 |
+
"control_model.middle_block.2.in_layers.2.weight": "blocks.30.conv1.weight",
|
570 |
+
"control_model.middle_block.2.in_layers.2.bias": "blocks.30.conv1.bias",
|
571 |
+
"control_model.middle_block.2.emb_layers.1.weight": "blocks.30.time_emb_proj.weight",
|
572 |
+
"control_model.middle_block.2.emb_layers.1.bias": "blocks.30.time_emb_proj.bias",
|
573 |
+
"control_model.middle_block.2.out_layers.0.weight": "blocks.30.norm2.weight",
|
574 |
+
"control_model.middle_block.2.out_layers.0.bias": "blocks.30.norm2.bias",
|
575 |
+
"control_model.middle_block.2.out_layers.3.weight": "blocks.30.conv2.weight",
|
576 |
+
"control_model.middle_block.2.out_layers.3.bias": "blocks.30.conv2.bias",
|
577 |
+
"control_model.middle_block_out.0.weight": "controlnet_blocks.12.weight",
|
578 |
+
"control_model.middle_block_out.0.bias": "controlnet_blocks.7.bias",
|
579 |
+
}
|
580 |
+
state_dict_ = {}
|
581 |
+
for name in state_dict:
|
582 |
+
if name in rename_dict:
|
583 |
+
param = state_dict[name]
|
584 |
+
if ".proj_in." in name or ".proj_out." in name:
|
585 |
+
param = param.squeeze()
|
586 |
+
state_dict_[rename_dict[name]] = param
|
587 |
+
return state_dict_
|
diffsynth/models/sd_ipadapter.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .svd_image_encoder import SVDImageEncoder
|
2 |
+
from .sdxl_ipadapter import IpAdapterImageProjModel, IpAdapterModule, SDXLIpAdapterStateDictConverter
|
3 |
+
from transformers import CLIPImageProcessor
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
class IpAdapterCLIPImageEmbedder(SVDImageEncoder):
|
8 |
+
def __init__(self):
|
9 |
+
super().__init__()
|
10 |
+
self.image_processor = CLIPImageProcessor()
|
11 |
+
|
12 |
+
def forward(self, image):
|
13 |
+
pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values
|
14 |
+
pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype)
|
15 |
+
return super().forward(pixel_values)
|
16 |
+
|
17 |
+
|
18 |
+
class SDIpAdapter(torch.nn.Module):
|
19 |
+
def __init__(self):
|
20 |
+
super().__init__()
|
21 |
+
shape_list = [(768, 320)] * 2 + [(768, 640)] * 2 + [(768, 1280)] * 5 + [(768, 640)] * 3 + [(768, 320)] * 3 + [(768, 1280)] * 1
|
22 |
+
self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list])
|
23 |
+
self.image_proj = IpAdapterImageProjModel(cross_attention_dim=768, clip_embeddings_dim=1024, clip_extra_context_tokens=4)
|
24 |
+
self.set_full_adapter()
|
25 |
+
|
26 |
+
def set_full_adapter(self):
|
27 |
+
block_ids = [1, 4, 9, 12, 17, 20, 40, 43, 46, 50, 53, 56, 60, 63, 66, 29]
|
28 |
+
self.call_block_id = {(i, 0): j for j, i in enumerate(block_ids)}
|
29 |
+
|
30 |
+
def set_less_adapter(self):
|
31 |
+
# IP-Adapter for SD v1.5 doesn't support this feature.
|
32 |
+
self.set_full_adapter(self)
|
33 |
+
|
34 |
+
def forward(self, hidden_states, scale=1.0):
|
35 |
+
hidden_states = self.image_proj(hidden_states)
|
36 |
+
hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
|
37 |
+
ip_kv_dict = {}
|
38 |
+
for (block_id, transformer_id) in self.call_block_id:
|
39 |
+
ipadapter_id = self.call_block_id[(block_id, transformer_id)]
|
40 |
+
ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
|
41 |
+
if block_id not in ip_kv_dict:
|
42 |
+
ip_kv_dict[block_id] = {}
|
43 |
+
ip_kv_dict[block_id][transformer_id] = {
|
44 |
+
"ip_k": ip_k,
|
45 |
+
"ip_v": ip_v,
|
46 |
+
"scale": scale
|
47 |
+
}
|
48 |
+
return ip_kv_dict
|
49 |
+
|
50 |
+
def state_dict_converter(self):
|
51 |
+
return SDIpAdapterStateDictConverter()
|
52 |
+
|
53 |
+
|
54 |
+
class SDIpAdapterStateDictConverter(SDXLIpAdapterStateDictConverter):
|
55 |
+
def __init__(self):
|
56 |
+
pass
|
diffsynth/models/sd_lora.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .sd_unet import SDUNetStateDictConverter, SDUNet
|
3 |
+
from .sd_text_encoder import SDTextEncoderStateDictConverter, SDTextEncoder
|
4 |
+
|
5 |
+
|
6 |
+
class SDLoRA:
|
7 |
+
def __init__(self):
|
8 |
+
pass
|
9 |
+
|
10 |
+
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0, device="cuda"):
|
11 |
+
special_keys = {
|
12 |
+
"down.blocks": "down_blocks",
|
13 |
+
"up.blocks": "up_blocks",
|
14 |
+
"mid.block": "mid_block",
|
15 |
+
"proj.in": "proj_in",
|
16 |
+
"proj.out": "proj_out",
|
17 |
+
"transformer.blocks": "transformer_blocks",
|
18 |
+
"to.q": "to_q",
|
19 |
+
"to.k": "to_k",
|
20 |
+
"to.v": "to_v",
|
21 |
+
"to.out": "to_out",
|
22 |
+
}
|
23 |
+
state_dict_ = {}
|
24 |
+
for key in state_dict:
|
25 |
+
if ".lora_up" not in key:
|
26 |
+
continue
|
27 |
+
if not key.startswith(lora_prefix):
|
28 |
+
continue
|
29 |
+
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
|
30 |
+
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
|
31 |
+
if len(weight_up.shape) == 4:
|
32 |
+
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
|
33 |
+
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
|
34 |
+
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
35 |
+
else:
|
36 |
+
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
37 |
+
target_name = key.split(".")[0].replace("_", ".")[len(lora_prefix):] + ".weight"
|
38 |
+
for special_key in special_keys:
|
39 |
+
target_name = target_name.replace(special_key, special_keys[special_key])
|
40 |
+
state_dict_[target_name] = lora_weight.cpu()
|
41 |
+
return state_dict_
|
42 |
+
|
43 |
+
def add_lora_to_unet(self, unet: SDUNet, state_dict_lora, alpha=1.0, device="cuda"):
|
44 |
+
state_dict_unet = unet.state_dict()
|
45 |
+
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_unet_", alpha=alpha, device=device)
|
46 |
+
state_dict_lora = SDUNetStateDictConverter().from_diffusers(state_dict_lora)
|
47 |
+
if len(state_dict_lora) > 0:
|
48 |
+
for name in state_dict_lora:
|
49 |
+
state_dict_unet[name] += state_dict_lora[name].to(device=device)
|
50 |
+
unet.load_state_dict(state_dict_unet)
|
51 |
+
|
52 |
+
def add_lora_to_text_encoder(self, text_encoder: SDTextEncoder, state_dict_lora, alpha=1.0, device="cuda"):
|
53 |
+
state_dict_text_encoder = text_encoder.state_dict()
|
54 |
+
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_te_", alpha=alpha, device=device)
|
55 |
+
state_dict_lora = SDTextEncoderStateDictConverter().from_diffusers(state_dict_lora)
|
56 |
+
if len(state_dict_lora) > 0:
|
57 |
+
for name in state_dict_lora:
|
58 |
+
state_dict_text_encoder[name] += state_dict_lora[name].to(device=device)
|
59 |
+
text_encoder.load_state_dict(state_dict_text_encoder)
|
60 |
+
|
diffsynth/models/sd_motion.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .sd_unet import SDUNet, Attention, GEGLU
|
2 |
+
import torch
|
3 |
+
from einops import rearrange, repeat
|
4 |
+
|
5 |
+
|
6 |
+
class TemporalTransformerBlock(torch.nn.Module):
|
7 |
+
|
8 |
+
def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
# 1. Self-Attn
|
12 |
+
self.pe1 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
|
13 |
+
self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
14 |
+
self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
|
15 |
+
|
16 |
+
# 2. Cross-Attn
|
17 |
+
self.pe2 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
|
18 |
+
self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
19 |
+
self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
|
20 |
+
|
21 |
+
# 3. Feed-forward
|
22 |
+
self.norm3 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
23 |
+
self.act_fn = GEGLU(dim, dim * 4)
|
24 |
+
self.ff = torch.nn.Linear(dim * 4, dim)
|
25 |
+
|
26 |
+
|
27 |
+
def forward(self, hidden_states, batch_size=1):
|
28 |
+
|
29 |
+
# 1. Self-Attention
|
30 |
+
norm_hidden_states = self.norm1(hidden_states)
|
31 |
+
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
|
32 |
+
attn_output = self.attn1(norm_hidden_states + self.pe1[:, :norm_hidden_states.shape[1]])
|
33 |
+
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
|
34 |
+
hidden_states = attn_output + hidden_states
|
35 |
+
|
36 |
+
# 2. Cross-Attention
|
37 |
+
norm_hidden_states = self.norm2(hidden_states)
|
38 |
+
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
|
39 |
+
attn_output = self.attn2(norm_hidden_states + self.pe2[:, :norm_hidden_states.shape[1]])
|
40 |
+
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
|
41 |
+
hidden_states = attn_output + hidden_states
|
42 |
+
|
43 |
+
# 3. Feed-forward
|
44 |
+
norm_hidden_states = self.norm3(hidden_states)
|
45 |
+
ff_output = self.act_fn(norm_hidden_states)
|
46 |
+
ff_output = self.ff(ff_output)
|
47 |
+
hidden_states = ff_output + hidden_states
|
48 |
+
|
49 |
+
return hidden_states
|
50 |
+
|
51 |
+
|
52 |
+
class TemporalBlock(torch.nn.Module):
|
53 |
+
|
54 |
+
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
|
55 |
+
super().__init__()
|
56 |
+
inner_dim = num_attention_heads * attention_head_dim
|
57 |
+
|
58 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
59 |
+
self.proj_in = torch.nn.Linear(in_channels, inner_dim)
|
60 |
+
|
61 |
+
self.transformer_blocks = torch.nn.ModuleList([
|
62 |
+
TemporalTransformerBlock(
|
63 |
+
inner_dim,
|
64 |
+
num_attention_heads,
|
65 |
+
attention_head_dim
|
66 |
+
)
|
67 |
+
for d in range(num_layers)
|
68 |
+
])
|
69 |
+
|
70 |
+
self.proj_out = torch.nn.Linear(inner_dim, in_channels)
|
71 |
+
|
72 |
+
def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1):
|
73 |
+
batch, _, height, width = hidden_states.shape
|
74 |
+
residual = hidden_states
|
75 |
+
|
76 |
+
hidden_states = self.norm(hidden_states)
|
77 |
+
inner_dim = hidden_states.shape[1]
|
78 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
79 |
+
hidden_states = self.proj_in(hidden_states)
|
80 |
+
|
81 |
+
for block in self.transformer_blocks:
|
82 |
+
hidden_states = block(
|
83 |
+
hidden_states,
|
84 |
+
batch_size=batch_size
|
85 |
+
)
|
86 |
+
|
87 |
+
hidden_states = self.proj_out(hidden_states)
|
88 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
89 |
+
hidden_states = hidden_states + residual
|
90 |
+
|
91 |
+
return hidden_states, time_emb, text_emb, res_stack
|
92 |
+
|
93 |
+
|
94 |
+
class SDMotionModel(torch.nn.Module):
|
95 |
+
def __init__(self):
|
96 |
+
super().__init__()
|
97 |
+
self.motion_modules = torch.nn.ModuleList([
|
98 |
+
TemporalBlock(8, 40, 320, eps=1e-6),
|
99 |
+
TemporalBlock(8, 40, 320, eps=1e-6),
|
100 |
+
TemporalBlock(8, 80, 640, eps=1e-6),
|
101 |
+
TemporalBlock(8, 80, 640, eps=1e-6),
|
102 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
103 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
104 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
105 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
106 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
107 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
108 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
109 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
110 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
111 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
112 |
+
TemporalBlock(8, 160, 1280, eps=1e-6),
|
113 |
+
TemporalBlock(8, 80, 640, eps=1e-6),
|
114 |
+
TemporalBlock(8, 80, 640, eps=1e-6),
|
115 |
+
TemporalBlock(8, 80, 640, eps=1e-6),
|
116 |
+
TemporalBlock(8, 40, 320, eps=1e-6),
|
117 |
+
TemporalBlock(8, 40, 320, eps=1e-6),
|
118 |
+
TemporalBlock(8, 40, 320, eps=1e-6),
|
119 |
+
])
|
120 |
+
self.call_block_id = {
|
121 |
+
1: 0,
|
122 |
+
4: 1,
|
123 |
+
9: 2,
|
124 |
+
12: 3,
|
125 |
+
17: 4,
|
126 |
+
20: 5,
|
127 |
+
24: 6,
|
128 |
+
26: 7,
|
129 |
+
29: 8,
|
130 |
+
32: 9,
|
131 |
+
34: 10,
|
132 |
+
36: 11,
|
133 |
+
40: 12,
|
134 |
+
43: 13,
|
135 |
+
46: 14,
|
136 |
+
50: 15,
|
137 |
+
53: 16,
|
138 |
+
56: 17,
|
139 |
+
60: 18,
|
140 |
+
63: 19,
|
141 |
+
66: 20
|
142 |
+
}
|
143 |
+
|
144 |
+
def forward(self):
|
145 |
+
pass
|
146 |
+
|
147 |
+
def state_dict_converter(self):
|
148 |
+
return SDMotionModelStateDictConverter()
|
149 |
+
|
150 |
+
|
151 |
+
class SDMotionModelStateDictConverter:
|
152 |
+
def __init__(self):
|
153 |
+
pass
|
154 |
+
|
155 |
+
def from_diffusers(self, state_dict):
|
156 |
+
rename_dict = {
|
157 |
+
"norm": "norm",
|
158 |
+
"proj_in": "proj_in",
|
159 |
+
"transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q",
|
160 |
+
"transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k",
|
161 |
+
"transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v",
|
162 |
+
"transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out",
|
163 |
+
"transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1",
|
164 |
+
"transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q",
|
165 |
+
"transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k",
|
166 |
+
"transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v",
|
167 |
+
"transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out",
|
168 |
+
"transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2",
|
169 |
+
"transformer_blocks.0.norms.0": "transformer_blocks.0.norm1",
|
170 |
+
"transformer_blocks.0.norms.1": "transformer_blocks.0.norm2",
|
171 |
+
"transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj",
|
172 |
+
"transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff",
|
173 |
+
"transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3",
|
174 |
+
"proj_out": "proj_out",
|
175 |
+
}
|
176 |
+
name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")])
|
177 |
+
name_list += sorted([i for i in state_dict if i.startswith("mid_block.")])
|
178 |
+
name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")])
|
179 |
+
state_dict_ = {}
|
180 |
+
last_prefix, module_id = "", -1
|
181 |
+
for name in name_list:
|
182 |
+
names = name.split(".")
|
183 |
+
prefix_index = names.index("temporal_transformer") + 1
|
184 |
+
prefix = ".".join(names[:prefix_index])
|
185 |
+
if prefix != last_prefix:
|
186 |
+
last_prefix = prefix
|
187 |
+
module_id += 1
|
188 |
+
middle_name = ".".join(names[prefix_index:-1])
|
189 |
+
suffix = names[-1]
|
190 |
+
if "pos_encoder" in names:
|
191 |
+
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]])
|
192 |
+
else:
|
193 |
+
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
|
194 |
+
state_dict_[rename] = state_dict[name]
|
195 |
+
return state_dict_
|
196 |
+
|
197 |
+
def from_civitai(self, state_dict):
|
198 |
+
return self.from_diffusers(state_dict)
|
diffsynth/models/sd_text_encoder.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .attention import Attention
|
3 |
+
|
4 |
+
|
5 |
+
class CLIPEncoderLayer(torch.nn.Module):
|
6 |
+
def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
|
7 |
+
super().__init__()
|
8 |
+
self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
|
9 |
+
self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
|
10 |
+
self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
|
11 |
+
self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
|
12 |
+
self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
|
13 |
+
|
14 |
+
self.use_quick_gelu = use_quick_gelu
|
15 |
+
|
16 |
+
def quickGELU(self, x):
|
17 |
+
return x * torch.sigmoid(1.702 * x)
|
18 |
+
|
19 |
+
def forward(self, hidden_states, attn_mask=None):
|
20 |
+
residual = hidden_states
|
21 |
+
|
22 |
+
hidden_states = self.layer_norm1(hidden_states)
|
23 |
+
hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
|
24 |
+
hidden_states = residual + hidden_states
|
25 |
+
|
26 |
+
residual = hidden_states
|
27 |
+
hidden_states = self.layer_norm2(hidden_states)
|
28 |
+
hidden_states = self.fc1(hidden_states)
|
29 |
+
if self.use_quick_gelu:
|
30 |
+
hidden_states = self.quickGELU(hidden_states)
|
31 |
+
else:
|
32 |
+
hidden_states = torch.nn.functional.gelu(hidden_states)
|
33 |
+
hidden_states = self.fc2(hidden_states)
|
34 |
+
hidden_states = residual + hidden_states
|
35 |
+
|
36 |
+
return hidden_states
|
37 |
+
|
38 |
+
|
39 |
+
class SDTextEncoder(torch.nn.Module):
|
40 |
+
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
|
41 |
+
super().__init__()
|
42 |
+
|
43 |
+
# token_embedding
|
44 |
+
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
45 |
+
|
46 |
+
# position_embeds (This is a fixed tensor)
|
47 |
+
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
48 |
+
|
49 |
+
# encoders
|
50 |
+
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
51 |
+
|
52 |
+
# attn_mask
|
53 |
+
self.attn_mask = self.attention_mask(max_position_embeddings)
|
54 |
+
|
55 |
+
# final_layer_norm
|
56 |
+
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
57 |
+
|
58 |
+
def attention_mask(self, length):
|
59 |
+
mask = torch.empty(length, length)
|
60 |
+
mask.fill_(float("-inf"))
|
61 |
+
mask.triu_(1)
|
62 |
+
return mask
|
63 |
+
|
64 |
+
def forward(self, input_ids, clip_skip=1):
|
65 |
+
embeds = self.token_embedding(input_ids) + self.position_embeds
|
66 |
+
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
67 |
+
for encoder_id, encoder in enumerate(self.encoders):
|
68 |
+
embeds = encoder(embeds, attn_mask=attn_mask)
|
69 |
+
if encoder_id + clip_skip == len(self.encoders):
|
70 |
+
break
|
71 |
+
embeds = self.final_layer_norm(embeds)
|
72 |
+
return embeds
|
73 |
+
|
74 |
+
def state_dict_converter(self):
|
75 |
+
return SDTextEncoderStateDictConverter()
|
76 |
+
|
77 |
+
|
78 |
+
class SDTextEncoderStateDictConverter:
|
79 |
+
def __init__(self):
|
80 |
+
pass
|
81 |
+
|
82 |
+
def from_diffusers(self, state_dict):
|
83 |
+
rename_dict = {
|
84 |
+
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
85 |
+
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
86 |
+
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
87 |
+
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
|
88 |
+
}
|
89 |
+
attn_rename_dict = {
|
90 |
+
"self_attn.q_proj": "attn.to_q",
|
91 |
+
"self_attn.k_proj": "attn.to_k",
|
92 |
+
"self_attn.v_proj": "attn.to_v",
|
93 |
+
"self_attn.out_proj": "attn.to_out",
|
94 |
+
"layer_norm1": "layer_norm1",
|
95 |
+
"layer_norm2": "layer_norm2",
|
96 |
+
"mlp.fc1": "fc1",
|
97 |
+
"mlp.fc2": "fc2",
|
98 |
+
}
|
99 |
+
state_dict_ = {}
|
100 |
+
for name in state_dict:
|
101 |
+
if name in rename_dict:
|
102 |
+
param = state_dict[name]
|
103 |
+
if name == "text_model.embeddings.position_embedding.weight":
|
104 |
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
105 |
+
state_dict_[rename_dict[name]] = param
|
106 |
+
elif name.startswith("text_model.encoder.layers."):
|
107 |
+
param = state_dict[name]
|
108 |
+
names = name.split(".")
|
109 |
+
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
110 |
+
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
111 |
+
state_dict_[name_] = param
|
112 |
+
return state_dict_
|
113 |
+
|
114 |
+
def from_civitai(self, state_dict):
|
115 |
+
rename_dict = {
|
116 |
+
"cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
117 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
|
118 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
|
119 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
|
120 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
|
121 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
|
122 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
|
123 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
|
124 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
|
125 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
|
126 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
|
127 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
128 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
129 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
|
130 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
|
131 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
|
132 |
+
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
|
133 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
|
134 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
|
135 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
|
136 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
|
137 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
|
138 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
|
139 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
|
140 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
|
141 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
|
142 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
|
143 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
144 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
145 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
|
146 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
|
147 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
|
148 |
+
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
|
149 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
|
150 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
|
151 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
|
152 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
|
153 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
|
154 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
|
155 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
|
156 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
|
157 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
|
158 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
|
159 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
160 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
161 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
|
162 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
|
163 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
|
164 |
+
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
|
165 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias",
|
166 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight",
|
167 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias",
|
168 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight",
|
169 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias",
|
170 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight",
|
171 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias",
|
172 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight",
|
173 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias",
|
174 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight",
|
175 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias",
|
176 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight",
|
177 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias",
|
178 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight",
|
179 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias",
|
180 |
+
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight",
|
181 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
|
182 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
|
183 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
|
184 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
|
185 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
|
186 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
|
187 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
|
188 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
|
189 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
|
190 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
|
191 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
192 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
193 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
|
194 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
|
195 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
|
196 |
+
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
|
197 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
|
198 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
|
199 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
|
200 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
|
201 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
|
202 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
|
203 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
|
204 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
|
205 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
|
206 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
|
207 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
208 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
209 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
|
210 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
|
211 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
|
212 |
+
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
|
213 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
|
214 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
|
215 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
|
216 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
|
217 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
|
218 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
|
219 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
|
220 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
|
221 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
|
222 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
|
223 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
224 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
225 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
|
226 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
|
227 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
|
228 |
+
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
|
229 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
|
230 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
|
231 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
|
232 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
|
233 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
|
234 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
|
235 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
|
236 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
|
237 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
|
238 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
|
239 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
240 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
241 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
|
242 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
|
243 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
|
244 |
+
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
|
245 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
|
246 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
|
247 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
|
248 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
|
249 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
|
250 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
|
251 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
|
252 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
|
253 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
|
254 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
|
255 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
256 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
257 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
|
258 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
|
259 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
|
260 |
+
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
|
261 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
|
262 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
|
263 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
|
264 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
|
265 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
|
266 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
|
267 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
|
268 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
|
269 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
|
270 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
|
271 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
272 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
273 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
|
274 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
|
275 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
|
276 |
+
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
|
277 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
|
278 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
|
279 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
|
280 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
|
281 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
|
282 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
|
283 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
|
284 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
|
285 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
|
286 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
|
287 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
288 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
289 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
|
290 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
|
291 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
|
292 |
+
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
|
293 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
|
294 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
|
295 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
|
296 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
|
297 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
|
298 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
|
299 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
|
300 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
|
301 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
|
302 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
|
303 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
304 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
305 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
|
306 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
|
307 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
|
308 |
+
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
|
309 |
+
"cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias",
|
310 |
+
"cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
311 |
+
"cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds"
|
312 |
+
}
|
313 |
+
state_dict_ = {}
|
314 |
+
for name in state_dict:
|
315 |
+
if name in rename_dict:
|
316 |
+
param = state_dict[name]
|
317 |
+
if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight":
|
318 |
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
319 |
+
state_dict_[rename_dict[name]] = param
|
320 |
+
return state_dict_
|
diffsynth/models/sd_unet.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
diffsynth/models/sd_vae_decoder.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .attention import Attention
|
3 |
+
from .sd_unet import ResnetBlock, UpSampler
|
4 |
+
from .tiler import TileWorker
|
5 |
+
|
6 |
+
|
7 |
+
class VAEAttentionBlock(torch.nn.Module):
|
8 |
+
|
9 |
+
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
|
10 |
+
super().__init__()
|
11 |
+
inner_dim = num_attention_heads * attention_head_dim
|
12 |
+
|
13 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
14 |
+
|
15 |
+
self.transformer_blocks = torch.nn.ModuleList([
|
16 |
+
Attention(
|
17 |
+
inner_dim,
|
18 |
+
num_attention_heads,
|
19 |
+
attention_head_dim,
|
20 |
+
bias_q=True,
|
21 |
+
bias_kv=True,
|
22 |
+
bias_out=True
|
23 |
+
)
|
24 |
+
for d in range(num_layers)
|
25 |
+
])
|
26 |
+
|
27 |
+
def forward(self, hidden_states, time_emb, text_emb, res_stack):
|
28 |
+
batch, _, height, width = hidden_states.shape
|
29 |
+
residual = hidden_states
|
30 |
+
|
31 |
+
hidden_states = self.norm(hidden_states)
|
32 |
+
inner_dim = hidden_states.shape[1]
|
33 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
34 |
+
|
35 |
+
for block in self.transformer_blocks:
|
36 |
+
hidden_states = block(hidden_states)
|
37 |
+
|
38 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
39 |
+
hidden_states = hidden_states + residual
|
40 |
+
|
41 |
+
return hidden_states, time_emb, text_emb, res_stack
|
42 |
+
|
43 |
+
|
44 |
+
class SDVAEDecoder(torch.nn.Module):
|
45 |
+
def __init__(self):
|
46 |
+
super().__init__()
|
47 |
+
self.scaling_factor = 0.18215
|
48 |
+
self.post_quant_conv = torch.nn.Conv2d(4, 4, kernel_size=1)
|
49 |
+
self.conv_in = torch.nn.Conv2d(4, 512, kernel_size=3, padding=1)
|
50 |
+
|
51 |
+
self.blocks = torch.nn.ModuleList([
|
52 |
+
# UNetMidBlock2D
|
53 |
+
ResnetBlock(512, 512, eps=1e-6),
|
54 |
+
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
55 |
+
ResnetBlock(512, 512, eps=1e-6),
|
56 |
+
# UpDecoderBlock2D
|
57 |
+
ResnetBlock(512, 512, eps=1e-6),
|
58 |
+
ResnetBlock(512, 512, eps=1e-6),
|
59 |
+
ResnetBlock(512, 512, eps=1e-6),
|
60 |
+
UpSampler(512),
|
61 |
+
# UpDecoderBlock2D
|
62 |
+
ResnetBlock(512, 512, eps=1e-6),
|
63 |
+
ResnetBlock(512, 512, eps=1e-6),
|
64 |
+
ResnetBlock(512, 512, eps=1e-6),
|
65 |
+
UpSampler(512),
|
66 |
+
# UpDecoderBlock2D
|
67 |
+
ResnetBlock(512, 256, eps=1e-6),
|
68 |
+
ResnetBlock(256, 256, eps=1e-6),
|
69 |
+
ResnetBlock(256, 256, eps=1e-6),
|
70 |
+
UpSampler(256),
|
71 |
+
# UpDecoderBlock2D
|
72 |
+
ResnetBlock(256, 128, eps=1e-6),
|
73 |
+
ResnetBlock(128, 128, eps=1e-6),
|
74 |
+
ResnetBlock(128, 128, eps=1e-6),
|
75 |
+
])
|
76 |
+
|
77 |
+
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-5)
|
78 |
+
self.conv_act = torch.nn.SiLU()
|
79 |
+
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
80 |
+
|
81 |
+
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
82 |
+
hidden_states = TileWorker().tiled_forward(
|
83 |
+
lambda x: self.forward(x),
|
84 |
+
sample,
|
85 |
+
tile_size,
|
86 |
+
tile_stride,
|
87 |
+
tile_device=sample.device,
|
88 |
+
tile_dtype=sample.dtype
|
89 |
+
)
|
90 |
+
return hidden_states
|
91 |
+
|
92 |
+
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
93 |
+
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
94 |
+
if tiled:
|
95 |
+
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
96 |
+
|
97 |
+
# 1. pre-process
|
98 |
+
sample = sample / self.scaling_factor
|
99 |
+
hidden_states = self.post_quant_conv(sample)
|
100 |
+
hidden_states = self.conv_in(hidden_states)
|
101 |
+
time_emb = None
|
102 |
+
text_emb = None
|
103 |
+
res_stack = None
|
104 |
+
|
105 |
+
# 2. blocks
|
106 |
+
for i, block in enumerate(self.blocks):
|
107 |
+
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
108 |
+
|
109 |
+
# 3. output
|
110 |
+
hidden_states = self.conv_norm_out(hidden_states)
|
111 |
+
hidden_states = self.conv_act(hidden_states)
|
112 |
+
hidden_states = self.conv_out(hidden_states)
|
113 |
+
|
114 |
+
return hidden_states
|
115 |
+
|
116 |
+
def state_dict_converter(self):
|
117 |
+
return SDVAEDecoderStateDictConverter()
|
118 |
+
|
119 |
+
|
120 |
+
class SDVAEDecoderStateDictConverter:
|
121 |
+
def __init__(self):
|
122 |
+
pass
|
123 |
+
|
124 |
+
def from_diffusers(self, state_dict):
|
125 |
+
# architecture
|
126 |
+
block_types = [
|
127 |
+
'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock',
|
128 |
+
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
|
129 |
+
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
|
130 |
+
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
|
131 |
+
'ResnetBlock', 'ResnetBlock', 'ResnetBlock'
|
132 |
+
]
|
133 |
+
|
134 |
+
# Rename each parameter
|
135 |
+
local_rename_dict = {
|
136 |
+
"post_quant_conv": "post_quant_conv",
|
137 |
+
"decoder.conv_in": "conv_in",
|
138 |
+
"decoder.mid_block.attentions.0.group_norm": "blocks.1.norm",
|
139 |
+
"decoder.mid_block.attentions.0.to_q": "blocks.1.transformer_blocks.0.to_q",
|
140 |
+
"decoder.mid_block.attentions.0.to_k": "blocks.1.transformer_blocks.0.to_k",
|
141 |
+
"decoder.mid_block.attentions.0.to_v": "blocks.1.transformer_blocks.0.to_v",
|
142 |
+
"decoder.mid_block.attentions.0.to_out.0": "blocks.1.transformer_blocks.0.to_out",
|
143 |
+
"decoder.mid_block.resnets.0.norm1": "blocks.0.norm1",
|
144 |
+
"decoder.mid_block.resnets.0.conv1": "blocks.0.conv1",
|
145 |
+
"decoder.mid_block.resnets.0.norm2": "blocks.0.norm2",
|
146 |
+
"decoder.mid_block.resnets.0.conv2": "blocks.0.conv2",
|
147 |
+
"decoder.mid_block.resnets.1.norm1": "blocks.2.norm1",
|
148 |
+
"decoder.mid_block.resnets.1.conv1": "blocks.2.conv1",
|
149 |
+
"decoder.mid_block.resnets.1.norm2": "blocks.2.norm2",
|
150 |
+
"decoder.mid_block.resnets.1.conv2": "blocks.2.conv2",
|
151 |
+
"decoder.conv_norm_out": "conv_norm_out",
|
152 |
+
"decoder.conv_out": "conv_out",
|
153 |
+
}
|
154 |
+
name_list = sorted([name for name in state_dict])
|
155 |
+
rename_dict = {}
|
156 |
+
block_id = {"ResnetBlock": 2, "DownSampler": 2, "UpSampler": 2}
|
157 |
+
last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
|
158 |
+
for name in name_list:
|
159 |
+
names = name.split(".")
|
160 |
+
name_prefix = ".".join(names[:-1])
|
161 |
+
if name_prefix in local_rename_dict:
|
162 |
+
rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
|
163 |
+
elif name.startswith("decoder.up_blocks"):
|
164 |
+
block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
|
165 |
+
block_type_with_id = ".".join(names[:5])
|
166 |
+
if block_type_with_id != last_block_type_with_id[block_type]:
|
167 |
+
block_id[block_type] += 1
|
168 |
+
last_block_type_with_id[block_type] = block_type_with_id
|
169 |
+
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
170 |
+
block_id[block_type] += 1
|
171 |
+
block_type_with_id = ".".join(names[:5])
|
172 |
+
names = ["blocks", str(block_id[block_type])] + names[5:]
|
173 |
+
rename_dict[name] = ".".join(names)
|
174 |
+
|
175 |
+
# Convert state_dict
|
176 |
+
state_dict_ = {}
|
177 |
+
for name, param in state_dict.items():
|
178 |
+
if name in rename_dict:
|
179 |
+
state_dict_[rename_dict[name]] = param
|
180 |
+
return state_dict_
|
181 |
+
|
182 |
+
def from_civitai(self, state_dict):
|
183 |
+
rename_dict = {
|
184 |
+
"first_stage_model.decoder.conv_in.bias": "conv_in.bias",
|
185 |
+
"first_stage_model.decoder.conv_in.weight": "conv_in.weight",
|
186 |
+
"first_stage_model.decoder.conv_out.bias": "conv_out.bias",
|
187 |
+
"first_stage_model.decoder.conv_out.weight": "conv_out.weight",
|
188 |
+
"first_stage_model.decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
|
189 |
+
"first_stage_model.decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
|
190 |
+
"first_stage_model.decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
|
191 |
+
"first_stage_model.decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
|
192 |
+
"first_stage_model.decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
|
193 |
+
"first_stage_model.decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
|
194 |
+
"first_stage_model.decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
|
195 |
+
"first_stage_model.decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
|
196 |
+
"first_stage_model.decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
|
197 |
+
"first_stage_model.decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
|
198 |
+
"first_stage_model.decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
|
199 |
+
"first_stage_model.decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
|
200 |
+
"first_stage_model.decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
|
201 |
+
"first_stage_model.decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
|
202 |
+
"first_stage_model.decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
|
203 |
+
"first_stage_model.decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
|
204 |
+
"first_stage_model.decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
|
205 |
+
"first_stage_model.decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
|
206 |
+
"first_stage_model.decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
|
207 |
+
"first_stage_model.decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
|
208 |
+
"first_stage_model.decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
|
209 |
+
"first_stage_model.decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
|
210 |
+
"first_stage_model.decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
|
211 |
+
"first_stage_model.decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
|
212 |
+
"first_stage_model.decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
|
213 |
+
"first_stage_model.decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
|
214 |
+
"first_stage_model.decoder.norm_out.bias": "conv_norm_out.bias",
|
215 |
+
"first_stage_model.decoder.norm_out.weight": "conv_norm_out.weight",
|
216 |
+
"first_stage_model.decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
|
217 |
+
"first_stage_model.decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
|
218 |
+
"first_stage_model.decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
|
219 |
+
"first_stage_model.decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
|
220 |
+
"first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
|
221 |
+
"first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
|
222 |
+
"first_stage_model.decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
|
223 |
+
"first_stage_model.decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
|
224 |
+
"first_stage_model.decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
|
225 |
+
"first_stage_model.decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
|
226 |
+
"first_stage_model.decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
|
227 |
+
"first_stage_model.decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
|
228 |
+
"first_stage_model.decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
|
229 |
+
"first_stage_model.decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
|
230 |
+
"first_stage_model.decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
|
231 |
+
"first_stage_model.decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
|
232 |
+
"first_stage_model.decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
|
233 |
+
"first_stage_model.decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
|
234 |
+
"first_stage_model.decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
|
235 |
+
"first_stage_model.decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
|
236 |
+
"first_stage_model.decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
|
237 |
+
"first_stage_model.decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
|
238 |
+
"first_stage_model.decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
|
239 |
+
"first_stage_model.decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
|
240 |
+
"first_stage_model.decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
|
241 |
+
"first_stage_model.decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
|
242 |
+
"first_stage_model.decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
|
243 |
+
"first_stage_model.decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
|
244 |
+
"first_stage_model.decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
|
245 |
+
"first_stage_model.decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
|
246 |
+
"first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
|
247 |
+
"first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
|
248 |
+
"first_stage_model.decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
|
249 |
+
"first_stage_model.decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
|
250 |
+
"first_stage_model.decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
|
251 |
+
"first_stage_model.decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
|
252 |
+
"first_stage_model.decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
|
253 |
+
"first_stage_model.decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
|
254 |
+
"first_stage_model.decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
|
255 |
+
"first_stage_model.decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
|
256 |
+
"first_stage_model.decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
|
257 |
+
"first_stage_model.decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
|
258 |
+
"first_stage_model.decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
|
259 |
+
"first_stage_model.decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
|
260 |
+
"first_stage_model.decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
|
261 |
+
"first_stage_model.decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
|
262 |
+
"first_stage_model.decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
|
263 |
+
"first_stage_model.decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
|
264 |
+
"first_stage_model.decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
|
265 |
+
"first_stage_model.decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
|
266 |
+
"first_stage_model.decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
|
267 |
+
"first_stage_model.decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
|
268 |
+
"first_stage_model.decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
|
269 |
+
"first_stage_model.decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
|
270 |
+
"first_stage_model.decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
|
271 |
+
"first_stage_model.decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
|
272 |
+
"first_stage_model.decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
|
273 |
+
"first_stage_model.decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
|
274 |
+
"first_stage_model.decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
|
275 |
+
"first_stage_model.decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
|
276 |
+
"first_stage_model.decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
|
277 |
+
"first_stage_model.decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
|
278 |
+
"first_stage_model.decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
|
279 |
+
"first_stage_model.decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
|
280 |
+
"first_stage_model.decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
|
281 |
+
"first_stage_model.decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
|
282 |
+
"first_stage_model.decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
|
283 |
+
"first_stage_model.decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
|
284 |
+
"first_stage_model.decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
|
285 |
+
"first_stage_model.decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
|
286 |
+
"first_stage_model.decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
|
287 |
+
"first_stage_model.decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
|
288 |
+
"first_stage_model.decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
|
289 |
+
"first_stage_model.decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
|
290 |
+
"first_stage_model.decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
|
291 |
+
"first_stage_model.decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
|
292 |
+
"first_stage_model.decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
|
293 |
+
"first_stage_model.decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
|
294 |
+
"first_stage_model.decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
|
295 |
+
"first_stage_model.decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
|
296 |
+
"first_stage_model.decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
|
297 |
+
"first_stage_model.decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
|
298 |
+
"first_stage_model.decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
|
299 |
+
"first_stage_model.decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
|
300 |
+
"first_stage_model.decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
|
301 |
+
"first_stage_model.decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
|
302 |
+
"first_stage_model.decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
|
303 |
+
"first_stage_model.decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
|
304 |
+
"first_stage_model.decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
|
305 |
+
"first_stage_model.decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
|
306 |
+
"first_stage_model.decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
|
307 |
+
"first_stage_model.decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
|
308 |
+
"first_stage_model.decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
|
309 |
+
"first_stage_model.decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
|
310 |
+
"first_stage_model.decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
|
311 |
+
"first_stage_model.decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
|
312 |
+
"first_stage_model.decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
|
313 |
+
"first_stage_model.decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
|
314 |
+
"first_stage_model.decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
|
315 |
+
"first_stage_model.decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
|
316 |
+
"first_stage_model.decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
|
317 |
+
"first_stage_model.decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
|
318 |
+
"first_stage_model.decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
|
319 |
+
"first_stage_model.decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
|
320 |
+
"first_stage_model.decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
|
321 |
+
"first_stage_model.decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
|
322 |
+
"first_stage_model.post_quant_conv.bias": "post_quant_conv.bias",
|
323 |
+
"first_stage_model.post_quant_conv.weight": "post_quant_conv.weight",
|
324 |
+
}
|
325 |
+
state_dict_ = {}
|
326 |
+
for name in state_dict:
|
327 |
+
if name in rename_dict:
|
328 |
+
param = state_dict[name]
|
329 |
+
if "transformer_blocks" in rename_dict[name]:
|
330 |
+
param = param.squeeze()
|
331 |
+
state_dict_[rename_dict[name]] = param
|
332 |
+
return state_dict_
|
diffsynth/models/sd_vae_encoder.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .sd_unet import ResnetBlock, DownSampler
|
3 |
+
from .sd_vae_decoder import VAEAttentionBlock
|
4 |
+
from .tiler import TileWorker
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
|
8 |
+
class SDVAEEncoder(torch.nn.Module):
|
9 |
+
def __init__(self):
|
10 |
+
super().__init__()
|
11 |
+
self.scaling_factor = 0.18215
|
12 |
+
self.quant_conv = torch.nn.Conv2d(8, 8, kernel_size=1)
|
13 |
+
self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
|
14 |
+
|
15 |
+
self.blocks = torch.nn.ModuleList([
|
16 |
+
# DownEncoderBlock2D
|
17 |
+
ResnetBlock(128, 128, eps=1e-6),
|
18 |
+
ResnetBlock(128, 128, eps=1e-6),
|
19 |
+
DownSampler(128, padding=0, extra_padding=True),
|
20 |
+
# DownEncoderBlock2D
|
21 |
+
ResnetBlock(128, 256, eps=1e-6),
|
22 |
+
ResnetBlock(256, 256, eps=1e-6),
|
23 |
+
DownSampler(256, padding=0, extra_padding=True),
|
24 |
+
# DownEncoderBlock2D
|
25 |
+
ResnetBlock(256, 512, eps=1e-6),
|
26 |
+
ResnetBlock(512, 512, eps=1e-6),
|
27 |
+
DownSampler(512, padding=0, extra_padding=True),
|
28 |
+
# DownEncoderBlock2D
|
29 |
+
ResnetBlock(512, 512, eps=1e-6),
|
30 |
+
ResnetBlock(512, 512, eps=1e-6),
|
31 |
+
# UNetMidBlock2D
|
32 |
+
ResnetBlock(512, 512, eps=1e-6),
|
33 |
+
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
34 |
+
ResnetBlock(512, 512, eps=1e-6),
|
35 |
+
])
|
36 |
+
|
37 |
+
self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
|
38 |
+
self.conv_act = torch.nn.SiLU()
|
39 |
+
self.conv_out = torch.nn.Conv2d(512, 8, kernel_size=3, padding=1)
|
40 |
+
|
41 |
+
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
42 |
+
hidden_states = TileWorker().tiled_forward(
|
43 |
+
lambda x: self.forward(x),
|
44 |
+
sample,
|
45 |
+
tile_size,
|
46 |
+
tile_stride,
|
47 |
+
tile_device=sample.device,
|
48 |
+
tile_dtype=sample.dtype
|
49 |
+
)
|
50 |
+
return hidden_states
|
51 |
+
|
52 |
+
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
53 |
+
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
54 |
+
if tiled:
|
55 |
+
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
56 |
+
|
57 |
+
# 1. pre-process
|
58 |
+
hidden_states = self.conv_in(sample)
|
59 |
+
time_emb = None
|
60 |
+
text_emb = None
|
61 |
+
res_stack = None
|
62 |
+
|
63 |
+
# 2. blocks
|
64 |
+
for i, block in enumerate(self.blocks):
|
65 |
+
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
66 |
+
|
67 |
+
# 3. output
|
68 |
+
hidden_states = self.conv_norm_out(hidden_states)
|
69 |
+
hidden_states = self.conv_act(hidden_states)
|
70 |
+
hidden_states = self.conv_out(hidden_states)
|
71 |
+
hidden_states = self.quant_conv(hidden_states)
|
72 |
+
hidden_states = hidden_states[:, :4]
|
73 |
+
hidden_states *= self.scaling_factor
|
74 |
+
|
75 |
+
return hidden_states
|
76 |
+
|
77 |
+
def encode_video(self, sample, batch_size=8):
|
78 |
+
B = sample.shape[0]
|
79 |
+
hidden_states = []
|
80 |
+
|
81 |
+
for i in range(0, sample.shape[2], batch_size):
|
82 |
+
|
83 |
+
j = min(i + batch_size, sample.shape[2])
|
84 |
+
sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
|
85 |
+
|
86 |
+
hidden_states_batch = self(sample_batch)
|
87 |
+
hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
|
88 |
+
|
89 |
+
hidden_states.append(hidden_states_batch)
|
90 |
+
|
91 |
+
hidden_states = torch.concat(hidden_states, dim=2)
|
92 |
+
return hidden_states
|
93 |
+
|
94 |
+
def state_dict_converter(self):
|
95 |
+
return SDVAEEncoderStateDictConverter()
|
96 |
+
|
97 |
+
|
98 |
+
class SDVAEEncoderStateDictConverter:
|
99 |
+
def __init__(self):
|
100 |
+
pass
|
101 |
+
|
102 |
+
def from_diffusers(self, state_dict):
|
103 |
+
# architecture
|
104 |
+
block_types = [
|
105 |
+
'ResnetBlock', 'ResnetBlock', 'DownSampler',
|
106 |
+
'ResnetBlock', 'ResnetBlock', 'DownSampler',
|
107 |
+
'ResnetBlock', 'ResnetBlock', 'DownSampler',
|
108 |
+
'ResnetBlock', 'ResnetBlock',
|
109 |
+
'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock'
|
110 |
+
]
|
111 |
+
|
112 |
+
# Rename each parameter
|
113 |
+
local_rename_dict = {
|
114 |
+
"quant_conv": "quant_conv",
|
115 |
+
"encoder.conv_in": "conv_in",
|
116 |
+
"encoder.mid_block.attentions.0.group_norm": "blocks.12.norm",
|
117 |
+
"encoder.mid_block.attentions.0.to_q": "blocks.12.transformer_blocks.0.to_q",
|
118 |
+
"encoder.mid_block.attentions.0.to_k": "blocks.12.transformer_blocks.0.to_k",
|
119 |
+
"encoder.mid_block.attentions.0.to_v": "blocks.12.transformer_blocks.0.to_v",
|
120 |
+
"encoder.mid_block.attentions.0.to_out.0": "blocks.12.transformer_blocks.0.to_out",
|
121 |
+
"encoder.mid_block.resnets.0.norm1": "blocks.11.norm1",
|
122 |
+
"encoder.mid_block.resnets.0.conv1": "blocks.11.conv1",
|
123 |
+
"encoder.mid_block.resnets.0.norm2": "blocks.11.norm2",
|
124 |
+
"encoder.mid_block.resnets.0.conv2": "blocks.11.conv2",
|
125 |
+
"encoder.mid_block.resnets.1.norm1": "blocks.13.norm1",
|
126 |
+
"encoder.mid_block.resnets.1.conv1": "blocks.13.conv1",
|
127 |
+
"encoder.mid_block.resnets.1.norm2": "blocks.13.norm2",
|
128 |
+
"encoder.mid_block.resnets.1.conv2": "blocks.13.conv2",
|
129 |
+
"encoder.conv_norm_out": "conv_norm_out",
|
130 |
+
"encoder.conv_out": "conv_out",
|
131 |
+
}
|
132 |
+
name_list = sorted([name for name in state_dict])
|
133 |
+
rename_dict = {}
|
134 |
+
block_id = {"ResnetBlock": -1, "DownSampler": -1, "UpSampler": -1}
|
135 |
+
last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
|
136 |
+
for name in name_list:
|
137 |
+
names = name.split(".")
|
138 |
+
name_prefix = ".".join(names[:-1])
|
139 |
+
if name_prefix in local_rename_dict:
|
140 |
+
rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
|
141 |
+
elif name.startswith("encoder.down_blocks"):
|
142 |
+
block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
|
143 |
+
block_type_with_id = ".".join(names[:5])
|
144 |
+
if block_type_with_id != last_block_type_with_id[block_type]:
|
145 |
+
block_id[block_type] += 1
|
146 |
+
last_block_type_with_id[block_type] = block_type_with_id
|
147 |
+
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
148 |
+
block_id[block_type] += 1
|
149 |
+
block_type_with_id = ".".join(names[:5])
|
150 |
+
names = ["blocks", str(block_id[block_type])] + names[5:]
|
151 |
+
rename_dict[name] = ".".join(names)
|
152 |
+
|
153 |
+
# Convert state_dict
|
154 |
+
state_dict_ = {}
|
155 |
+
for name, param in state_dict.items():
|
156 |
+
if name in rename_dict:
|
157 |
+
state_dict_[rename_dict[name]] = param
|
158 |
+
return state_dict_
|
159 |
+
|
160 |
+
def from_civitai(self, state_dict):
|
161 |
+
rename_dict = {
|
162 |
+
"first_stage_model.encoder.conv_in.bias": "conv_in.bias",
|
163 |
+
"first_stage_model.encoder.conv_in.weight": "conv_in.weight",
|
164 |
+
"first_stage_model.encoder.conv_out.bias": "conv_out.bias",
|
165 |
+
"first_stage_model.encoder.conv_out.weight": "conv_out.weight",
|
166 |
+
"first_stage_model.encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
|
167 |
+
"first_stage_model.encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
|
168 |
+
"first_stage_model.encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
|
169 |
+
"first_stage_model.encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
|
170 |
+
"first_stage_model.encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
|
171 |
+
"first_stage_model.encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
|
172 |
+
"first_stage_model.encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
|
173 |
+
"first_stage_model.encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
|
174 |
+
"first_stage_model.encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
|
175 |
+
"first_stage_model.encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
|
176 |
+
"first_stage_model.encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
|
177 |
+
"first_stage_model.encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
|
178 |
+
"first_stage_model.encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
|
179 |
+
"first_stage_model.encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
|
180 |
+
"first_stage_model.encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
|
181 |
+
"first_stage_model.encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
|
182 |
+
"first_stage_model.encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
|
183 |
+
"first_stage_model.encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
|
184 |
+
"first_stage_model.encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
|
185 |
+
"first_stage_model.encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
|
186 |
+
"first_stage_model.encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
|
187 |
+
"first_stage_model.encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
|
188 |
+
"first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
|
189 |
+
"first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
|
190 |
+
"first_stage_model.encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
|
191 |
+
"first_stage_model.encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
|
192 |
+
"first_stage_model.encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
|
193 |
+
"first_stage_model.encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
|
194 |
+
"first_stage_model.encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
|
195 |
+
"first_stage_model.encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
|
196 |
+
"first_stage_model.encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
|
197 |
+
"first_stage_model.encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
|
198 |
+
"first_stage_model.encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
|
199 |
+
"first_stage_model.encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
|
200 |
+
"first_stage_model.encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
|
201 |
+
"first_stage_model.encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
|
202 |
+
"first_stage_model.encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
|
203 |
+
"first_stage_model.encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
|
204 |
+
"first_stage_model.encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
|
205 |
+
"first_stage_model.encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
|
206 |
+
"first_stage_model.encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
|
207 |
+
"first_stage_model.encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
|
208 |
+
"first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
|
209 |
+
"first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
|
210 |
+
"first_stage_model.encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
|
211 |
+
"first_stage_model.encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
|
212 |
+
"first_stage_model.encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
|
213 |
+
"first_stage_model.encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
|
214 |
+
"first_stage_model.encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
|
215 |
+
"first_stage_model.encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
|
216 |
+
"first_stage_model.encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
|
217 |
+
"first_stage_model.encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
|
218 |
+
"first_stage_model.encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
|
219 |
+
"first_stage_model.encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
|
220 |
+
"first_stage_model.encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
|
221 |
+
"first_stage_model.encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
|
222 |
+
"first_stage_model.encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
|
223 |
+
"first_stage_model.encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
|
224 |
+
"first_stage_model.encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
|
225 |
+
"first_stage_model.encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
|
226 |
+
"first_stage_model.encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
|
227 |
+
"first_stage_model.encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
|
228 |
+
"first_stage_model.encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
|
229 |
+
"first_stage_model.encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
|
230 |
+
"first_stage_model.encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
|
231 |
+
"first_stage_model.encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
|
232 |
+
"first_stage_model.encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
|
233 |
+
"first_stage_model.encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
|
234 |
+
"first_stage_model.encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
|
235 |
+
"first_stage_model.encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
|
236 |
+
"first_stage_model.encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
|
237 |
+
"first_stage_model.encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
|
238 |
+
"first_stage_model.encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
|
239 |
+
"first_stage_model.encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
|
240 |
+
"first_stage_model.encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
|
241 |
+
"first_stage_model.encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
|
242 |
+
"first_stage_model.encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
|
243 |
+
"first_stage_model.encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
|
244 |
+
"first_stage_model.encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
|
245 |
+
"first_stage_model.encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
|
246 |
+
"first_stage_model.encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
|
247 |
+
"first_stage_model.encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
|
248 |
+
"first_stage_model.encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
|
249 |
+
"first_stage_model.encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
|
250 |
+
"first_stage_model.encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
|
251 |
+
"first_stage_model.encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
|
252 |
+
"first_stage_model.encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
|
253 |
+
"first_stage_model.encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
|
254 |
+
"first_stage_model.encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
|
255 |
+
"first_stage_model.encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
|
256 |
+
"first_stage_model.encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
|
257 |
+
"first_stage_model.encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
|
258 |
+
"first_stage_model.encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
|
259 |
+
"first_stage_model.encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
|
260 |
+
"first_stage_model.encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
|
261 |
+
"first_stage_model.encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
|
262 |
+
"first_stage_model.encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
|
263 |
+
"first_stage_model.encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
|
264 |
+
"first_stage_model.encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
|
265 |
+
"first_stage_model.encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
|
266 |
+
"first_stage_model.encoder.norm_out.bias": "conv_norm_out.bias",
|
267 |
+
"first_stage_model.encoder.norm_out.weight": "conv_norm_out.weight",
|
268 |
+
"first_stage_model.quant_conv.bias": "quant_conv.bias",
|
269 |
+
"first_stage_model.quant_conv.weight": "quant_conv.weight",
|
270 |
+
}
|
271 |
+
state_dict_ = {}
|
272 |
+
for name in state_dict:
|
273 |
+
if name in rename_dict:
|
274 |
+
param = state_dict[name]
|
275 |
+
if "transformer_blocks" in rename_dict[name]:
|
276 |
+
param = param.squeeze()
|
277 |
+
state_dict_[rename_dict[name]] = param
|
278 |
+
return state_dict_
|
diffsynth/models/sdxl_ipadapter.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .svd_image_encoder import SVDImageEncoder
|
2 |
+
from transformers import CLIPImageProcessor
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class IpAdapterXLCLIPImageEmbedder(SVDImageEncoder):
|
7 |
+
def __init__(self):
|
8 |
+
super().__init__(embed_dim=1664, encoder_intermediate_size=8192, projection_dim=1280, num_encoder_layers=48, num_heads=16, head_dim=104)
|
9 |
+
self.image_processor = CLIPImageProcessor()
|
10 |
+
|
11 |
+
def forward(self, image):
|
12 |
+
pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values
|
13 |
+
pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype)
|
14 |
+
return super().forward(pixel_values)
|
15 |
+
|
16 |
+
|
17 |
+
class IpAdapterImageProjModel(torch.nn.Module):
|
18 |
+
def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280, clip_extra_context_tokens=4):
|
19 |
+
super().__init__()
|
20 |
+
self.cross_attention_dim = cross_attention_dim
|
21 |
+
self.clip_extra_context_tokens = clip_extra_context_tokens
|
22 |
+
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
23 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
24 |
+
|
25 |
+
def forward(self, image_embeds):
|
26 |
+
clip_extra_context_tokens = self.proj(image_embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
|
27 |
+
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
28 |
+
return clip_extra_context_tokens
|
29 |
+
|
30 |
+
|
31 |
+
class IpAdapterModule(torch.nn.Module):
|
32 |
+
def __init__(self, input_dim, output_dim):
|
33 |
+
super().__init__()
|
34 |
+
self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
35 |
+
self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
36 |
+
|
37 |
+
def forward(self, hidden_states):
|
38 |
+
ip_k = self.to_k_ip(hidden_states)
|
39 |
+
ip_v = self.to_v_ip(hidden_states)
|
40 |
+
return ip_k, ip_v
|
41 |
+
|
42 |
+
|
43 |
+
class SDXLIpAdapter(torch.nn.Module):
|
44 |
+
def __init__(self):
|
45 |
+
super().__init__()
|
46 |
+
shape_list = [(2048, 640)] * 4 + [(2048, 1280)] * 50 + [(2048, 640)] * 6 + [(2048, 1280)] * 10
|
47 |
+
self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list])
|
48 |
+
self.image_proj = IpAdapterImageProjModel()
|
49 |
+
self.set_full_adapter()
|
50 |
+
|
51 |
+
def set_full_adapter(self):
|
52 |
+
map_list = sum([
|
53 |
+
[(7, i) for i in range(2)],
|
54 |
+
[(10, i) for i in range(2)],
|
55 |
+
[(15, i) for i in range(10)],
|
56 |
+
[(18, i) for i in range(10)],
|
57 |
+
[(25, i) for i in range(10)],
|
58 |
+
[(28, i) for i in range(10)],
|
59 |
+
[(31, i) for i in range(10)],
|
60 |
+
[(35, i) for i in range(2)],
|
61 |
+
[(38, i) for i in range(2)],
|
62 |
+
[(41, i) for i in range(2)],
|
63 |
+
[(21, i) for i in range(10)],
|
64 |
+
], [])
|
65 |
+
self.call_block_id = {i: j for j, i in enumerate(map_list)}
|
66 |
+
|
67 |
+
def set_less_adapter(self):
|
68 |
+
map_list = sum([
|
69 |
+
[(7, i) for i in range(2)],
|
70 |
+
[(10, i) for i in range(2)],
|
71 |
+
[(15, i) for i in range(10)],
|
72 |
+
[(18, i) for i in range(10)],
|
73 |
+
[(25, i) for i in range(10)],
|
74 |
+
[(28, i) for i in range(10)],
|
75 |
+
[(31, i) for i in range(10)],
|
76 |
+
[(35, i) for i in range(2)],
|
77 |
+
[(38, i) for i in range(2)],
|
78 |
+
[(41, i) for i in range(2)],
|
79 |
+
[(21, i) for i in range(10)],
|
80 |
+
], [])
|
81 |
+
self.call_block_id = {i: j for j, i in enumerate(map_list) if j>=34 and j<44}
|
82 |
+
|
83 |
+
def forward(self, hidden_states, scale=1.0):
|
84 |
+
hidden_states = self.image_proj(hidden_states)
|
85 |
+
hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
|
86 |
+
ip_kv_dict = {}
|
87 |
+
for (block_id, transformer_id) in self.call_block_id:
|
88 |
+
ipadapter_id = self.call_block_id[(block_id, transformer_id)]
|
89 |
+
ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
|
90 |
+
if block_id not in ip_kv_dict:
|
91 |
+
ip_kv_dict[block_id] = {}
|
92 |
+
ip_kv_dict[block_id][transformer_id] = {
|
93 |
+
"ip_k": ip_k,
|
94 |
+
"ip_v": ip_v,
|
95 |
+
"scale": scale
|
96 |
+
}
|
97 |
+
return ip_kv_dict
|
98 |
+
|
99 |
+
def state_dict_converter(self):
|
100 |
+
return SDXLIpAdapterStateDictConverter()
|
101 |
+
|
102 |
+
|
103 |
+
class SDXLIpAdapterStateDictConverter:
|
104 |
+
def __init__(self):
|
105 |
+
pass
|
106 |
+
|
107 |
+
def from_diffusers(self, state_dict):
|
108 |
+
state_dict_ = {}
|
109 |
+
for name in state_dict["ip_adapter"]:
|
110 |
+
names = name.split(".")
|
111 |
+
layer_id = str(int(names[0]) // 2)
|
112 |
+
name_ = ".".join(["ipadapter_modules"] + [layer_id] + names[1:])
|
113 |
+
state_dict_[name_] = state_dict["ip_adapter"][name]
|
114 |
+
for name in state_dict["image_proj"]:
|
115 |
+
name_ = "image_proj." + name
|
116 |
+
state_dict_[name_] = state_dict["image_proj"][name]
|
117 |
+
return state_dict_
|
118 |
+
|
119 |
+
def from_civitai(self, state_dict):
|
120 |
+
return self.from_diffusers(state_dict)
|
121 |
+
|
diffsynth/models/sdxl_motion.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .sd_motion import TemporalBlock
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
class SDXLMotionModel(torch.nn.Module):
|
7 |
+
def __init__(self):
|
8 |
+
super().__init__()
|
9 |
+
self.motion_modules = torch.nn.ModuleList([
|
10 |
+
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
11 |
+
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
12 |
+
|
13 |
+
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
14 |
+
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
15 |
+
|
16 |
+
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
17 |
+
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
18 |
+
|
19 |
+
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
20 |
+
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
21 |
+
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
22 |
+
|
23 |
+
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
24 |
+
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
25 |
+
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
26 |
+
|
27 |
+
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
28 |
+
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
29 |
+
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
30 |
+
])
|
31 |
+
self.call_block_id = {
|
32 |
+
0: 0,
|
33 |
+
2: 1,
|
34 |
+
7: 2,
|
35 |
+
10: 3,
|
36 |
+
15: 4,
|
37 |
+
18: 5,
|
38 |
+
25: 6,
|
39 |
+
28: 7,
|
40 |
+
31: 8,
|
41 |
+
35: 9,
|
42 |
+
38: 10,
|
43 |
+
41: 11,
|
44 |
+
44: 12,
|
45 |
+
46: 13,
|
46 |
+
48: 14,
|
47 |
+
}
|
48 |
+
|
49 |
+
def forward(self):
|
50 |
+
pass
|
51 |
+
|
52 |
+
def state_dict_converter(self):
|
53 |
+
return SDMotionModelStateDictConverter()
|
54 |
+
|
55 |
+
|
56 |
+
class SDMotionModelStateDictConverter:
|
57 |
+
def __init__(self):
|
58 |
+
pass
|
59 |
+
|
60 |
+
def from_diffusers(self, state_dict):
|
61 |
+
rename_dict = {
|
62 |
+
"norm": "norm",
|
63 |
+
"proj_in": "proj_in",
|
64 |
+
"transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q",
|
65 |
+
"transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k",
|
66 |
+
"transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v",
|
67 |
+
"transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out",
|
68 |
+
"transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1",
|
69 |
+
"transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q",
|
70 |
+
"transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k",
|
71 |
+
"transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v",
|
72 |
+
"transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out",
|
73 |
+
"transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2",
|
74 |
+
"transformer_blocks.0.norms.0": "transformer_blocks.0.norm1",
|
75 |
+
"transformer_blocks.0.norms.1": "transformer_blocks.0.norm2",
|
76 |
+
"transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj",
|
77 |
+
"transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff",
|
78 |
+
"transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3",
|
79 |
+
"proj_out": "proj_out",
|
80 |
+
}
|
81 |
+
name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")])
|
82 |
+
name_list += sorted([i for i in state_dict if i.startswith("mid_block.")])
|
83 |
+
name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")])
|
84 |
+
state_dict_ = {}
|
85 |
+
last_prefix, module_id = "", -1
|
86 |
+
for name in name_list:
|
87 |
+
names = name.split(".")
|
88 |
+
prefix_index = names.index("temporal_transformer") + 1
|
89 |
+
prefix = ".".join(names[:prefix_index])
|
90 |
+
if prefix != last_prefix:
|
91 |
+
last_prefix = prefix
|
92 |
+
module_id += 1
|
93 |
+
middle_name = ".".join(names[prefix_index:-1])
|
94 |
+
suffix = names[-1]
|
95 |
+
if "pos_encoder" in names:
|
96 |
+
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]])
|
97 |
+
else:
|
98 |
+
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
|
99 |
+
state_dict_[rename] = state_dict[name]
|
100 |
+
return state_dict_
|
101 |
+
|
102 |
+
def from_civitai(self, state_dict):
|
103 |
+
return self.from_diffusers(state_dict)
|
diffsynth/models/sdxl_text_encoder.py
ADDED
@@ -0,0 +1,757 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .sd_text_encoder import CLIPEncoderLayer
|
3 |
+
|
4 |
+
|
5 |
+
class SDXLTextEncoder(torch.nn.Module):
|
6 |
+
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=11, encoder_intermediate_size=3072):
|
7 |
+
super().__init__()
|
8 |
+
|
9 |
+
# token_embedding
|
10 |
+
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
11 |
+
|
12 |
+
# position_embeds (This is a fixed tensor)
|
13 |
+
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
14 |
+
|
15 |
+
# encoders
|
16 |
+
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
17 |
+
|
18 |
+
# attn_mask
|
19 |
+
self.attn_mask = self.attention_mask(max_position_embeddings)
|
20 |
+
|
21 |
+
# The text encoder is different to that in Stable Diffusion 1.x.
|
22 |
+
# It does not include final_layer_norm.
|
23 |
+
|
24 |
+
def attention_mask(self, length):
|
25 |
+
mask = torch.empty(length, length)
|
26 |
+
mask.fill_(float("-inf"))
|
27 |
+
mask.triu_(1)
|
28 |
+
return mask
|
29 |
+
|
30 |
+
def forward(self, input_ids, clip_skip=1):
|
31 |
+
embeds = self.token_embedding(input_ids) + self.position_embeds
|
32 |
+
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
33 |
+
for encoder_id, encoder in enumerate(self.encoders):
|
34 |
+
embeds = encoder(embeds, attn_mask=attn_mask)
|
35 |
+
if encoder_id + clip_skip == len(self.encoders):
|
36 |
+
break
|
37 |
+
return embeds
|
38 |
+
|
39 |
+
def state_dict_converter(self):
|
40 |
+
return SDXLTextEncoderStateDictConverter()
|
41 |
+
|
42 |
+
|
43 |
+
class SDXLTextEncoder2(torch.nn.Module):
|
44 |
+
def __init__(self, embed_dim=1280, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=32, encoder_intermediate_size=5120):
|
45 |
+
super().__init__()
|
46 |
+
|
47 |
+
# token_embedding
|
48 |
+
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
49 |
+
|
50 |
+
# position_embeds (This is a fixed tensor)
|
51 |
+
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
52 |
+
|
53 |
+
# encoders
|
54 |
+
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=20, head_dim=64, use_quick_gelu=False) for _ in range(num_encoder_layers)])
|
55 |
+
|
56 |
+
# attn_mask
|
57 |
+
self.attn_mask = self.attention_mask(max_position_embeddings)
|
58 |
+
|
59 |
+
# final_layer_norm
|
60 |
+
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
61 |
+
|
62 |
+
# text_projection
|
63 |
+
self.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False)
|
64 |
+
|
65 |
+
def attention_mask(self, length):
|
66 |
+
mask = torch.empty(length, length)
|
67 |
+
mask.fill_(float("-inf"))
|
68 |
+
mask.triu_(1)
|
69 |
+
return mask
|
70 |
+
|
71 |
+
def forward(self, input_ids, clip_skip=2):
|
72 |
+
embeds = self.token_embedding(input_ids) + self.position_embeds
|
73 |
+
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
74 |
+
for encoder_id, encoder in enumerate(self.encoders):
|
75 |
+
embeds = encoder(embeds, attn_mask=attn_mask)
|
76 |
+
if encoder_id + clip_skip == len(self.encoders):
|
77 |
+
hidden_states = embeds
|
78 |
+
embeds = self.final_layer_norm(embeds)
|
79 |
+
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
|
80 |
+
pooled_embeds = self.text_projection(pooled_embeds)
|
81 |
+
return pooled_embeds, hidden_states
|
82 |
+
|
83 |
+
def state_dict_converter(self):
|
84 |
+
return SDXLTextEncoder2StateDictConverter()
|
85 |
+
|
86 |
+
|
87 |
+
class SDXLTextEncoderStateDictConverter:
|
88 |
+
def __init__(self):
|
89 |
+
pass
|
90 |
+
|
91 |
+
def from_diffusers(self, state_dict):
|
92 |
+
rename_dict = {
|
93 |
+
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
94 |
+
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
95 |
+
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
96 |
+
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
|
97 |
+
}
|
98 |
+
attn_rename_dict = {
|
99 |
+
"self_attn.q_proj": "attn.to_q",
|
100 |
+
"self_attn.k_proj": "attn.to_k",
|
101 |
+
"self_attn.v_proj": "attn.to_v",
|
102 |
+
"self_attn.out_proj": "attn.to_out",
|
103 |
+
"layer_norm1": "layer_norm1",
|
104 |
+
"layer_norm2": "layer_norm2",
|
105 |
+
"mlp.fc1": "fc1",
|
106 |
+
"mlp.fc2": "fc2",
|
107 |
+
}
|
108 |
+
state_dict_ = {}
|
109 |
+
for name in state_dict:
|
110 |
+
if name in rename_dict:
|
111 |
+
param = state_dict[name]
|
112 |
+
if name == "text_model.embeddings.position_embedding.weight":
|
113 |
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
114 |
+
state_dict_[rename_dict[name]] = param
|
115 |
+
elif name.startswith("text_model.encoder.layers."):
|
116 |
+
param = state_dict[name]
|
117 |
+
names = name.split(".")
|
118 |
+
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
119 |
+
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
120 |
+
state_dict_[name_] = param
|
121 |
+
return state_dict_
|
122 |
+
|
123 |
+
def from_civitai(self, state_dict):
|
124 |
+
rename_dict = {
|
125 |
+
"conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight": "position_embeds",
|
126 |
+
"conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
127 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
|
128 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
|
129 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
|
130 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
|
131 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
|
132 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
|
133 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
|
134 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
|
135 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
|
136 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
|
137 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
138 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
139 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
|
140 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
|
141 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
|
142 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
|
143 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
|
144 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
|
145 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
|
146 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
|
147 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
|
148 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
|
149 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
|
150 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
|
151 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
|
152 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
|
153 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
154 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
155 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
|
156 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
|
157 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
|
158 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
|
159 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
|
160 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
|
161 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
|
162 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
|
163 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
|
164 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
|
165 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
|
166 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
|
167 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
|
168 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
|
169 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
170 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
171 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
|
172 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
|
173 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
|
174 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
|
175 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
|
176 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
|
177 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
|
178 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
|
179 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
|
180 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
|
181 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
|
182 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
|
183 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
|
184 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
|
185 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
186 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
187 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
|
188 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
|
189 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
|
190 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
|
191 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
|
192 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
|
193 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
|
194 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
|
195 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
|
196 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
|
197 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
|
198 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
|
199 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
|
200 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
|
201 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
202 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
203 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
|
204 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
|
205 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
|
206 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
|
207 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
|
208 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
|
209 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
|
210 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
|
211 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
|
212 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
|
213 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
|
214 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
|
215 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
|
216 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
|
217 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
218 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
219 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
|
220 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
|
221 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
|
222 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
|
223 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
|
224 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
|
225 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
|
226 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
|
227 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
|
228 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
|
229 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
|
230 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
|
231 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
|
232 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
|
233 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
234 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
235 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
|
236 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
|
237 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
|
238 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
|
239 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
|
240 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
|
241 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
|
242 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
|
243 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
|
244 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
|
245 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
|
246 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
|
247 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
|
248 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
|
249 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
250 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
251 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
|
252 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
|
253 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
|
254 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
|
255 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
|
256 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
|
257 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
|
258 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
|
259 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
|
260 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
|
261 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
|
262 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
|
263 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
|
264 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
|
265 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
266 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
267 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
|
268 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
|
269 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
|
270 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
|
271 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
|
272 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
|
273 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
|
274 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
|
275 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
|
276 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
|
277 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
|
278 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
|
279 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
|
280 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
|
281 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
282 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
283 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
|
284 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
|
285 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
|
286 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
|
287 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
|
288 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
|
289 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
|
290 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
|
291 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
|
292 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
|
293 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
|
294 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
|
295 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
|
296 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
|
297 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
298 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
299 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
|
300 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
|
301 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
|
302 |
+
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
|
303 |
+
}
|
304 |
+
state_dict_ = {}
|
305 |
+
for name in state_dict:
|
306 |
+
if name in rename_dict:
|
307 |
+
param = state_dict[name]
|
308 |
+
if name == "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight":
|
309 |
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
310 |
+
state_dict_[rename_dict[name]] = param
|
311 |
+
return state_dict_
|
312 |
+
|
313 |
+
|
314 |
+
class SDXLTextEncoder2StateDictConverter:
|
315 |
+
def __init__(self):
|
316 |
+
pass
|
317 |
+
|
318 |
+
def from_diffusers(self, state_dict):
|
319 |
+
rename_dict = {
|
320 |
+
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
321 |
+
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
322 |
+
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
323 |
+
"text_model.final_layer_norm.bias": "final_layer_norm.bias",
|
324 |
+
"text_projection.weight": "text_projection.weight"
|
325 |
+
}
|
326 |
+
attn_rename_dict = {
|
327 |
+
"self_attn.q_proj": "attn.to_q",
|
328 |
+
"self_attn.k_proj": "attn.to_k",
|
329 |
+
"self_attn.v_proj": "attn.to_v",
|
330 |
+
"self_attn.out_proj": "attn.to_out",
|
331 |
+
"layer_norm1": "layer_norm1",
|
332 |
+
"layer_norm2": "layer_norm2",
|
333 |
+
"mlp.fc1": "fc1",
|
334 |
+
"mlp.fc2": "fc2",
|
335 |
+
}
|
336 |
+
state_dict_ = {}
|
337 |
+
for name in state_dict:
|
338 |
+
if name in rename_dict:
|
339 |
+
param = state_dict[name]
|
340 |
+
if name == "text_model.embeddings.position_embedding.weight":
|
341 |
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
342 |
+
state_dict_[rename_dict[name]] = param
|
343 |
+
elif name.startswith("text_model.encoder.layers."):
|
344 |
+
param = state_dict[name]
|
345 |
+
names = name.split(".")
|
346 |
+
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
347 |
+
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
348 |
+
state_dict_[name_] = param
|
349 |
+
return state_dict_
|
350 |
+
|
351 |
+
def from_civitai(self, state_dict):
|
352 |
+
rename_dict = {
|
353 |
+
"conditioner.embedders.1.model.ln_final.bias": "final_layer_norm.bias",
|
354 |
+
"conditioner.embedders.1.model.ln_final.weight": "final_layer_norm.weight",
|
355 |
+
"conditioner.embedders.1.model.positional_embedding": "position_embeds",
|
356 |
+
"conditioner.embedders.1.model.token_embedding.weight": "token_embedding.weight",
|
357 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias": ['encoders.0.attn.to_q.bias', 'encoders.0.attn.to_k.bias', 'encoders.0.attn.to_v.bias'],
|
358 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight": ['encoders.0.attn.to_q.weight', 'encoders.0.attn.to_k.weight', 'encoders.0.attn.to_v.weight'],
|
359 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
360 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
361 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias": "encoders.0.layer_norm1.bias",
|
362 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.ln_1.weight": "encoders.0.layer_norm1.weight",
|
363 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.ln_2.bias": "encoders.0.layer_norm2.bias",
|
364 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.ln_2.weight": "encoders.0.layer_norm2.weight",
|
365 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.bias": "encoders.0.fc1.bias",
|
366 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.weight": "encoders.0.fc1.weight",
|
367 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.bias": "encoders.0.fc2.bias",
|
368 |
+
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.weight": "encoders.0.fc2.weight",
|
369 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias": ['encoders.1.attn.to_q.bias', 'encoders.1.attn.to_k.bias', 'encoders.1.attn.to_v.bias'],
|
370 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight": ['encoders.1.attn.to_q.weight', 'encoders.1.attn.to_k.weight', 'encoders.1.attn.to_v.weight'],
|
371 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
372 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
373 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.ln_1.bias": "encoders.1.layer_norm1.bias",
|
374 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.ln_1.weight": "encoders.1.layer_norm1.weight",
|
375 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.ln_2.bias": "encoders.1.layer_norm2.bias",
|
376 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.ln_2.weight": "encoders.1.layer_norm2.weight",
|
377 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.bias": "encoders.1.fc1.bias",
|
378 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.weight": "encoders.1.fc1.weight",
|
379 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.bias": "encoders.1.fc2.bias",
|
380 |
+
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.weight": "encoders.1.fc2.weight",
|
381 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias": ['encoders.10.attn.to_q.bias', 'encoders.10.attn.to_k.bias', 'encoders.10.attn.to_v.bias'],
|
382 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight": ['encoders.10.attn.to_q.weight', 'encoders.10.attn.to_k.weight', 'encoders.10.attn.to_v.weight'],
|
383 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
384 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
385 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.ln_1.bias": "encoders.10.layer_norm1.bias",
|
386 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.ln_1.weight": "encoders.10.layer_norm1.weight",
|
387 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.ln_2.bias": "encoders.10.layer_norm2.bias",
|
388 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.ln_2.weight": "encoders.10.layer_norm2.weight",
|
389 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.bias": "encoders.10.fc1.bias",
|
390 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.weight": "encoders.10.fc1.weight",
|
391 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.bias": "encoders.10.fc2.bias",
|
392 |
+
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.weight": "encoders.10.fc2.weight",
|
393 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias": ['encoders.11.attn.to_q.bias', 'encoders.11.attn.to_k.bias', 'encoders.11.attn.to_v.bias'],
|
394 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight": ['encoders.11.attn.to_q.weight', 'encoders.11.attn.to_k.weight', 'encoders.11.attn.to_v.weight'],
|
395 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.bias": "encoders.11.attn.to_out.bias",
|
396 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.weight": "encoders.11.attn.to_out.weight",
|
397 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.ln_1.bias": "encoders.11.layer_norm1.bias",
|
398 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.ln_1.weight": "encoders.11.layer_norm1.weight",
|
399 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.ln_2.bias": "encoders.11.layer_norm2.bias",
|
400 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.ln_2.weight": "encoders.11.layer_norm2.weight",
|
401 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.bias": "encoders.11.fc1.bias",
|
402 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.weight": "encoders.11.fc1.weight",
|
403 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.bias": "encoders.11.fc2.bias",
|
404 |
+
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.weight": "encoders.11.fc2.weight",
|
405 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias": ['encoders.12.attn.to_q.bias', 'encoders.12.attn.to_k.bias', 'encoders.12.attn.to_v.bias'],
|
406 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight": ['encoders.12.attn.to_q.weight', 'encoders.12.attn.to_k.weight', 'encoders.12.attn.to_v.weight'],
|
407 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.bias": "encoders.12.attn.to_out.bias",
|
408 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.weight": "encoders.12.attn.to_out.weight",
|
409 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.ln_1.bias": "encoders.12.layer_norm1.bias",
|
410 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.ln_1.weight": "encoders.12.layer_norm1.weight",
|
411 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.ln_2.bias": "encoders.12.layer_norm2.bias",
|
412 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.ln_2.weight": "encoders.12.layer_norm2.weight",
|
413 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.bias": "encoders.12.fc1.bias",
|
414 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.weight": "encoders.12.fc1.weight",
|
415 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.bias": "encoders.12.fc2.bias",
|
416 |
+
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.weight": "encoders.12.fc2.weight",
|
417 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias": ['encoders.13.attn.to_q.bias', 'encoders.13.attn.to_k.bias', 'encoders.13.attn.to_v.bias'],
|
418 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight": ['encoders.13.attn.to_q.weight', 'encoders.13.attn.to_k.weight', 'encoders.13.attn.to_v.weight'],
|
419 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.bias": "encoders.13.attn.to_out.bias",
|
420 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.weight": "encoders.13.attn.to_out.weight",
|
421 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.ln_1.bias": "encoders.13.layer_norm1.bias",
|
422 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.ln_1.weight": "encoders.13.layer_norm1.weight",
|
423 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.ln_2.bias": "encoders.13.layer_norm2.bias",
|
424 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.ln_2.weight": "encoders.13.layer_norm2.weight",
|
425 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.bias": "encoders.13.fc1.bias",
|
426 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.weight": "encoders.13.fc1.weight",
|
427 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.bias": "encoders.13.fc2.bias",
|
428 |
+
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.weight": "encoders.13.fc2.weight",
|
429 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias": ['encoders.14.attn.to_q.bias', 'encoders.14.attn.to_k.bias', 'encoders.14.attn.to_v.bias'],
|
430 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight": ['encoders.14.attn.to_q.weight', 'encoders.14.attn.to_k.weight', 'encoders.14.attn.to_v.weight'],
|
431 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.bias": "encoders.14.attn.to_out.bias",
|
432 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.weight": "encoders.14.attn.to_out.weight",
|
433 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.ln_1.bias": "encoders.14.layer_norm1.bias",
|
434 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.ln_1.weight": "encoders.14.layer_norm1.weight",
|
435 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.ln_2.bias": "encoders.14.layer_norm2.bias",
|
436 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.ln_2.weight": "encoders.14.layer_norm2.weight",
|
437 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.bias": "encoders.14.fc1.bias",
|
438 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.weight": "encoders.14.fc1.weight",
|
439 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.bias": "encoders.14.fc2.bias",
|
440 |
+
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.weight": "encoders.14.fc2.weight",
|
441 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias": ['encoders.15.attn.to_q.bias', 'encoders.15.attn.to_k.bias', 'encoders.15.attn.to_v.bias'],
|
442 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight": ['encoders.15.attn.to_q.weight', 'encoders.15.attn.to_k.weight', 'encoders.15.attn.to_v.weight'],
|
443 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.bias": "encoders.15.attn.to_out.bias",
|
444 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.weight": "encoders.15.attn.to_out.weight",
|
445 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.ln_1.bias": "encoders.15.layer_norm1.bias",
|
446 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.ln_1.weight": "encoders.15.layer_norm1.weight",
|
447 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.ln_2.bias": "encoders.15.layer_norm2.bias",
|
448 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.ln_2.weight": "encoders.15.layer_norm2.weight",
|
449 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.bias": "encoders.15.fc1.bias",
|
450 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.weight": "encoders.15.fc1.weight",
|
451 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.bias": "encoders.15.fc2.bias",
|
452 |
+
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.weight": "encoders.15.fc2.weight",
|
453 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias": ['encoders.16.attn.to_q.bias', 'encoders.16.attn.to_k.bias', 'encoders.16.attn.to_v.bias'],
|
454 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight": ['encoders.16.attn.to_q.weight', 'encoders.16.attn.to_k.weight', 'encoders.16.attn.to_v.weight'],
|
455 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.bias": "encoders.16.attn.to_out.bias",
|
456 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.weight": "encoders.16.attn.to_out.weight",
|
457 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.ln_1.bias": "encoders.16.layer_norm1.bias",
|
458 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.ln_1.weight": "encoders.16.layer_norm1.weight",
|
459 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.ln_2.bias": "encoders.16.layer_norm2.bias",
|
460 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.ln_2.weight": "encoders.16.layer_norm2.weight",
|
461 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.bias": "encoders.16.fc1.bias",
|
462 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.weight": "encoders.16.fc1.weight",
|
463 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.bias": "encoders.16.fc2.bias",
|
464 |
+
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.weight": "encoders.16.fc2.weight",
|
465 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias": ['encoders.17.attn.to_q.bias', 'encoders.17.attn.to_k.bias', 'encoders.17.attn.to_v.bias'],
|
466 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight": ['encoders.17.attn.to_q.weight', 'encoders.17.attn.to_k.weight', 'encoders.17.attn.to_v.weight'],
|
467 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.bias": "encoders.17.attn.to_out.bias",
|
468 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.weight": "encoders.17.attn.to_out.weight",
|
469 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.ln_1.bias": "encoders.17.layer_norm1.bias",
|
470 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.ln_1.weight": "encoders.17.layer_norm1.weight",
|
471 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.ln_2.bias": "encoders.17.layer_norm2.bias",
|
472 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.ln_2.weight": "encoders.17.layer_norm2.weight",
|
473 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.bias": "encoders.17.fc1.bias",
|
474 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.weight": "encoders.17.fc1.weight",
|
475 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.bias": "encoders.17.fc2.bias",
|
476 |
+
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.weight": "encoders.17.fc2.weight",
|
477 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias": ['encoders.18.attn.to_q.bias', 'encoders.18.attn.to_k.bias', 'encoders.18.attn.to_v.bias'],
|
478 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight": ['encoders.18.attn.to_q.weight', 'encoders.18.attn.to_k.weight', 'encoders.18.attn.to_v.weight'],
|
479 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.bias": "encoders.18.attn.to_out.bias",
|
480 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.weight": "encoders.18.attn.to_out.weight",
|
481 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.ln_1.bias": "encoders.18.layer_norm1.bias",
|
482 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.ln_1.weight": "encoders.18.layer_norm1.weight",
|
483 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.ln_2.bias": "encoders.18.layer_norm2.bias",
|
484 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.ln_2.weight": "encoders.18.layer_norm2.weight",
|
485 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.bias": "encoders.18.fc1.bias",
|
486 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.weight": "encoders.18.fc1.weight",
|
487 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.bias": "encoders.18.fc2.bias",
|
488 |
+
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.weight": "encoders.18.fc2.weight",
|
489 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias": ['encoders.19.attn.to_q.bias', 'encoders.19.attn.to_k.bias', 'encoders.19.attn.to_v.bias'],
|
490 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight": ['encoders.19.attn.to_q.weight', 'encoders.19.attn.to_k.weight', 'encoders.19.attn.to_v.weight'],
|
491 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.bias": "encoders.19.attn.to_out.bias",
|
492 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.weight": "encoders.19.attn.to_out.weight",
|
493 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.ln_1.bias": "encoders.19.layer_norm1.bias",
|
494 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.ln_1.weight": "encoders.19.layer_norm1.weight",
|
495 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.ln_2.bias": "encoders.19.layer_norm2.bias",
|
496 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.ln_2.weight": "encoders.19.layer_norm2.weight",
|
497 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.bias": "encoders.19.fc1.bias",
|
498 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.weight": "encoders.19.fc1.weight",
|
499 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.bias": "encoders.19.fc2.bias",
|
500 |
+
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.weight": "encoders.19.fc2.weight",
|
501 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias": ['encoders.2.attn.to_q.bias', 'encoders.2.attn.to_k.bias', 'encoders.2.attn.to_v.bias'],
|
502 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight": ['encoders.2.attn.to_q.weight', 'encoders.2.attn.to_k.weight', 'encoders.2.attn.to_v.weight'],
|
503 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
504 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
505 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.ln_1.bias": "encoders.2.layer_norm1.bias",
|
506 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.ln_1.weight": "encoders.2.layer_norm1.weight",
|
507 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.ln_2.bias": "encoders.2.layer_norm2.bias",
|
508 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.ln_2.weight": "encoders.2.layer_norm2.weight",
|
509 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.bias": "encoders.2.fc1.bias",
|
510 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.weight": "encoders.2.fc1.weight",
|
511 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.bias": "encoders.2.fc2.bias",
|
512 |
+
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.weight": "encoders.2.fc2.weight",
|
513 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias": ['encoders.20.attn.to_q.bias', 'encoders.20.attn.to_k.bias', 'encoders.20.attn.to_v.bias'],
|
514 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight": ['encoders.20.attn.to_q.weight', 'encoders.20.attn.to_k.weight', 'encoders.20.attn.to_v.weight'],
|
515 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.bias": "encoders.20.attn.to_out.bias",
|
516 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.weight": "encoders.20.attn.to_out.weight",
|
517 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.ln_1.bias": "encoders.20.layer_norm1.bias",
|
518 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.ln_1.weight": "encoders.20.layer_norm1.weight",
|
519 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.ln_2.bias": "encoders.20.layer_norm2.bias",
|
520 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.ln_2.weight": "encoders.20.layer_norm2.weight",
|
521 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.bias": "encoders.20.fc1.bias",
|
522 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.weight": "encoders.20.fc1.weight",
|
523 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.bias": "encoders.20.fc2.bias",
|
524 |
+
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.weight": "encoders.20.fc2.weight",
|
525 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias": ['encoders.21.attn.to_q.bias', 'encoders.21.attn.to_k.bias', 'encoders.21.attn.to_v.bias'],
|
526 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight": ['encoders.21.attn.to_q.weight', 'encoders.21.attn.to_k.weight', 'encoders.21.attn.to_v.weight'],
|
527 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.bias": "encoders.21.attn.to_out.bias",
|
528 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.weight": "encoders.21.attn.to_out.weight",
|
529 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.ln_1.bias": "encoders.21.layer_norm1.bias",
|
530 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.ln_1.weight": "encoders.21.layer_norm1.weight",
|
531 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.ln_2.bias": "encoders.21.layer_norm2.bias",
|
532 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.ln_2.weight": "encoders.21.layer_norm2.weight",
|
533 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.bias": "encoders.21.fc1.bias",
|
534 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.weight": "encoders.21.fc1.weight",
|
535 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.bias": "encoders.21.fc2.bias",
|
536 |
+
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.weight": "encoders.21.fc2.weight",
|
537 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias": ['encoders.22.attn.to_q.bias', 'encoders.22.attn.to_k.bias', 'encoders.22.attn.to_v.bias'],
|
538 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight": ['encoders.22.attn.to_q.weight', 'encoders.22.attn.to_k.weight', 'encoders.22.attn.to_v.weight'],
|
539 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.bias": "encoders.22.attn.to_out.bias",
|
540 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.weight": "encoders.22.attn.to_out.weight",
|
541 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.ln_1.bias": "encoders.22.layer_norm1.bias",
|
542 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.ln_1.weight": "encoders.22.layer_norm1.weight",
|
543 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.ln_2.bias": "encoders.22.layer_norm2.bias",
|
544 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.ln_2.weight": "encoders.22.layer_norm2.weight",
|
545 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.bias": "encoders.22.fc1.bias",
|
546 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.weight": "encoders.22.fc1.weight",
|
547 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.bias": "encoders.22.fc2.bias",
|
548 |
+
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.weight": "encoders.22.fc2.weight",
|
549 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias": ['encoders.23.attn.to_q.bias', 'encoders.23.attn.to_k.bias', 'encoders.23.attn.to_v.bias'],
|
550 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight": ['encoders.23.attn.to_q.weight', 'encoders.23.attn.to_k.weight', 'encoders.23.attn.to_v.weight'],
|
551 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.bias": "encoders.23.attn.to_out.bias",
|
552 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.weight": "encoders.23.attn.to_out.weight",
|
553 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.ln_1.bias": "encoders.23.layer_norm1.bias",
|
554 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.ln_1.weight": "encoders.23.layer_norm1.weight",
|
555 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.ln_2.bias": "encoders.23.layer_norm2.bias",
|
556 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.ln_2.weight": "encoders.23.layer_norm2.weight",
|
557 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.bias": "encoders.23.fc1.bias",
|
558 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.weight": "encoders.23.fc1.weight",
|
559 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.bias": "encoders.23.fc2.bias",
|
560 |
+
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.weight": "encoders.23.fc2.weight",
|
561 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias": ['encoders.24.attn.to_q.bias', 'encoders.24.attn.to_k.bias', 'encoders.24.attn.to_v.bias'],
|
562 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight": ['encoders.24.attn.to_q.weight', 'encoders.24.attn.to_k.weight', 'encoders.24.attn.to_v.weight'],
|
563 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.bias": "encoders.24.attn.to_out.bias",
|
564 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.weight": "encoders.24.attn.to_out.weight",
|
565 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.ln_1.bias": "encoders.24.layer_norm1.bias",
|
566 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.ln_1.weight": "encoders.24.layer_norm1.weight",
|
567 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.ln_2.bias": "encoders.24.layer_norm2.bias",
|
568 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.ln_2.weight": "encoders.24.layer_norm2.weight",
|
569 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.bias": "encoders.24.fc1.bias",
|
570 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.weight": "encoders.24.fc1.weight",
|
571 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.bias": "encoders.24.fc2.bias",
|
572 |
+
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.weight": "encoders.24.fc2.weight",
|
573 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias": ['encoders.25.attn.to_q.bias', 'encoders.25.attn.to_k.bias', 'encoders.25.attn.to_v.bias'],
|
574 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight": ['encoders.25.attn.to_q.weight', 'encoders.25.attn.to_k.weight', 'encoders.25.attn.to_v.weight'],
|
575 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.bias": "encoders.25.attn.to_out.bias",
|
576 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.weight": "encoders.25.attn.to_out.weight",
|
577 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.ln_1.bias": "encoders.25.layer_norm1.bias",
|
578 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.ln_1.weight": "encoders.25.layer_norm1.weight",
|
579 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.ln_2.bias": "encoders.25.layer_norm2.bias",
|
580 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.ln_2.weight": "encoders.25.layer_norm2.weight",
|
581 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.bias": "encoders.25.fc1.bias",
|
582 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.weight": "encoders.25.fc1.weight",
|
583 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.bias": "encoders.25.fc2.bias",
|
584 |
+
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.weight": "encoders.25.fc2.weight",
|
585 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias": ['encoders.26.attn.to_q.bias', 'encoders.26.attn.to_k.bias', 'encoders.26.attn.to_v.bias'],
|
586 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight": ['encoders.26.attn.to_q.weight', 'encoders.26.attn.to_k.weight', 'encoders.26.attn.to_v.weight'],
|
587 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.bias": "encoders.26.attn.to_out.bias",
|
588 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.weight": "encoders.26.attn.to_out.weight",
|
589 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.ln_1.bias": "encoders.26.layer_norm1.bias",
|
590 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.ln_1.weight": "encoders.26.layer_norm1.weight",
|
591 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.ln_2.bias": "encoders.26.layer_norm2.bias",
|
592 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.ln_2.weight": "encoders.26.layer_norm2.weight",
|
593 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.bias": "encoders.26.fc1.bias",
|
594 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.weight": "encoders.26.fc1.weight",
|
595 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.bias": "encoders.26.fc2.bias",
|
596 |
+
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.weight": "encoders.26.fc2.weight",
|
597 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias": ['encoders.27.attn.to_q.bias', 'encoders.27.attn.to_k.bias', 'encoders.27.attn.to_v.bias'],
|
598 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight": ['encoders.27.attn.to_q.weight', 'encoders.27.attn.to_k.weight', 'encoders.27.attn.to_v.weight'],
|
599 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.bias": "encoders.27.attn.to_out.bias",
|
600 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.weight": "encoders.27.attn.to_out.weight",
|
601 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.ln_1.bias": "encoders.27.layer_norm1.bias",
|
602 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.ln_1.weight": "encoders.27.layer_norm1.weight",
|
603 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.ln_2.bias": "encoders.27.layer_norm2.bias",
|
604 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.ln_2.weight": "encoders.27.layer_norm2.weight",
|
605 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.bias": "encoders.27.fc1.bias",
|
606 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.weight": "encoders.27.fc1.weight",
|
607 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.bias": "encoders.27.fc2.bias",
|
608 |
+
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.weight": "encoders.27.fc2.weight",
|
609 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias": ['encoders.28.attn.to_q.bias', 'encoders.28.attn.to_k.bias', 'encoders.28.attn.to_v.bias'],
|
610 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight": ['encoders.28.attn.to_q.weight', 'encoders.28.attn.to_k.weight', 'encoders.28.attn.to_v.weight'],
|
611 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.bias": "encoders.28.attn.to_out.bias",
|
612 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.weight": "encoders.28.attn.to_out.weight",
|
613 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.ln_1.bias": "encoders.28.layer_norm1.bias",
|
614 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.ln_1.weight": "encoders.28.layer_norm1.weight",
|
615 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.ln_2.bias": "encoders.28.layer_norm2.bias",
|
616 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.ln_2.weight": "encoders.28.layer_norm2.weight",
|
617 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.bias": "encoders.28.fc1.bias",
|
618 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.weight": "encoders.28.fc1.weight",
|
619 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.bias": "encoders.28.fc2.bias",
|
620 |
+
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.weight": "encoders.28.fc2.weight",
|
621 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias": ['encoders.29.attn.to_q.bias', 'encoders.29.attn.to_k.bias', 'encoders.29.attn.to_v.bias'],
|
622 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight": ['encoders.29.attn.to_q.weight', 'encoders.29.attn.to_k.weight', 'encoders.29.attn.to_v.weight'],
|
623 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.bias": "encoders.29.attn.to_out.bias",
|
624 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.weight": "encoders.29.attn.to_out.weight",
|
625 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.ln_1.bias": "encoders.29.layer_norm1.bias",
|
626 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.ln_1.weight": "encoders.29.layer_norm1.weight",
|
627 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.ln_2.bias": "encoders.29.layer_norm2.bias",
|
628 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.ln_2.weight": "encoders.29.layer_norm2.weight",
|
629 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.bias": "encoders.29.fc1.bias",
|
630 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.weight": "encoders.29.fc1.weight",
|
631 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.bias": "encoders.29.fc2.bias",
|
632 |
+
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.weight": "encoders.29.fc2.weight",
|
633 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias": ['encoders.3.attn.to_q.bias', 'encoders.3.attn.to_k.bias', 'encoders.3.attn.to_v.bias'],
|
634 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight": ['encoders.3.attn.to_q.weight', 'encoders.3.attn.to_k.weight', 'encoders.3.attn.to_v.weight'],
|
635 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
636 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
637 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.ln_1.bias": "encoders.3.layer_norm1.bias",
|
638 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.ln_1.weight": "encoders.3.layer_norm1.weight",
|
639 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.ln_2.bias": "encoders.3.layer_norm2.bias",
|
640 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.ln_2.weight": "encoders.3.layer_norm2.weight",
|
641 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.bias": "encoders.3.fc1.bias",
|
642 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.weight": "encoders.3.fc1.weight",
|
643 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.bias": "encoders.3.fc2.bias",
|
644 |
+
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.weight": "encoders.3.fc2.weight",
|
645 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias": ['encoders.30.attn.to_q.bias', 'encoders.30.attn.to_k.bias', 'encoders.30.attn.to_v.bias'],
|
646 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight": ['encoders.30.attn.to_q.weight', 'encoders.30.attn.to_k.weight', 'encoders.30.attn.to_v.weight'],
|
647 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.bias": "encoders.30.attn.to_out.bias",
|
648 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.weight": "encoders.30.attn.to_out.weight",
|
649 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.ln_1.bias": "encoders.30.layer_norm1.bias",
|
650 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.ln_1.weight": "encoders.30.layer_norm1.weight",
|
651 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.ln_2.bias": "encoders.30.layer_norm2.bias",
|
652 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.ln_2.weight": "encoders.30.layer_norm2.weight",
|
653 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.bias": "encoders.30.fc1.bias",
|
654 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.weight": "encoders.30.fc1.weight",
|
655 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.bias": "encoders.30.fc2.bias",
|
656 |
+
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.weight": "encoders.30.fc2.weight",
|
657 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias": ['encoders.31.attn.to_q.bias', 'encoders.31.attn.to_k.bias', 'encoders.31.attn.to_v.bias'],
|
658 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight": ['encoders.31.attn.to_q.weight', 'encoders.31.attn.to_k.weight', 'encoders.31.attn.to_v.weight'],
|
659 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.bias": "encoders.31.attn.to_out.bias",
|
660 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.weight": "encoders.31.attn.to_out.weight",
|
661 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.ln_1.bias": "encoders.31.layer_norm1.bias",
|
662 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.ln_1.weight": "encoders.31.layer_norm1.weight",
|
663 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.ln_2.bias": "encoders.31.layer_norm2.bias",
|
664 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.ln_2.weight": "encoders.31.layer_norm2.weight",
|
665 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.bias": "encoders.31.fc1.bias",
|
666 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.weight": "encoders.31.fc1.weight",
|
667 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.bias": "encoders.31.fc2.bias",
|
668 |
+
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.weight": "encoders.31.fc2.weight",
|
669 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias": ['encoders.4.attn.to_q.bias', 'encoders.4.attn.to_k.bias', 'encoders.4.attn.to_v.bias'],
|
670 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight": ['encoders.4.attn.to_q.weight', 'encoders.4.attn.to_k.weight', 'encoders.4.attn.to_v.weight'],
|
671 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
672 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
673 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.ln_1.bias": "encoders.4.layer_norm1.bias",
|
674 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.ln_1.weight": "encoders.4.layer_norm1.weight",
|
675 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.ln_2.bias": "encoders.4.layer_norm2.bias",
|
676 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.ln_2.weight": "encoders.4.layer_norm2.weight",
|
677 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.bias": "encoders.4.fc1.bias",
|
678 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.weight": "encoders.4.fc1.weight",
|
679 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.bias": "encoders.4.fc2.bias",
|
680 |
+
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.weight": "encoders.4.fc2.weight",
|
681 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias": ['encoders.5.attn.to_q.bias', 'encoders.5.attn.to_k.bias', 'encoders.5.attn.to_v.bias'],
|
682 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight": ['encoders.5.attn.to_q.weight', 'encoders.5.attn.to_k.weight', 'encoders.5.attn.to_v.weight'],
|
683 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
684 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
685 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.ln_1.bias": "encoders.5.layer_norm1.bias",
|
686 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.ln_1.weight": "encoders.5.layer_norm1.weight",
|
687 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.ln_2.bias": "encoders.5.layer_norm2.bias",
|
688 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.ln_2.weight": "encoders.5.layer_norm2.weight",
|
689 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.bias": "encoders.5.fc1.bias",
|
690 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.weight": "encoders.5.fc1.weight",
|
691 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.bias": "encoders.5.fc2.bias",
|
692 |
+
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.weight": "encoders.5.fc2.weight",
|
693 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias": ['encoders.6.attn.to_q.bias', 'encoders.6.attn.to_k.bias', 'encoders.6.attn.to_v.bias'],
|
694 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight": ['encoders.6.attn.to_q.weight', 'encoders.6.attn.to_k.weight', 'encoders.6.attn.to_v.weight'],
|
695 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
696 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
697 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.ln_1.bias": "encoders.6.layer_norm1.bias",
|
698 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.ln_1.weight": "encoders.6.layer_norm1.weight",
|
699 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.ln_2.bias": "encoders.6.layer_norm2.bias",
|
700 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.ln_2.weight": "encoders.6.layer_norm2.weight",
|
701 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.bias": "encoders.6.fc1.bias",
|
702 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.weight": "encoders.6.fc1.weight",
|
703 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.bias": "encoders.6.fc2.bias",
|
704 |
+
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.weight": "encoders.6.fc2.weight",
|
705 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias": ['encoders.7.attn.to_q.bias', 'encoders.7.attn.to_k.bias', 'encoders.7.attn.to_v.bias'],
|
706 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight": ['encoders.7.attn.to_q.weight', 'encoders.7.attn.to_k.weight', 'encoders.7.attn.to_v.weight'],
|
707 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
708 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
709 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.ln_1.bias": "encoders.7.layer_norm1.bias",
|
710 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.ln_1.weight": "encoders.7.layer_norm1.weight",
|
711 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.ln_2.bias": "encoders.7.layer_norm2.bias",
|
712 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.ln_2.weight": "encoders.7.layer_norm2.weight",
|
713 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.bias": "encoders.7.fc1.bias",
|
714 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.weight": "encoders.7.fc1.weight",
|
715 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.bias": "encoders.7.fc2.bias",
|
716 |
+
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.weight": "encoders.7.fc2.weight",
|
717 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias": ['encoders.8.attn.to_q.bias', 'encoders.8.attn.to_k.bias', 'encoders.8.attn.to_v.bias'],
|
718 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight": ['encoders.8.attn.to_q.weight', 'encoders.8.attn.to_k.weight', 'encoders.8.attn.to_v.weight'],
|
719 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
720 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
721 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.ln_1.bias": "encoders.8.layer_norm1.bias",
|
722 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.ln_1.weight": "encoders.8.layer_norm1.weight",
|
723 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.ln_2.bias": "encoders.8.layer_norm2.bias",
|
724 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.ln_2.weight": "encoders.8.layer_norm2.weight",
|
725 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.bias": "encoders.8.fc1.bias",
|
726 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.weight": "encoders.8.fc1.weight",
|
727 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.bias": "encoders.8.fc2.bias",
|
728 |
+
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.weight": "encoders.8.fc2.weight",
|
729 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias": ['encoders.9.attn.to_q.bias', 'encoders.9.attn.to_k.bias', 'encoders.9.attn.to_v.bias'],
|
730 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight": ['encoders.9.attn.to_q.weight', 'encoders.9.attn.to_k.weight', 'encoders.9.attn.to_v.weight'],
|
731 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
732 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
733 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.ln_1.bias": "encoders.9.layer_norm1.bias",
|
734 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.ln_1.weight": "encoders.9.layer_norm1.weight",
|
735 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.ln_2.bias": "encoders.9.layer_norm2.bias",
|
736 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.ln_2.weight": "encoders.9.layer_norm2.weight",
|
737 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.bias": "encoders.9.fc1.bias",
|
738 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.weight": "encoders.9.fc1.weight",
|
739 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias": "encoders.9.fc2.bias",
|
740 |
+
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight": "encoders.9.fc2.weight",
|
741 |
+
"conditioner.embedders.1.model.text_projection": "text_projection.weight",
|
742 |
+
}
|
743 |
+
state_dict_ = {}
|
744 |
+
for name in state_dict:
|
745 |
+
if name in rename_dict:
|
746 |
+
param = state_dict[name]
|
747 |
+
if name == "conditioner.embedders.1.model.positional_embedding":
|
748 |
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
749 |
+
elif name == "conditioner.embedders.1.model.text_projection":
|
750 |
+
param = param.T
|
751 |
+
if isinstance(rename_dict[name], str):
|
752 |
+
state_dict_[rename_dict[name]] = param
|
753 |
+
else:
|
754 |
+
length = param.shape[0] // 3
|
755 |
+
for i, rename in enumerate(rename_dict[name]):
|
756 |
+
state_dict_[rename] = param[i*length: i*length+length]
|
757 |
+
return state_dict_
|
diffsynth/models/sdxl_unet.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
diffsynth/models/sdxl_vae_decoder.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .sd_vae_decoder import SDVAEDecoder, SDVAEDecoderStateDictConverter
|
2 |
+
|
3 |
+
|
4 |
+
class SDXLVAEDecoder(SDVAEDecoder):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
self.scaling_factor = 0.13025
|
8 |
+
|
9 |
+
def state_dict_converter(self):
|
10 |
+
return SDXLVAEDecoderStateDictConverter()
|
11 |
+
|
12 |
+
|
13 |
+
class SDXLVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
|
14 |
+
def __init__(self):
|
15 |
+
super().__init__()
|
diffsynth/models/sdxl_vae_encoder.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder
|
2 |
+
|
3 |
+
|
4 |
+
class SDXLVAEEncoder(SDVAEEncoder):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
self.scaling_factor = 0.13025
|
8 |
+
|
9 |
+
def state_dict_converter(self):
|
10 |
+
return SDXLVAEEncoderStateDictConverter()
|
11 |
+
|
12 |
+
|
13 |
+
class SDXLVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
|
14 |
+
def __init__(self):
|
15 |
+
super().__init__()
|
diffsynth/models/svd_image_encoder.py
ADDED
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .sd_text_encoder import CLIPEncoderLayer
|
3 |
+
|
4 |
+
|
5 |
+
class CLIPVisionEmbeddings(torch.nn.Module):
|
6 |
+
def __init__(self, embed_dim=1280, image_size=224, patch_size=14, num_channels=3):
|
7 |
+
super().__init__()
|
8 |
+
|
9 |
+
# class_embeds (This is a fixed tensor)
|
10 |
+
self.class_embedding = torch.nn.Parameter(torch.randn(1, 1, embed_dim))
|
11 |
+
|
12 |
+
# position_embeds
|
13 |
+
self.patch_embedding = torch.nn.Conv2d(in_channels=num_channels, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size, bias=False)
|
14 |
+
|
15 |
+
# position_embeds (This is a fixed tensor)
|
16 |
+
self.position_embeds = torch.nn.Parameter(torch.zeros(1, (image_size // patch_size) ** 2 + 1, embed_dim))
|
17 |
+
|
18 |
+
def forward(self, pixel_values):
|
19 |
+
batch_size = pixel_values.shape[0]
|
20 |
+
patch_embeds = self.patch_embedding(pixel_values)
|
21 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
22 |
+
class_embeds = self.class_embedding.repeat(batch_size, 1, 1)
|
23 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + self.position_embeds
|
24 |
+
return embeddings
|
25 |
+
|
26 |
+
|
27 |
+
class SVDImageEncoder(torch.nn.Module):
|
28 |
+
def __init__(self, embed_dim=1280, layer_norm_eps=1e-5, num_encoder_layers=32, encoder_intermediate_size=5120, projection_dim=1024, num_heads=16, head_dim=80):
|
29 |
+
super().__init__()
|
30 |
+
self.embeddings = CLIPVisionEmbeddings(embed_dim=embed_dim)
|
31 |
+
self.pre_layernorm = torch.nn.LayerNorm(embed_dim, eps=layer_norm_eps)
|
32 |
+
self.encoders = torch.nn.ModuleList([
|
33 |
+
CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=num_heads, head_dim=head_dim, use_quick_gelu=False)
|
34 |
+
for _ in range(num_encoder_layers)])
|
35 |
+
self.post_layernorm = torch.nn.LayerNorm(embed_dim, eps=layer_norm_eps)
|
36 |
+
self.visual_projection = torch.nn.Linear(embed_dim, projection_dim, bias=False)
|
37 |
+
|
38 |
+
def forward(self, pixel_values):
|
39 |
+
embeds = self.embeddings(pixel_values)
|
40 |
+
embeds = self.pre_layernorm(embeds)
|
41 |
+
for encoder_id, encoder in enumerate(self.encoders):
|
42 |
+
embeds = encoder(embeds)
|
43 |
+
embeds = self.post_layernorm(embeds[:, 0, :])
|
44 |
+
embeds = self.visual_projection(embeds)
|
45 |
+
return embeds
|
46 |
+
|
47 |
+
def state_dict_converter(self):
|
48 |
+
return SVDImageEncoderStateDictConverter()
|
49 |
+
|
50 |
+
|
51 |
+
class SVDImageEncoderStateDictConverter:
|
52 |
+
def __init__(self):
|
53 |
+
pass
|
54 |
+
|
55 |
+
def from_diffusers(self, state_dict):
|
56 |
+
rename_dict = {
|
57 |
+
"vision_model.embeddings.patch_embedding.weight": "embeddings.patch_embedding.weight",
|
58 |
+
"vision_model.embeddings.class_embedding": "embeddings.class_embedding",
|
59 |
+
"vision_model.embeddings.position_embedding.weight": "embeddings.position_embeds",
|
60 |
+
"vision_model.pre_layrnorm.weight": "pre_layernorm.weight",
|
61 |
+
"vision_model.pre_layrnorm.bias": "pre_layernorm.bias",
|
62 |
+
"vision_model.post_layernorm.weight": "post_layernorm.weight",
|
63 |
+
"vision_model.post_layernorm.bias": "post_layernorm.bias",
|
64 |
+
"visual_projection.weight": "visual_projection.weight"
|
65 |
+
}
|
66 |
+
attn_rename_dict = {
|
67 |
+
"self_attn.q_proj": "attn.to_q",
|
68 |
+
"self_attn.k_proj": "attn.to_k",
|
69 |
+
"self_attn.v_proj": "attn.to_v",
|
70 |
+
"self_attn.out_proj": "attn.to_out",
|
71 |
+
"layer_norm1": "layer_norm1",
|
72 |
+
"layer_norm2": "layer_norm2",
|
73 |
+
"mlp.fc1": "fc1",
|
74 |
+
"mlp.fc2": "fc2",
|
75 |
+
}
|
76 |
+
state_dict_ = {}
|
77 |
+
for name in state_dict:
|
78 |
+
if name in rename_dict:
|
79 |
+
param = state_dict[name]
|
80 |
+
if name == "vision_model.embeddings.class_embedding":
|
81 |
+
param = state_dict[name].view(1, 1, -1)
|
82 |
+
elif name == "vision_model.embeddings.position_embedding.weight":
|
83 |
+
param = state_dict[name].unsqueeze(0)
|
84 |
+
state_dict_[rename_dict[name]] = param
|
85 |
+
elif name.startswith("vision_model.encoder.layers."):
|
86 |
+
param = state_dict[name]
|
87 |
+
names = name.split(".")
|
88 |
+
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
89 |
+
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
90 |
+
state_dict_[name_] = param
|
91 |
+
return state_dict_
|
92 |
+
|
93 |
+
def from_civitai(self, state_dict):
|
94 |
+
rename_dict = {
|
95 |
+
"conditioner.embedders.0.open_clip.model.visual.class_embedding": "embeddings.class_embedding",
|
96 |
+
"conditioner.embedders.0.open_clip.model.visual.conv1.weight": "embeddings.patch_embedding.weight",
|
97 |
+
"conditioner.embedders.0.open_clip.model.visual.ln_post.bias": "post_layernorm.bias",
|
98 |
+
"conditioner.embedders.0.open_clip.model.visual.ln_post.weight": "post_layernorm.weight",
|
99 |
+
"conditioner.embedders.0.open_clip.model.visual.ln_pre.bias": "pre_layernorm.bias",
|
100 |
+
"conditioner.embedders.0.open_clip.model.visual.ln_pre.weight": "pre_layernorm.weight",
|
101 |
+
"conditioner.embedders.0.open_clip.model.visual.positional_embedding": "embeddings.position_embeds",
|
102 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.in_proj_bias": ['encoders.0.attn.to_q.bias', 'encoders.0.attn.to_k.bias', 'encoders.0.attn.to_v.bias'],
|
103 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.in_proj_weight": ['encoders.0.attn.to_q.weight', 'encoders.0.attn.to_k.weight', 'encoders.0.attn.to_v.weight'],
|
104 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
105 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
106 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_1.bias": "encoders.0.layer_norm1.bias",
|
107 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_1.weight": "encoders.0.layer_norm1.weight",
|
108 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_2.bias": "encoders.0.layer_norm2.bias",
|
109 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_2.weight": "encoders.0.layer_norm2.weight",
|
110 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_fc.bias": "encoders.0.fc1.bias",
|
111 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_fc.weight": "encoders.0.fc1.weight",
|
112 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_proj.bias": "encoders.0.fc2.bias",
|
113 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_proj.weight": "encoders.0.fc2.weight",
|
114 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.in_proj_bias": ['encoders.1.attn.to_q.bias', 'encoders.1.attn.to_k.bias', 'encoders.1.attn.to_v.bias'],
|
115 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.in_proj_weight": ['encoders.1.attn.to_q.weight', 'encoders.1.attn.to_k.weight', 'encoders.1.attn.to_v.weight'],
|
116 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
117 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
118 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_1.bias": "encoders.1.layer_norm1.bias",
|
119 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_1.weight": "encoders.1.layer_norm1.weight",
|
120 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_2.bias": "encoders.1.layer_norm2.bias",
|
121 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_2.weight": "encoders.1.layer_norm2.weight",
|
122 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_fc.bias": "encoders.1.fc1.bias",
|
123 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_fc.weight": "encoders.1.fc1.weight",
|
124 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_proj.bias": "encoders.1.fc2.bias",
|
125 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_proj.weight": "encoders.1.fc2.weight",
|
126 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.in_proj_bias": ['encoders.10.attn.to_q.bias', 'encoders.10.attn.to_k.bias', 'encoders.10.attn.to_v.bias'],
|
127 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.in_proj_weight": ['encoders.10.attn.to_q.weight', 'encoders.10.attn.to_k.weight', 'encoders.10.attn.to_v.weight'],
|
128 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
129 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
130 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_1.bias": "encoders.10.layer_norm1.bias",
|
131 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_1.weight": "encoders.10.layer_norm1.weight",
|
132 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_2.bias": "encoders.10.layer_norm2.bias",
|
133 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_2.weight": "encoders.10.layer_norm2.weight",
|
134 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_fc.bias": "encoders.10.fc1.bias",
|
135 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_fc.weight": "encoders.10.fc1.weight",
|
136 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_proj.bias": "encoders.10.fc2.bias",
|
137 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_proj.weight": "encoders.10.fc2.weight",
|
138 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.in_proj_bias": ['encoders.11.attn.to_q.bias', 'encoders.11.attn.to_k.bias', 'encoders.11.attn.to_v.bias'],
|
139 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.in_proj_weight": ['encoders.11.attn.to_q.weight', 'encoders.11.attn.to_k.weight', 'encoders.11.attn.to_v.weight'],
|
140 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.out_proj.bias": "encoders.11.attn.to_out.bias",
|
141 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.out_proj.weight": "encoders.11.attn.to_out.weight",
|
142 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_1.bias": "encoders.11.layer_norm1.bias",
|
143 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_1.weight": "encoders.11.layer_norm1.weight",
|
144 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_2.bias": "encoders.11.layer_norm2.bias",
|
145 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_2.weight": "encoders.11.layer_norm2.weight",
|
146 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_fc.bias": "encoders.11.fc1.bias",
|
147 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_fc.weight": "encoders.11.fc1.weight",
|
148 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_proj.bias": "encoders.11.fc2.bias",
|
149 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_proj.weight": "encoders.11.fc2.weight",
|
150 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.in_proj_bias": ['encoders.12.attn.to_q.bias', 'encoders.12.attn.to_k.bias', 'encoders.12.attn.to_v.bias'],
|
151 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.in_proj_weight": ['encoders.12.attn.to_q.weight', 'encoders.12.attn.to_k.weight', 'encoders.12.attn.to_v.weight'],
|
152 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.out_proj.bias": "encoders.12.attn.to_out.bias",
|
153 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.out_proj.weight": "encoders.12.attn.to_out.weight",
|
154 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_1.bias": "encoders.12.layer_norm1.bias",
|
155 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_1.weight": "encoders.12.layer_norm1.weight",
|
156 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_2.bias": "encoders.12.layer_norm2.bias",
|
157 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_2.weight": "encoders.12.layer_norm2.weight",
|
158 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_fc.bias": "encoders.12.fc1.bias",
|
159 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_fc.weight": "encoders.12.fc1.weight",
|
160 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_proj.bias": "encoders.12.fc2.bias",
|
161 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_proj.weight": "encoders.12.fc2.weight",
|
162 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.in_proj_bias": ['encoders.13.attn.to_q.bias', 'encoders.13.attn.to_k.bias', 'encoders.13.attn.to_v.bias'],
|
163 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.in_proj_weight": ['encoders.13.attn.to_q.weight', 'encoders.13.attn.to_k.weight', 'encoders.13.attn.to_v.weight'],
|
164 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.out_proj.bias": "encoders.13.attn.to_out.bias",
|
165 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.out_proj.weight": "encoders.13.attn.to_out.weight",
|
166 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_1.bias": "encoders.13.layer_norm1.bias",
|
167 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_1.weight": "encoders.13.layer_norm1.weight",
|
168 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_2.bias": "encoders.13.layer_norm2.bias",
|
169 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_2.weight": "encoders.13.layer_norm2.weight",
|
170 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_fc.bias": "encoders.13.fc1.bias",
|
171 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_fc.weight": "encoders.13.fc1.weight",
|
172 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_proj.bias": "encoders.13.fc2.bias",
|
173 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_proj.weight": "encoders.13.fc2.weight",
|
174 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.in_proj_bias": ['encoders.14.attn.to_q.bias', 'encoders.14.attn.to_k.bias', 'encoders.14.attn.to_v.bias'],
|
175 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.in_proj_weight": ['encoders.14.attn.to_q.weight', 'encoders.14.attn.to_k.weight', 'encoders.14.attn.to_v.weight'],
|
176 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.out_proj.bias": "encoders.14.attn.to_out.bias",
|
177 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.out_proj.weight": "encoders.14.attn.to_out.weight",
|
178 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_1.bias": "encoders.14.layer_norm1.bias",
|
179 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_1.weight": "encoders.14.layer_norm1.weight",
|
180 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_2.bias": "encoders.14.layer_norm2.bias",
|
181 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_2.weight": "encoders.14.layer_norm2.weight",
|
182 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_fc.bias": "encoders.14.fc1.bias",
|
183 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_fc.weight": "encoders.14.fc1.weight",
|
184 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_proj.bias": "encoders.14.fc2.bias",
|
185 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_proj.weight": "encoders.14.fc2.weight",
|
186 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.in_proj_bias": ['encoders.15.attn.to_q.bias', 'encoders.15.attn.to_k.bias', 'encoders.15.attn.to_v.bias'],
|
187 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.in_proj_weight": ['encoders.15.attn.to_q.weight', 'encoders.15.attn.to_k.weight', 'encoders.15.attn.to_v.weight'],
|
188 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.out_proj.bias": "encoders.15.attn.to_out.bias",
|
189 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.out_proj.weight": "encoders.15.attn.to_out.weight",
|
190 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_1.bias": "encoders.15.layer_norm1.bias",
|
191 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_1.weight": "encoders.15.layer_norm1.weight",
|
192 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_2.bias": "encoders.15.layer_norm2.bias",
|
193 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_2.weight": "encoders.15.layer_norm2.weight",
|
194 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_fc.bias": "encoders.15.fc1.bias",
|
195 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_fc.weight": "encoders.15.fc1.weight",
|
196 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_proj.bias": "encoders.15.fc2.bias",
|
197 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_proj.weight": "encoders.15.fc2.weight",
|
198 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.in_proj_bias": ['encoders.16.attn.to_q.bias', 'encoders.16.attn.to_k.bias', 'encoders.16.attn.to_v.bias'],
|
199 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.in_proj_weight": ['encoders.16.attn.to_q.weight', 'encoders.16.attn.to_k.weight', 'encoders.16.attn.to_v.weight'],
|
200 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.out_proj.bias": "encoders.16.attn.to_out.bias",
|
201 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.out_proj.weight": "encoders.16.attn.to_out.weight",
|
202 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_1.bias": "encoders.16.layer_norm1.bias",
|
203 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_1.weight": "encoders.16.layer_norm1.weight",
|
204 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_2.bias": "encoders.16.layer_norm2.bias",
|
205 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_2.weight": "encoders.16.layer_norm2.weight",
|
206 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_fc.bias": "encoders.16.fc1.bias",
|
207 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_fc.weight": "encoders.16.fc1.weight",
|
208 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_proj.bias": "encoders.16.fc2.bias",
|
209 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_proj.weight": "encoders.16.fc2.weight",
|
210 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.in_proj_bias": ['encoders.17.attn.to_q.bias', 'encoders.17.attn.to_k.bias', 'encoders.17.attn.to_v.bias'],
|
211 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.in_proj_weight": ['encoders.17.attn.to_q.weight', 'encoders.17.attn.to_k.weight', 'encoders.17.attn.to_v.weight'],
|
212 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.out_proj.bias": "encoders.17.attn.to_out.bias",
|
213 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.out_proj.weight": "encoders.17.attn.to_out.weight",
|
214 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_1.bias": "encoders.17.layer_norm1.bias",
|
215 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_1.weight": "encoders.17.layer_norm1.weight",
|
216 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_2.bias": "encoders.17.layer_norm2.bias",
|
217 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_2.weight": "encoders.17.layer_norm2.weight",
|
218 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_fc.bias": "encoders.17.fc1.bias",
|
219 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_fc.weight": "encoders.17.fc1.weight",
|
220 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_proj.bias": "encoders.17.fc2.bias",
|
221 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_proj.weight": "encoders.17.fc2.weight",
|
222 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.in_proj_bias": ['encoders.18.attn.to_q.bias', 'encoders.18.attn.to_k.bias', 'encoders.18.attn.to_v.bias'],
|
223 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.in_proj_weight": ['encoders.18.attn.to_q.weight', 'encoders.18.attn.to_k.weight', 'encoders.18.attn.to_v.weight'],
|
224 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.out_proj.bias": "encoders.18.attn.to_out.bias",
|
225 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.out_proj.weight": "encoders.18.attn.to_out.weight",
|
226 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_1.bias": "encoders.18.layer_norm1.bias",
|
227 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_1.weight": "encoders.18.layer_norm1.weight",
|
228 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_2.bias": "encoders.18.layer_norm2.bias",
|
229 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_2.weight": "encoders.18.layer_norm2.weight",
|
230 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_fc.bias": "encoders.18.fc1.bias",
|
231 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_fc.weight": "encoders.18.fc1.weight",
|
232 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_proj.bias": "encoders.18.fc2.bias",
|
233 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_proj.weight": "encoders.18.fc2.weight",
|
234 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.in_proj_bias": ['encoders.19.attn.to_q.bias', 'encoders.19.attn.to_k.bias', 'encoders.19.attn.to_v.bias'],
|
235 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.in_proj_weight": ['encoders.19.attn.to_q.weight', 'encoders.19.attn.to_k.weight', 'encoders.19.attn.to_v.weight'],
|
236 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.out_proj.bias": "encoders.19.attn.to_out.bias",
|
237 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.out_proj.weight": "encoders.19.attn.to_out.weight",
|
238 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_1.bias": "encoders.19.layer_norm1.bias",
|
239 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_1.weight": "encoders.19.layer_norm1.weight",
|
240 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_2.bias": "encoders.19.layer_norm2.bias",
|
241 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_2.weight": "encoders.19.layer_norm2.weight",
|
242 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_fc.bias": "encoders.19.fc1.bias",
|
243 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_fc.weight": "encoders.19.fc1.weight",
|
244 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_proj.bias": "encoders.19.fc2.bias",
|
245 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_proj.weight": "encoders.19.fc2.weight",
|
246 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.in_proj_bias": ['encoders.2.attn.to_q.bias', 'encoders.2.attn.to_k.bias', 'encoders.2.attn.to_v.bias'],
|
247 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.in_proj_weight": ['encoders.2.attn.to_q.weight', 'encoders.2.attn.to_k.weight', 'encoders.2.attn.to_v.weight'],
|
248 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
249 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
250 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_1.bias": "encoders.2.layer_norm1.bias",
|
251 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_1.weight": "encoders.2.layer_norm1.weight",
|
252 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_2.bias": "encoders.2.layer_norm2.bias",
|
253 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_2.weight": "encoders.2.layer_norm2.weight",
|
254 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_fc.bias": "encoders.2.fc1.bias",
|
255 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_fc.weight": "encoders.2.fc1.weight",
|
256 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_proj.bias": "encoders.2.fc2.bias",
|
257 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_proj.weight": "encoders.2.fc2.weight",
|
258 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.in_proj_bias": ['encoders.20.attn.to_q.bias', 'encoders.20.attn.to_k.bias', 'encoders.20.attn.to_v.bias'],
|
259 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.in_proj_weight": ['encoders.20.attn.to_q.weight', 'encoders.20.attn.to_k.weight', 'encoders.20.attn.to_v.weight'],
|
260 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.out_proj.bias": "encoders.20.attn.to_out.bias",
|
261 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.out_proj.weight": "encoders.20.attn.to_out.weight",
|
262 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_1.bias": "encoders.20.layer_norm1.bias",
|
263 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_1.weight": "encoders.20.layer_norm1.weight",
|
264 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_2.bias": "encoders.20.layer_norm2.bias",
|
265 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_2.weight": "encoders.20.layer_norm2.weight",
|
266 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_fc.bias": "encoders.20.fc1.bias",
|
267 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_fc.weight": "encoders.20.fc1.weight",
|
268 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_proj.bias": "encoders.20.fc2.bias",
|
269 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_proj.weight": "encoders.20.fc2.weight",
|
270 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.in_proj_bias": ['encoders.21.attn.to_q.bias', 'encoders.21.attn.to_k.bias', 'encoders.21.attn.to_v.bias'],
|
271 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.in_proj_weight": ['encoders.21.attn.to_q.weight', 'encoders.21.attn.to_k.weight', 'encoders.21.attn.to_v.weight'],
|
272 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.out_proj.bias": "encoders.21.attn.to_out.bias",
|
273 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.out_proj.weight": "encoders.21.attn.to_out.weight",
|
274 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_1.bias": "encoders.21.layer_norm1.bias",
|
275 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_1.weight": "encoders.21.layer_norm1.weight",
|
276 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_2.bias": "encoders.21.layer_norm2.bias",
|
277 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_2.weight": "encoders.21.layer_norm2.weight",
|
278 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_fc.bias": "encoders.21.fc1.bias",
|
279 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_fc.weight": "encoders.21.fc1.weight",
|
280 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_proj.bias": "encoders.21.fc2.bias",
|
281 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_proj.weight": "encoders.21.fc2.weight",
|
282 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.in_proj_bias": ['encoders.22.attn.to_q.bias', 'encoders.22.attn.to_k.bias', 'encoders.22.attn.to_v.bias'],
|
283 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.in_proj_weight": ['encoders.22.attn.to_q.weight', 'encoders.22.attn.to_k.weight', 'encoders.22.attn.to_v.weight'],
|
284 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.out_proj.bias": "encoders.22.attn.to_out.bias",
|
285 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.out_proj.weight": "encoders.22.attn.to_out.weight",
|
286 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_1.bias": "encoders.22.layer_norm1.bias",
|
287 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_1.weight": "encoders.22.layer_norm1.weight",
|
288 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_2.bias": "encoders.22.layer_norm2.bias",
|
289 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_2.weight": "encoders.22.layer_norm2.weight",
|
290 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_fc.bias": "encoders.22.fc1.bias",
|
291 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_fc.weight": "encoders.22.fc1.weight",
|
292 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_proj.bias": "encoders.22.fc2.bias",
|
293 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_proj.weight": "encoders.22.fc2.weight",
|
294 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.in_proj_bias": ['encoders.23.attn.to_q.bias', 'encoders.23.attn.to_k.bias', 'encoders.23.attn.to_v.bias'],
|
295 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.in_proj_weight": ['encoders.23.attn.to_q.weight', 'encoders.23.attn.to_k.weight', 'encoders.23.attn.to_v.weight'],
|
296 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.out_proj.bias": "encoders.23.attn.to_out.bias",
|
297 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.out_proj.weight": "encoders.23.attn.to_out.weight",
|
298 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_1.bias": "encoders.23.layer_norm1.bias",
|
299 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_1.weight": "encoders.23.layer_norm1.weight",
|
300 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_2.bias": "encoders.23.layer_norm2.bias",
|
301 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_2.weight": "encoders.23.layer_norm2.weight",
|
302 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_fc.bias": "encoders.23.fc1.bias",
|
303 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_fc.weight": "encoders.23.fc1.weight",
|
304 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_proj.bias": "encoders.23.fc2.bias",
|
305 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_proj.weight": "encoders.23.fc2.weight",
|
306 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.in_proj_bias": ['encoders.24.attn.to_q.bias', 'encoders.24.attn.to_k.bias', 'encoders.24.attn.to_v.bias'],
|
307 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.in_proj_weight": ['encoders.24.attn.to_q.weight', 'encoders.24.attn.to_k.weight', 'encoders.24.attn.to_v.weight'],
|
308 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.out_proj.bias": "encoders.24.attn.to_out.bias",
|
309 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.out_proj.weight": "encoders.24.attn.to_out.weight",
|
310 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_1.bias": "encoders.24.layer_norm1.bias",
|
311 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_1.weight": "encoders.24.layer_norm1.weight",
|
312 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_2.bias": "encoders.24.layer_norm2.bias",
|
313 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_2.weight": "encoders.24.layer_norm2.weight",
|
314 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_fc.bias": "encoders.24.fc1.bias",
|
315 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_fc.weight": "encoders.24.fc1.weight",
|
316 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_proj.bias": "encoders.24.fc2.bias",
|
317 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_proj.weight": "encoders.24.fc2.weight",
|
318 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.in_proj_bias": ['encoders.25.attn.to_q.bias', 'encoders.25.attn.to_k.bias', 'encoders.25.attn.to_v.bias'],
|
319 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.in_proj_weight": ['encoders.25.attn.to_q.weight', 'encoders.25.attn.to_k.weight', 'encoders.25.attn.to_v.weight'],
|
320 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.out_proj.bias": "encoders.25.attn.to_out.bias",
|
321 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.out_proj.weight": "encoders.25.attn.to_out.weight",
|
322 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_1.bias": "encoders.25.layer_norm1.bias",
|
323 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_1.weight": "encoders.25.layer_norm1.weight",
|
324 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_2.bias": "encoders.25.layer_norm2.bias",
|
325 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_2.weight": "encoders.25.layer_norm2.weight",
|
326 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_fc.bias": "encoders.25.fc1.bias",
|
327 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_fc.weight": "encoders.25.fc1.weight",
|
328 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_proj.bias": "encoders.25.fc2.bias",
|
329 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_proj.weight": "encoders.25.fc2.weight",
|
330 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.in_proj_bias": ['encoders.26.attn.to_q.bias', 'encoders.26.attn.to_k.bias', 'encoders.26.attn.to_v.bias'],
|
331 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.in_proj_weight": ['encoders.26.attn.to_q.weight', 'encoders.26.attn.to_k.weight', 'encoders.26.attn.to_v.weight'],
|
332 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.out_proj.bias": "encoders.26.attn.to_out.bias",
|
333 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.out_proj.weight": "encoders.26.attn.to_out.weight",
|
334 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_1.bias": "encoders.26.layer_norm1.bias",
|
335 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_1.weight": "encoders.26.layer_norm1.weight",
|
336 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_2.bias": "encoders.26.layer_norm2.bias",
|
337 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_2.weight": "encoders.26.layer_norm2.weight",
|
338 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_fc.bias": "encoders.26.fc1.bias",
|
339 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_fc.weight": "encoders.26.fc1.weight",
|
340 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_proj.bias": "encoders.26.fc2.bias",
|
341 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_proj.weight": "encoders.26.fc2.weight",
|
342 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.in_proj_bias": ['encoders.27.attn.to_q.bias', 'encoders.27.attn.to_k.bias', 'encoders.27.attn.to_v.bias'],
|
343 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.in_proj_weight": ['encoders.27.attn.to_q.weight', 'encoders.27.attn.to_k.weight', 'encoders.27.attn.to_v.weight'],
|
344 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.out_proj.bias": "encoders.27.attn.to_out.bias",
|
345 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.out_proj.weight": "encoders.27.attn.to_out.weight",
|
346 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_1.bias": "encoders.27.layer_norm1.bias",
|
347 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_1.weight": "encoders.27.layer_norm1.weight",
|
348 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_2.bias": "encoders.27.layer_norm2.bias",
|
349 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_2.weight": "encoders.27.layer_norm2.weight",
|
350 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_fc.bias": "encoders.27.fc1.bias",
|
351 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_fc.weight": "encoders.27.fc1.weight",
|
352 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_proj.bias": "encoders.27.fc2.bias",
|
353 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_proj.weight": "encoders.27.fc2.weight",
|
354 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.in_proj_bias": ['encoders.28.attn.to_q.bias', 'encoders.28.attn.to_k.bias', 'encoders.28.attn.to_v.bias'],
|
355 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.in_proj_weight": ['encoders.28.attn.to_q.weight', 'encoders.28.attn.to_k.weight', 'encoders.28.attn.to_v.weight'],
|
356 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.out_proj.bias": "encoders.28.attn.to_out.bias",
|
357 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.out_proj.weight": "encoders.28.attn.to_out.weight",
|
358 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_1.bias": "encoders.28.layer_norm1.bias",
|
359 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_1.weight": "encoders.28.layer_norm1.weight",
|
360 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_2.bias": "encoders.28.layer_norm2.bias",
|
361 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_2.weight": "encoders.28.layer_norm2.weight",
|
362 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_fc.bias": "encoders.28.fc1.bias",
|
363 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_fc.weight": "encoders.28.fc1.weight",
|
364 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_proj.bias": "encoders.28.fc2.bias",
|
365 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_proj.weight": "encoders.28.fc2.weight",
|
366 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.in_proj_bias": ['encoders.29.attn.to_q.bias', 'encoders.29.attn.to_k.bias', 'encoders.29.attn.to_v.bias'],
|
367 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.in_proj_weight": ['encoders.29.attn.to_q.weight', 'encoders.29.attn.to_k.weight', 'encoders.29.attn.to_v.weight'],
|
368 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.out_proj.bias": "encoders.29.attn.to_out.bias",
|
369 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.out_proj.weight": "encoders.29.attn.to_out.weight",
|
370 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_1.bias": "encoders.29.layer_norm1.bias",
|
371 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_1.weight": "encoders.29.layer_norm1.weight",
|
372 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_2.bias": "encoders.29.layer_norm2.bias",
|
373 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_2.weight": "encoders.29.layer_norm2.weight",
|
374 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_fc.bias": "encoders.29.fc1.bias",
|
375 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_fc.weight": "encoders.29.fc1.weight",
|
376 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_proj.bias": "encoders.29.fc2.bias",
|
377 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_proj.weight": "encoders.29.fc2.weight",
|
378 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.in_proj_bias": ['encoders.3.attn.to_q.bias', 'encoders.3.attn.to_k.bias', 'encoders.3.attn.to_v.bias'],
|
379 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.in_proj_weight": ['encoders.3.attn.to_q.weight', 'encoders.3.attn.to_k.weight', 'encoders.3.attn.to_v.weight'],
|
380 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
381 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
382 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_1.bias": "encoders.3.layer_norm1.bias",
|
383 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_1.weight": "encoders.3.layer_norm1.weight",
|
384 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_2.bias": "encoders.3.layer_norm2.bias",
|
385 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_2.weight": "encoders.3.layer_norm2.weight",
|
386 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_fc.bias": "encoders.3.fc1.bias",
|
387 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_fc.weight": "encoders.3.fc1.weight",
|
388 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_proj.bias": "encoders.3.fc2.bias",
|
389 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_proj.weight": "encoders.3.fc2.weight",
|
390 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.in_proj_bias": ['encoders.30.attn.to_q.bias', 'encoders.30.attn.to_k.bias', 'encoders.30.attn.to_v.bias'],
|
391 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.in_proj_weight": ['encoders.30.attn.to_q.weight', 'encoders.30.attn.to_k.weight', 'encoders.30.attn.to_v.weight'],
|
392 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.out_proj.bias": "encoders.30.attn.to_out.bias",
|
393 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.out_proj.weight": "encoders.30.attn.to_out.weight",
|
394 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_1.bias": "encoders.30.layer_norm1.bias",
|
395 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_1.weight": "encoders.30.layer_norm1.weight",
|
396 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_2.bias": "encoders.30.layer_norm2.bias",
|
397 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_2.weight": "encoders.30.layer_norm2.weight",
|
398 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_fc.bias": "encoders.30.fc1.bias",
|
399 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_fc.weight": "encoders.30.fc1.weight",
|
400 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_proj.bias": "encoders.30.fc2.bias",
|
401 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_proj.weight": "encoders.30.fc2.weight",
|
402 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.in_proj_bias": ['encoders.31.attn.to_q.bias', 'encoders.31.attn.to_k.bias', 'encoders.31.attn.to_v.bias'],
|
403 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.in_proj_weight": ['encoders.31.attn.to_q.weight', 'encoders.31.attn.to_k.weight', 'encoders.31.attn.to_v.weight'],
|
404 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.out_proj.bias": "encoders.31.attn.to_out.bias",
|
405 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.out_proj.weight": "encoders.31.attn.to_out.weight",
|
406 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_1.bias": "encoders.31.layer_norm1.bias",
|
407 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_1.weight": "encoders.31.layer_norm1.weight",
|
408 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_2.bias": "encoders.31.layer_norm2.bias",
|
409 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_2.weight": "encoders.31.layer_norm2.weight",
|
410 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_fc.bias": "encoders.31.fc1.bias",
|
411 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_fc.weight": "encoders.31.fc1.weight",
|
412 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_proj.bias": "encoders.31.fc2.bias",
|
413 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_proj.weight": "encoders.31.fc2.weight",
|
414 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.in_proj_bias": ['encoders.4.attn.to_q.bias', 'encoders.4.attn.to_k.bias', 'encoders.4.attn.to_v.bias'],
|
415 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.in_proj_weight": ['encoders.4.attn.to_q.weight', 'encoders.4.attn.to_k.weight', 'encoders.4.attn.to_v.weight'],
|
416 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
417 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
418 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_1.bias": "encoders.4.layer_norm1.bias",
|
419 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_1.weight": "encoders.4.layer_norm1.weight",
|
420 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_2.bias": "encoders.4.layer_norm2.bias",
|
421 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_2.weight": "encoders.4.layer_norm2.weight",
|
422 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_fc.bias": "encoders.4.fc1.bias",
|
423 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_fc.weight": "encoders.4.fc1.weight",
|
424 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_proj.bias": "encoders.4.fc2.bias",
|
425 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_proj.weight": "encoders.4.fc2.weight",
|
426 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.in_proj_bias": ['encoders.5.attn.to_q.bias', 'encoders.5.attn.to_k.bias', 'encoders.5.attn.to_v.bias'],
|
427 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.in_proj_weight": ['encoders.5.attn.to_q.weight', 'encoders.5.attn.to_k.weight', 'encoders.5.attn.to_v.weight'],
|
428 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
429 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
430 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_1.bias": "encoders.5.layer_norm1.bias",
|
431 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_1.weight": "encoders.5.layer_norm1.weight",
|
432 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_2.bias": "encoders.5.layer_norm2.bias",
|
433 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_2.weight": "encoders.5.layer_norm2.weight",
|
434 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_fc.bias": "encoders.5.fc1.bias",
|
435 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_fc.weight": "encoders.5.fc1.weight",
|
436 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_proj.bias": "encoders.5.fc2.bias",
|
437 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_proj.weight": "encoders.5.fc2.weight",
|
438 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.in_proj_bias": ['encoders.6.attn.to_q.bias', 'encoders.6.attn.to_k.bias', 'encoders.6.attn.to_v.bias'],
|
439 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.in_proj_weight": ['encoders.6.attn.to_q.weight', 'encoders.6.attn.to_k.weight', 'encoders.6.attn.to_v.weight'],
|
440 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
441 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
442 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_1.bias": "encoders.6.layer_norm1.bias",
|
443 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_1.weight": "encoders.6.layer_norm1.weight",
|
444 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_2.bias": "encoders.6.layer_norm2.bias",
|
445 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_2.weight": "encoders.6.layer_norm2.weight",
|
446 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_fc.bias": "encoders.6.fc1.bias",
|
447 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_fc.weight": "encoders.6.fc1.weight",
|
448 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_proj.bias": "encoders.6.fc2.bias",
|
449 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_proj.weight": "encoders.6.fc2.weight",
|
450 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.in_proj_bias": ['encoders.7.attn.to_q.bias', 'encoders.7.attn.to_k.bias', 'encoders.7.attn.to_v.bias'],
|
451 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.in_proj_weight": ['encoders.7.attn.to_q.weight', 'encoders.7.attn.to_k.weight', 'encoders.7.attn.to_v.weight'],
|
452 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
453 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
454 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_1.bias": "encoders.7.layer_norm1.bias",
|
455 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_1.weight": "encoders.7.layer_norm1.weight",
|
456 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_2.bias": "encoders.7.layer_norm2.bias",
|
457 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_2.weight": "encoders.7.layer_norm2.weight",
|
458 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_fc.bias": "encoders.7.fc1.bias",
|
459 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_fc.weight": "encoders.7.fc1.weight",
|
460 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_proj.bias": "encoders.7.fc2.bias",
|
461 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_proj.weight": "encoders.7.fc2.weight",
|
462 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.in_proj_bias": ['encoders.8.attn.to_q.bias', 'encoders.8.attn.to_k.bias', 'encoders.8.attn.to_v.bias'],
|
463 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.in_proj_weight": ['encoders.8.attn.to_q.weight', 'encoders.8.attn.to_k.weight', 'encoders.8.attn.to_v.weight'],
|
464 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
465 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
466 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_1.bias": "encoders.8.layer_norm1.bias",
|
467 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_1.weight": "encoders.8.layer_norm1.weight",
|
468 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_2.bias": "encoders.8.layer_norm2.bias",
|
469 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_2.weight": "encoders.8.layer_norm2.weight",
|
470 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_fc.bias": "encoders.8.fc1.bias",
|
471 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_fc.weight": "encoders.8.fc1.weight",
|
472 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_proj.bias": "encoders.8.fc2.bias",
|
473 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_proj.weight": "encoders.8.fc2.weight",
|
474 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.in_proj_bias": ['encoders.9.attn.to_q.bias', 'encoders.9.attn.to_k.bias', 'encoders.9.attn.to_v.bias'],
|
475 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.in_proj_weight": ['encoders.9.attn.to_q.weight', 'encoders.9.attn.to_k.weight', 'encoders.9.attn.to_v.weight'],
|
476 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
477 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
478 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_1.bias": "encoders.9.layer_norm1.bias",
|
479 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_1.weight": "encoders.9.layer_norm1.weight",
|
480 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_2.bias": "encoders.9.layer_norm2.bias",
|
481 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_2.weight": "encoders.9.layer_norm2.weight",
|
482 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_fc.bias": "encoders.9.fc1.bias",
|
483 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_fc.weight": "encoders.9.fc1.weight",
|
484 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_proj.bias": "encoders.9.fc2.bias",
|
485 |
+
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_proj.weight": "encoders.9.fc2.weight",
|
486 |
+
"conditioner.embedders.0.open_clip.model.visual.proj": "visual_projection.weight",
|
487 |
+
}
|
488 |
+
state_dict_ = {}
|
489 |
+
for name in state_dict:
|
490 |
+
if name in rename_dict:
|
491 |
+
param = state_dict[name]
|
492 |
+
if name == "conditioner.embedders.0.open_clip.model.visual.class_embedding":
|
493 |
+
param = param.reshape((1, 1, param.shape[0]))
|
494 |
+
elif name == "conditioner.embedders.0.open_clip.model.visual.positional_embedding":
|
495 |
+
param = param.reshape((1, param.shape[0], param.shape[1]))
|
496 |
+
elif name == "conditioner.embedders.0.open_clip.model.visual.proj":
|
497 |
+
param = param.T
|
498 |
+
if isinstance(rename_dict[name], str):
|
499 |
+
state_dict_[rename_dict[name]] = param
|
500 |
+
else:
|
501 |
+
length = param.shape[0] // 3
|
502 |
+
for i, rename in enumerate(rename_dict[name]):
|
503 |
+
state_dict_[rename] = param[i*length: i*length+length]
|
504 |
+
return state_dict_
|
diffsynth/models/svd_unet.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
diffsynth/models/svd_vae_decoder.py
ADDED
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .attention import Attention
|
3 |
+
from .sd_unet import ResnetBlock, UpSampler
|
4 |
+
from .tiler import TileWorker
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
|
7 |
+
|
8 |
+
class VAEAttentionBlock(torch.nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
|
11 |
+
super().__init__()
|
12 |
+
inner_dim = num_attention_heads * attention_head_dim
|
13 |
+
|
14 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
15 |
+
|
16 |
+
self.transformer_blocks = torch.nn.ModuleList([
|
17 |
+
Attention(
|
18 |
+
inner_dim,
|
19 |
+
num_attention_heads,
|
20 |
+
attention_head_dim,
|
21 |
+
bias_q=True,
|
22 |
+
bias_kv=True,
|
23 |
+
bias_out=True
|
24 |
+
)
|
25 |
+
for d in range(num_layers)
|
26 |
+
])
|
27 |
+
|
28 |
+
def forward(self, hidden_states, time_emb, text_emb, res_stack):
|
29 |
+
batch, _, height, width = hidden_states.shape
|
30 |
+
residual = hidden_states
|
31 |
+
|
32 |
+
hidden_states = self.norm(hidden_states)
|
33 |
+
inner_dim = hidden_states.shape[1]
|
34 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
35 |
+
|
36 |
+
for block in self.transformer_blocks:
|
37 |
+
hidden_states = block(hidden_states)
|
38 |
+
|
39 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
40 |
+
hidden_states = hidden_states + residual
|
41 |
+
|
42 |
+
return hidden_states, time_emb, text_emb, res_stack
|
43 |
+
|
44 |
+
|
45 |
+
class TemporalResnetBlock(torch.nn.Module):
|
46 |
+
|
47 |
+
def __init__(self, in_channels, out_channels, groups=32, eps=1e-5):
|
48 |
+
super().__init__()
|
49 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
50 |
+
self.conv1 = torch.nn.Conv3d(in_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0))
|
51 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
|
52 |
+
self.conv2 = torch.nn.Conv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0))
|
53 |
+
self.nonlinearity = torch.nn.SiLU()
|
54 |
+
self.mix_factor = torch.nn.Parameter(torch.Tensor([0.5]))
|
55 |
+
|
56 |
+
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
|
57 |
+
x_spatial = hidden_states
|
58 |
+
x = rearrange(hidden_states, "T C H W -> 1 C T H W")
|
59 |
+
x = self.norm1(x)
|
60 |
+
x = self.nonlinearity(x)
|
61 |
+
x = self.conv1(x)
|
62 |
+
x = self.norm2(x)
|
63 |
+
x = self.nonlinearity(x)
|
64 |
+
x = self.conv2(x)
|
65 |
+
x_temporal = hidden_states + x[0].permute(1, 0, 2, 3)
|
66 |
+
alpha = torch.sigmoid(self.mix_factor)
|
67 |
+
hidden_states = alpha * x_temporal + (1 - alpha) * x_spatial
|
68 |
+
return hidden_states, time_emb, text_emb, res_stack
|
69 |
+
|
70 |
+
|
71 |
+
class SVDVAEDecoder(torch.nn.Module):
|
72 |
+
def __init__(self):
|
73 |
+
super().__init__()
|
74 |
+
self.scaling_factor = 0.18215
|
75 |
+
self.conv_in = torch.nn.Conv2d(4, 512, kernel_size=3, padding=1)
|
76 |
+
|
77 |
+
self.blocks = torch.nn.ModuleList([
|
78 |
+
# UNetMidBlock
|
79 |
+
ResnetBlock(512, 512, eps=1e-6),
|
80 |
+
TemporalResnetBlock(512, 512, eps=1e-6),
|
81 |
+
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
82 |
+
ResnetBlock(512, 512, eps=1e-6),
|
83 |
+
TemporalResnetBlock(512, 512, eps=1e-6),
|
84 |
+
# UpDecoderBlock
|
85 |
+
ResnetBlock(512, 512, eps=1e-6),
|
86 |
+
TemporalResnetBlock(512, 512, eps=1e-6),
|
87 |
+
ResnetBlock(512, 512, eps=1e-6),
|
88 |
+
TemporalResnetBlock(512, 512, eps=1e-6),
|
89 |
+
ResnetBlock(512, 512, eps=1e-6),
|
90 |
+
TemporalResnetBlock(512, 512, eps=1e-6),
|
91 |
+
UpSampler(512),
|
92 |
+
# UpDecoderBlock
|
93 |
+
ResnetBlock(512, 512, eps=1e-6),
|
94 |
+
TemporalResnetBlock(512, 512, eps=1e-6),
|
95 |
+
ResnetBlock(512, 512, eps=1e-6),
|
96 |
+
TemporalResnetBlock(512, 512, eps=1e-6),
|
97 |
+
ResnetBlock(512, 512, eps=1e-6),
|
98 |
+
TemporalResnetBlock(512, 512, eps=1e-6),
|
99 |
+
UpSampler(512),
|
100 |
+
# UpDecoderBlock
|
101 |
+
ResnetBlock(512, 256, eps=1e-6),
|
102 |
+
TemporalResnetBlock(256, 256, eps=1e-6),
|
103 |
+
ResnetBlock(256, 256, eps=1e-6),
|
104 |
+
TemporalResnetBlock(256, 256, eps=1e-6),
|
105 |
+
ResnetBlock(256, 256, eps=1e-6),
|
106 |
+
TemporalResnetBlock(256, 256, eps=1e-6),
|
107 |
+
UpSampler(256),
|
108 |
+
# UpDecoderBlock
|
109 |
+
ResnetBlock(256, 128, eps=1e-6),
|
110 |
+
TemporalResnetBlock(128, 128, eps=1e-6),
|
111 |
+
ResnetBlock(128, 128, eps=1e-6),
|
112 |
+
TemporalResnetBlock(128, 128, eps=1e-6),
|
113 |
+
ResnetBlock(128, 128, eps=1e-6),
|
114 |
+
TemporalResnetBlock(128, 128, eps=1e-6),
|
115 |
+
])
|
116 |
+
|
117 |
+
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-5)
|
118 |
+
self.conv_act = torch.nn.SiLU()
|
119 |
+
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
120 |
+
self.time_conv_out = torch.nn.Conv3d(3, 3, kernel_size=(3, 1, 1), padding=(1, 0, 0))
|
121 |
+
|
122 |
+
|
123 |
+
def forward(self, sample):
|
124 |
+
# 1. pre-process
|
125 |
+
hidden_states = rearrange(sample, "C T H W -> T C H W")
|
126 |
+
hidden_states = hidden_states / self.scaling_factor
|
127 |
+
hidden_states = self.conv_in(hidden_states)
|
128 |
+
time_emb, text_emb, res_stack = None, None, None
|
129 |
+
|
130 |
+
# 2. blocks
|
131 |
+
for i, block in enumerate(self.blocks):
|
132 |
+
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
133 |
+
|
134 |
+
# 3. output
|
135 |
+
hidden_states = self.conv_norm_out(hidden_states)
|
136 |
+
hidden_states = self.conv_act(hidden_states)
|
137 |
+
hidden_states = self.conv_out(hidden_states)
|
138 |
+
hidden_states = rearrange(hidden_states, "T C H W -> C T H W")
|
139 |
+
hidden_states = self.time_conv_out(hidden_states)
|
140 |
+
|
141 |
+
return hidden_states
|
142 |
+
|
143 |
+
|
144 |
+
def build_mask(self, data, is_bound):
|
145 |
+
_, T, H, W = data.shape
|
146 |
+
t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
|
147 |
+
h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
|
148 |
+
w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
|
149 |
+
border_width = (T + H + W) // 6
|
150 |
+
pad = torch.ones_like(t) * border_width
|
151 |
+
mask = torch.stack([
|
152 |
+
pad if is_bound[0] else t + 1,
|
153 |
+
pad if is_bound[1] else T - t,
|
154 |
+
pad if is_bound[2] else h + 1,
|
155 |
+
pad if is_bound[3] else H - h,
|
156 |
+
pad if is_bound[4] else w + 1,
|
157 |
+
pad if is_bound[5] else W - w
|
158 |
+
]).min(dim=0).values
|
159 |
+
mask = mask.clip(1, border_width)
|
160 |
+
mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
|
161 |
+
mask = rearrange(mask, "T H W -> 1 T H W")
|
162 |
+
return mask
|
163 |
+
|
164 |
+
|
165 |
+
def decode_video(
|
166 |
+
self, sample,
|
167 |
+
batch_time=8, batch_height=128, batch_width=128,
|
168 |
+
stride_time=4, stride_height=32, stride_width=32,
|
169 |
+
progress_bar=lambda x:x
|
170 |
+
):
|
171 |
+
sample = sample.permute(1, 0, 2, 3)
|
172 |
+
data_device = sample.device
|
173 |
+
computation_device = self.conv_in.weight.device
|
174 |
+
torch_dtype = sample.dtype
|
175 |
+
_, T, H, W = sample.shape
|
176 |
+
|
177 |
+
weight = torch.zeros((1, T, H*8, W*8), dtype=torch_dtype, device=data_device)
|
178 |
+
values = torch.zeros((3, T, H*8, W*8), dtype=torch_dtype, device=data_device)
|
179 |
+
|
180 |
+
# Split tasks
|
181 |
+
tasks = []
|
182 |
+
for t in range(0, T, stride_time):
|
183 |
+
for h in range(0, H, stride_height):
|
184 |
+
for w in range(0, W, stride_width):
|
185 |
+
if (t-stride_time >= 0 and t-stride_time+batch_time >= T)\
|
186 |
+
or (h-stride_height >= 0 and h-stride_height+batch_height >= H)\
|
187 |
+
or (w-stride_width >= 0 and w-stride_width+batch_width >= W):
|
188 |
+
continue
|
189 |
+
tasks.append((t, t+batch_time, h, h+batch_height, w, w+batch_width))
|
190 |
+
|
191 |
+
# Run
|
192 |
+
for tl, tr, hl, hr, wl, wr in progress_bar(tasks):
|
193 |
+
sample_batch = sample[:, tl:tr, hl:hr, wl:wr].to(computation_device)
|
194 |
+
sample_batch = self.forward(sample_batch).to(data_device)
|
195 |
+
mask = self.build_mask(sample_batch, is_bound=(tl==0, tr>=T, hl==0, hr>=H, wl==0, wr>=W))
|
196 |
+
values[:, tl:tr, hl*8:hr*8, wl*8:wr*8] += sample_batch * mask
|
197 |
+
weight[:, tl:tr, hl*8:hr*8, wl*8:wr*8] += mask
|
198 |
+
values /= weight
|
199 |
+
return values
|
200 |
+
|
201 |
+
|
202 |
+
def state_dict_converter(self):
|
203 |
+
return SVDVAEDecoderStateDictConverter()
|
204 |
+
|
205 |
+
|
206 |
+
class SVDVAEDecoderStateDictConverter:
|
207 |
+
def __init__(self):
|
208 |
+
pass
|
209 |
+
|
210 |
+
def from_diffusers(self, state_dict):
|
211 |
+
static_rename_dict = {
|
212 |
+
"decoder.conv_in": "conv_in",
|
213 |
+
"decoder.mid_block.attentions.0.group_norm": "blocks.2.norm",
|
214 |
+
"decoder.mid_block.attentions.0.to_q": "blocks.2.transformer_blocks.0.to_q",
|
215 |
+
"decoder.mid_block.attentions.0.to_k": "blocks.2.transformer_blocks.0.to_k",
|
216 |
+
"decoder.mid_block.attentions.0.to_v": "blocks.2.transformer_blocks.0.to_v",
|
217 |
+
"decoder.mid_block.attentions.0.to_out.0": "blocks.2.transformer_blocks.0.to_out",
|
218 |
+
"decoder.up_blocks.0.upsamplers.0.conv": "blocks.11.conv",
|
219 |
+
"decoder.up_blocks.1.upsamplers.0.conv": "blocks.18.conv",
|
220 |
+
"decoder.up_blocks.2.upsamplers.0.conv": "blocks.25.conv",
|
221 |
+
"decoder.conv_norm_out": "conv_norm_out",
|
222 |
+
"decoder.conv_out": "conv_out",
|
223 |
+
"decoder.time_conv_out": "time_conv_out"
|
224 |
+
}
|
225 |
+
prefix_rename_dict = {
|
226 |
+
"decoder.mid_block.resnets.0.spatial_res_block": "blocks.0",
|
227 |
+
"decoder.mid_block.resnets.0.temporal_res_block": "blocks.1",
|
228 |
+
"decoder.mid_block.resnets.0.time_mixer": "blocks.1",
|
229 |
+
"decoder.mid_block.resnets.1.spatial_res_block": "blocks.3",
|
230 |
+
"decoder.mid_block.resnets.1.temporal_res_block": "blocks.4",
|
231 |
+
"decoder.mid_block.resnets.1.time_mixer": "blocks.4",
|
232 |
+
|
233 |
+
"decoder.up_blocks.0.resnets.0.spatial_res_block": "blocks.5",
|
234 |
+
"decoder.up_blocks.0.resnets.0.temporal_res_block": "blocks.6",
|
235 |
+
"decoder.up_blocks.0.resnets.0.time_mixer": "blocks.6",
|
236 |
+
"decoder.up_blocks.0.resnets.1.spatial_res_block": "blocks.7",
|
237 |
+
"decoder.up_blocks.0.resnets.1.temporal_res_block": "blocks.8",
|
238 |
+
"decoder.up_blocks.0.resnets.1.time_mixer": "blocks.8",
|
239 |
+
"decoder.up_blocks.0.resnets.2.spatial_res_block": "blocks.9",
|
240 |
+
"decoder.up_blocks.0.resnets.2.temporal_res_block": "blocks.10",
|
241 |
+
"decoder.up_blocks.0.resnets.2.time_mixer": "blocks.10",
|
242 |
+
|
243 |
+
"decoder.up_blocks.1.resnets.0.spatial_res_block": "blocks.12",
|
244 |
+
"decoder.up_blocks.1.resnets.0.temporal_res_block": "blocks.13",
|
245 |
+
"decoder.up_blocks.1.resnets.0.time_mixer": "blocks.13",
|
246 |
+
"decoder.up_blocks.1.resnets.1.spatial_res_block": "blocks.14",
|
247 |
+
"decoder.up_blocks.1.resnets.1.temporal_res_block": "blocks.15",
|
248 |
+
"decoder.up_blocks.1.resnets.1.time_mixer": "blocks.15",
|
249 |
+
"decoder.up_blocks.1.resnets.2.spatial_res_block": "blocks.16",
|
250 |
+
"decoder.up_blocks.1.resnets.2.temporal_res_block": "blocks.17",
|
251 |
+
"decoder.up_blocks.1.resnets.2.time_mixer": "blocks.17",
|
252 |
+
|
253 |
+
"decoder.up_blocks.2.resnets.0.spatial_res_block": "blocks.19",
|
254 |
+
"decoder.up_blocks.2.resnets.0.temporal_res_block": "blocks.20",
|
255 |
+
"decoder.up_blocks.2.resnets.0.time_mixer": "blocks.20",
|
256 |
+
"decoder.up_blocks.2.resnets.1.spatial_res_block": "blocks.21",
|
257 |
+
"decoder.up_blocks.2.resnets.1.temporal_res_block": "blocks.22",
|
258 |
+
"decoder.up_blocks.2.resnets.1.time_mixer": "blocks.22",
|
259 |
+
"decoder.up_blocks.2.resnets.2.spatial_res_block": "blocks.23",
|
260 |
+
"decoder.up_blocks.2.resnets.2.temporal_res_block": "blocks.24",
|
261 |
+
"decoder.up_blocks.2.resnets.2.time_mixer": "blocks.24",
|
262 |
+
|
263 |
+
"decoder.up_blocks.3.resnets.0.spatial_res_block": "blocks.26",
|
264 |
+
"decoder.up_blocks.3.resnets.0.temporal_res_block": "blocks.27",
|
265 |
+
"decoder.up_blocks.3.resnets.0.time_mixer": "blocks.27",
|
266 |
+
"decoder.up_blocks.3.resnets.1.spatial_res_block": "blocks.28",
|
267 |
+
"decoder.up_blocks.3.resnets.1.temporal_res_block": "blocks.29",
|
268 |
+
"decoder.up_blocks.3.resnets.1.time_mixer": "blocks.29",
|
269 |
+
"decoder.up_blocks.3.resnets.2.spatial_res_block": "blocks.30",
|
270 |
+
"decoder.up_blocks.3.resnets.2.temporal_res_block": "blocks.31",
|
271 |
+
"decoder.up_blocks.3.resnets.2.time_mixer": "blocks.31",
|
272 |
+
}
|
273 |
+
suffix_rename_dict = {
|
274 |
+
"norm1.weight": "norm1.weight",
|
275 |
+
"conv1.weight": "conv1.weight",
|
276 |
+
"norm2.weight": "norm2.weight",
|
277 |
+
"conv2.weight": "conv2.weight",
|
278 |
+
"conv_shortcut.weight": "conv_shortcut.weight",
|
279 |
+
"norm1.bias": "norm1.bias",
|
280 |
+
"conv1.bias": "conv1.bias",
|
281 |
+
"norm2.bias": "norm2.bias",
|
282 |
+
"conv2.bias": "conv2.bias",
|
283 |
+
"conv_shortcut.bias": "conv_shortcut.bias",
|
284 |
+
"mix_factor": "mix_factor",
|
285 |
+
}
|
286 |
+
|
287 |
+
state_dict_ = {}
|
288 |
+
for name in static_rename_dict:
|
289 |
+
state_dict_[static_rename_dict[name] + ".weight"] = state_dict[name + ".weight"]
|
290 |
+
state_dict_[static_rename_dict[name] + ".bias"] = state_dict[name + ".bias"]
|
291 |
+
for prefix_name in prefix_rename_dict:
|
292 |
+
for suffix_name in suffix_rename_dict:
|
293 |
+
name = prefix_name + "." + suffix_name
|
294 |
+
name_ = prefix_rename_dict[prefix_name] + "." + suffix_rename_dict[suffix_name]
|
295 |
+
if name in state_dict:
|
296 |
+
state_dict_[name_] = state_dict[name]
|
297 |
+
|
298 |
+
return state_dict_
|
299 |
+
|
300 |
+
|
301 |
+
def from_civitai(self, state_dict):
|
302 |
+
rename_dict = {
|
303 |
+
"first_stage_model.decoder.conv_in.bias": "conv_in.bias",
|
304 |
+
"first_stage_model.decoder.conv_in.weight": "conv_in.weight",
|
305 |
+
"first_stage_model.decoder.conv_out.bias": "conv_out.bias",
|
306 |
+
"first_stage_model.decoder.conv_out.time_mix_conv.bias": "time_conv_out.bias",
|
307 |
+
"first_stage_model.decoder.conv_out.time_mix_conv.weight": "time_conv_out.weight",
|
308 |
+
"first_stage_model.decoder.conv_out.weight": "conv_out.weight",
|
309 |
+
"first_stage_model.decoder.mid.attn_1.k.bias": "blocks.2.transformer_blocks.0.to_k.bias",
|
310 |
+
"first_stage_model.decoder.mid.attn_1.k.weight": "blocks.2.transformer_blocks.0.to_k.weight",
|
311 |
+
"first_stage_model.decoder.mid.attn_1.norm.bias": "blocks.2.norm.bias",
|
312 |
+
"first_stage_model.decoder.mid.attn_1.norm.weight": "blocks.2.norm.weight",
|
313 |
+
"first_stage_model.decoder.mid.attn_1.proj_out.bias": "blocks.2.transformer_blocks.0.to_out.bias",
|
314 |
+
"first_stage_model.decoder.mid.attn_1.proj_out.weight": "blocks.2.transformer_blocks.0.to_out.weight",
|
315 |
+
"first_stage_model.decoder.mid.attn_1.q.bias": "blocks.2.transformer_blocks.0.to_q.bias",
|
316 |
+
"first_stage_model.decoder.mid.attn_1.q.weight": "blocks.2.transformer_blocks.0.to_q.weight",
|
317 |
+
"first_stage_model.decoder.mid.attn_1.v.bias": "blocks.2.transformer_blocks.0.to_v.bias",
|
318 |
+
"first_stage_model.decoder.mid.attn_1.v.weight": "blocks.2.transformer_blocks.0.to_v.weight",
|
319 |
+
"first_stage_model.decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
|
320 |
+
"first_stage_model.decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
|
321 |
+
"first_stage_model.decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
|
322 |
+
"first_stage_model.decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
|
323 |
+
"first_stage_model.decoder.mid.block_1.mix_factor": "blocks.1.mix_factor",
|
324 |
+
"first_stage_model.decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
|
325 |
+
"first_stage_model.decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
|
326 |
+
"first_stage_model.decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
|
327 |
+
"first_stage_model.decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
|
328 |
+
"first_stage_model.decoder.mid.block_1.time_stack.in_layers.0.bias": "blocks.1.norm1.bias",
|
329 |
+
"first_stage_model.decoder.mid.block_1.time_stack.in_layers.0.weight": "blocks.1.norm1.weight",
|
330 |
+
"first_stage_model.decoder.mid.block_1.time_stack.in_layers.2.bias": "blocks.1.conv1.bias",
|
331 |
+
"first_stage_model.decoder.mid.block_1.time_stack.in_layers.2.weight": "blocks.1.conv1.weight",
|
332 |
+
"first_stage_model.decoder.mid.block_1.time_stack.out_layers.0.bias": "blocks.1.norm2.bias",
|
333 |
+
"first_stage_model.decoder.mid.block_1.time_stack.out_layers.0.weight": "blocks.1.norm2.weight",
|
334 |
+
"first_stage_model.decoder.mid.block_1.time_stack.out_layers.3.bias": "blocks.1.conv2.bias",
|
335 |
+
"first_stage_model.decoder.mid.block_1.time_stack.out_layers.3.weight": "blocks.1.conv2.weight",
|
336 |
+
"first_stage_model.decoder.mid.block_2.conv1.bias": "blocks.3.conv1.bias",
|
337 |
+
"first_stage_model.decoder.mid.block_2.conv1.weight": "blocks.3.conv1.weight",
|
338 |
+
"first_stage_model.decoder.mid.block_2.conv2.bias": "blocks.3.conv2.bias",
|
339 |
+
"first_stage_model.decoder.mid.block_2.conv2.weight": "blocks.3.conv2.weight",
|
340 |
+
"first_stage_model.decoder.mid.block_2.mix_factor": "blocks.4.mix_factor",
|
341 |
+
"first_stage_model.decoder.mid.block_2.norm1.bias": "blocks.3.norm1.bias",
|
342 |
+
"first_stage_model.decoder.mid.block_2.norm1.weight": "blocks.3.norm1.weight",
|
343 |
+
"first_stage_model.decoder.mid.block_2.norm2.bias": "blocks.3.norm2.bias",
|
344 |
+
"first_stage_model.decoder.mid.block_2.norm2.weight": "blocks.3.norm2.weight",
|
345 |
+
"first_stage_model.decoder.mid.block_2.time_stack.in_layers.0.bias": "blocks.4.norm1.bias",
|
346 |
+
"first_stage_model.decoder.mid.block_2.time_stack.in_layers.0.weight": "blocks.4.norm1.weight",
|
347 |
+
"first_stage_model.decoder.mid.block_2.time_stack.in_layers.2.bias": "blocks.4.conv1.bias",
|
348 |
+
"first_stage_model.decoder.mid.block_2.time_stack.in_layers.2.weight": "blocks.4.conv1.weight",
|
349 |
+
"first_stage_model.decoder.mid.block_2.time_stack.out_layers.0.bias": "blocks.4.norm2.bias",
|
350 |
+
"first_stage_model.decoder.mid.block_2.time_stack.out_layers.0.weight": "blocks.4.norm2.weight",
|
351 |
+
"first_stage_model.decoder.mid.block_2.time_stack.out_layers.3.bias": "blocks.4.conv2.bias",
|
352 |
+
"first_stage_model.decoder.mid.block_2.time_stack.out_layers.3.weight": "blocks.4.conv2.weight",
|
353 |
+
"first_stage_model.decoder.norm_out.bias": "conv_norm_out.bias",
|
354 |
+
"first_stage_model.decoder.norm_out.weight": "conv_norm_out.weight",
|
355 |
+
"first_stage_model.decoder.up.0.block.0.conv1.bias": "blocks.26.conv1.bias",
|
356 |
+
"first_stage_model.decoder.up.0.block.0.conv1.weight": "blocks.26.conv1.weight",
|
357 |
+
"first_stage_model.decoder.up.0.block.0.conv2.bias": "blocks.26.conv2.bias",
|
358 |
+
"first_stage_model.decoder.up.0.block.0.conv2.weight": "blocks.26.conv2.weight",
|
359 |
+
"first_stage_model.decoder.up.0.block.0.mix_factor": "blocks.27.mix_factor",
|
360 |
+
"first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "blocks.26.conv_shortcut.bias",
|
361 |
+
"first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "blocks.26.conv_shortcut.weight",
|
362 |
+
"first_stage_model.decoder.up.0.block.0.norm1.bias": "blocks.26.norm1.bias",
|
363 |
+
"first_stage_model.decoder.up.0.block.0.norm1.weight": "blocks.26.norm1.weight",
|
364 |
+
"first_stage_model.decoder.up.0.block.0.norm2.bias": "blocks.26.norm2.bias",
|
365 |
+
"first_stage_model.decoder.up.0.block.0.norm2.weight": "blocks.26.norm2.weight",
|
366 |
+
"first_stage_model.decoder.up.0.block.0.time_stack.in_layers.0.bias": "blocks.27.norm1.bias",
|
367 |
+
"first_stage_model.decoder.up.0.block.0.time_stack.in_layers.0.weight": "blocks.27.norm1.weight",
|
368 |
+
"first_stage_model.decoder.up.0.block.0.time_stack.in_layers.2.bias": "blocks.27.conv1.bias",
|
369 |
+
"first_stage_model.decoder.up.0.block.0.time_stack.in_layers.2.weight": "blocks.27.conv1.weight",
|
370 |
+
"first_stage_model.decoder.up.0.block.0.time_stack.out_layers.0.bias": "blocks.27.norm2.bias",
|
371 |
+
"first_stage_model.decoder.up.0.block.0.time_stack.out_layers.0.weight": "blocks.27.norm2.weight",
|
372 |
+
"first_stage_model.decoder.up.0.block.0.time_stack.out_layers.3.bias": "blocks.27.conv2.bias",
|
373 |
+
"first_stage_model.decoder.up.0.block.0.time_stack.out_layers.3.weight": "blocks.27.conv2.weight",
|
374 |
+
"first_stage_model.decoder.up.0.block.1.conv1.bias": "blocks.28.conv1.bias",
|
375 |
+
"first_stage_model.decoder.up.0.block.1.conv1.weight": "blocks.28.conv1.weight",
|
376 |
+
"first_stage_model.decoder.up.0.block.1.conv2.bias": "blocks.28.conv2.bias",
|
377 |
+
"first_stage_model.decoder.up.0.block.1.conv2.weight": "blocks.28.conv2.weight",
|
378 |
+
"first_stage_model.decoder.up.0.block.1.mix_factor": "blocks.29.mix_factor",
|
379 |
+
"first_stage_model.decoder.up.0.block.1.norm1.bias": "blocks.28.norm1.bias",
|
380 |
+
"first_stage_model.decoder.up.0.block.1.norm1.weight": "blocks.28.norm1.weight",
|
381 |
+
"first_stage_model.decoder.up.0.block.1.norm2.bias": "blocks.28.norm2.bias",
|
382 |
+
"first_stage_model.decoder.up.0.block.1.norm2.weight": "blocks.28.norm2.weight",
|
383 |
+
"first_stage_model.decoder.up.0.block.1.time_stack.in_layers.0.bias": "blocks.29.norm1.bias",
|
384 |
+
"first_stage_model.decoder.up.0.block.1.time_stack.in_layers.0.weight": "blocks.29.norm1.weight",
|
385 |
+
"first_stage_model.decoder.up.0.block.1.time_stack.in_layers.2.bias": "blocks.29.conv1.bias",
|
386 |
+
"first_stage_model.decoder.up.0.block.1.time_stack.in_layers.2.weight": "blocks.29.conv1.weight",
|
387 |
+
"first_stage_model.decoder.up.0.block.1.time_stack.out_layers.0.bias": "blocks.29.norm2.bias",
|
388 |
+
"first_stage_model.decoder.up.0.block.1.time_stack.out_layers.0.weight": "blocks.29.norm2.weight",
|
389 |
+
"first_stage_model.decoder.up.0.block.1.time_stack.out_layers.3.bias": "blocks.29.conv2.bias",
|
390 |
+
"first_stage_model.decoder.up.0.block.1.time_stack.out_layers.3.weight": "blocks.29.conv2.weight",
|
391 |
+
"first_stage_model.decoder.up.0.block.2.conv1.bias": "blocks.30.conv1.bias",
|
392 |
+
"first_stage_model.decoder.up.0.block.2.conv1.weight": "blocks.30.conv1.weight",
|
393 |
+
"first_stage_model.decoder.up.0.block.2.conv2.bias": "blocks.30.conv2.bias",
|
394 |
+
"first_stage_model.decoder.up.0.block.2.conv2.weight": "blocks.30.conv2.weight",
|
395 |
+
"first_stage_model.decoder.up.0.block.2.mix_factor": "blocks.31.mix_factor",
|
396 |
+
"first_stage_model.decoder.up.0.block.2.norm1.bias": "blocks.30.norm1.bias",
|
397 |
+
"first_stage_model.decoder.up.0.block.2.norm1.weight": "blocks.30.norm1.weight",
|
398 |
+
"first_stage_model.decoder.up.0.block.2.norm2.bias": "blocks.30.norm2.bias",
|
399 |
+
"first_stage_model.decoder.up.0.block.2.norm2.weight": "blocks.30.norm2.weight",
|
400 |
+
"first_stage_model.decoder.up.0.block.2.time_stack.in_layers.0.bias": "blocks.31.norm1.bias",
|
401 |
+
"first_stage_model.decoder.up.0.block.2.time_stack.in_layers.0.weight": "blocks.31.norm1.weight",
|
402 |
+
"first_stage_model.decoder.up.0.block.2.time_stack.in_layers.2.bias": "blocks.31.conv1.bias",
|
403 |
+
"first_stage_model.decoder.up.0.block.2.time_stack.in_layers.2.weight": "blocks.31.conv1.weight",
|
404 |
+
"first_stage_model.decoder.up.0.block.2.time_stack.out_layers.0.bias": "blocks.31.norm2.bias",
|
405 |
+
"first_stage_model.decoder.up.0.block.2.time_stack.out_layers.0.weight": "blocks.31.norm2.weight",
|
406 |
+
"first_stage_model.decoder.up.0.block.2.time_stack.out_layers.3.bias": "blocks.31.conv2.bias",
|
407 |
+
"first_stage_model.decoder.up.0.block.2.time_stack.out_layers.3.weight": "blocks.31.conv2.weight",
|
408 |
+
"first_stage_model.decoder.up.1.block.0.conv1.bias": "blocks.19.conv1.bias",
|
409 |
+
"first_stage_model.decoder.up.1.block.0.conv1.weight": "blocks.19.conv1.weight",
|
410 |
+
"first_stage_model.decoder.up.1.block.0.conv2.bias": "blocks.19.conv2.bias",
|
411 |
+
"first_stage_model.decoder.up.1.block.0.conv2.weight": "blocks.19.conv2.weight",
|
412 |
+
"first_stage_model.decoder.up.1.block.0.mix_factor": "blocks.20.mix_factor",
|
413 |
+
"first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "blocks.19.conv_shortcut.bias",
|
414 |
+
"first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "blocks.19.conv_shortcut.weight",
|
415 |
+
"first_stage_model.decoder.up.1.block.0.norm1.bias": "blocks.19.norm1.bias",
|
416 |
+
"first_stage_model.decoder.up.1.block.0.norm1.weight": "blocks.19.norm1.weight",
|
417 |
+
"first_stage_model.decoder.up.1.block.0.norm2.bias": "blocks.19.norm2.bias",
|
418 |
+
"first_stage_model.decoder.up.1.block.0.norm2.weight": "blocks.19.norm2.weight",
|
419 |
+
"first_stage_model.decoder.up.1.block.0.time_stack.in_layers.0.bias": "blocks.20.norm1.bias",
|
420 |
+
"first_stage_model.decoder.up.1.block.0.time_stack.in_layers.0.weight": "blocks.20.norm1.weight",
|
421 |
+
"first_stage_model.decoder.up.1.block.0.time_stack.in_layers.2.bias": "blocks.20.conv1.bias",
|
422 |
+
"first_stage_model.decoder.up.1.block.0.time_stack.in_layers.2.weight": "blocks.20.conv1.weight",
|
423 |
+
"first_stage_model.decoder.up.1.block.0.time_stack.out_layers.0.bias": "blocks.20.norm2.bias",
|
424 |
+
"first_stage_model.decoder.up.1.block.0.time_stack.out_layers.0.weight": "blocks.20.norm2.weight",
|
425 |
+
"first_stage_model.decoder.up.1.block.0.time_stack.out_layers.3.bias": "blocks.20.conv2.bias",
|
426 |
+
"first_stage_model.decoder.up.1.block.0.time_stack.out_layers.3.weight": "blocks.20.conv2.weight",
|
427 |
+
"first_stage_model.decoder.up.1.block.1.conv1.bias": "blocks.21.conv1.bias",
|
428 |
+
"first_stage_model.decoder.up.1.block.1.conv1.weight": "blocks.21.conv1.weight",
|
429 |
+
"first_stage_model.decoder.up.1.block.1.conv2.bias": "blocks.21.conv2.bias",
|
430 |
+
"first_stage_model.decoder.up.1.block.1.conv2.weight": "blocks.21.conv2.weight",
|
431 |
+
"first_stage_model.decoder.up.1.block.1.mix_factor": "blocks.22.mix_factor",
|
432 |
+
"first_stage_model.decoder.up.1.block.1.norm1.bias": "blocks.21.norm1.bias",
|
433 |
+
"first_stage_model.decoder.up.1.block.1.norm1.weight": "blocks.21.norm1.weight",
|
434 |
+
"first_stage_model.decoder.up.1.block.1.norm2.bias": "blocks.21.norm2.bias",
|
435 |
+
"first_stage_model.decoder.up.1.block.1.norm2.weight": "blocks.21.norm2.weight",
|
436 |
+
"first_stage_model.decoder.up.1.block.1.time_stack.in_layers.0.bias": "blocks.22.norm1.bias",
|
437 |
+
"first_stage_model.decoder.up.1.block.1.time_stack.in_layers.0.weight": "blocks.22.norm1.weight",
|
438 |
+
"first_stage_model.decoder.up.1.block.1.time_stack.in_layers.2.bias": "blocks.22.conv1.bias",
|
439 |
+
"first_stage_model.decoder.up.1.block.1.time_stack.in_layers.2.weight": "blocks.22.conv1.weight",
|
440 |
+
"first_stage_model.decoder.up.1.block.1.time_stack.out_layers.0.bias": "blocks.22.norm2.bias",
|
441 |
+
"first_stage_model.decoder.up.1.block.1.time_stack.out_layers.0.weight": "blocks.22.norm2.weight",
|
442 |
+
"first_stage_model.decoder.up.1.block.1.time_stack.out_layers.3.bias": "blocks.22.conv2.bias",
|
443 |
+
"first_stage_model.decoder.up.1.block.1.time_stack.out_layers.3.weight": "blocks.22.conv2.weight",
|
444 |
+
"first_stage_model.decoder.up.1.block.2.conv1.bias": "blocks.23.conv1.bias",
|
445 |
+
"first_stage_model.decoder.up.1.block.2.conv1.weight": "blocks.23.conv1.weight",
|
446 |
+
"first_stage_model.decoder.up.1.block.2.conv2.bias": "blocks.23.conv2.bias",
|
447 |
+
"first_stage_model.decoder.up.1.block.2.conv2.weight": "blocks.23.conv2.weight",
|
448 |
+
"first_stage_model.decoder.up.1.block.2.mix_factor": "blocks.24.mix_factor",
|
449 |
+
"first_stage_model.decoder.up.1.block.2.norm1.bias": "blocks.23.norm1.bias",
|
450 |
+
"first_stage_model.decoder.up.1.block.2.norm1.weight": "blocks.23.norm1.weight",
|
451 |
+
"first_stage_model.decoder.up.1.block.2.norm2.bias": "blocks.23.norm2.bias",
|
452 |
+
"first_stage_model.decoder.up.1.block.2.norm2.weight": "blocks.23.norm2.weight",
|
453 |
+
"first_stage_model.decoder.up.1.block.2.time_stack.in_layers.0.bias": "blocks.24.norm1.bias",
|
454 |
+
"first_stage_model.decoder.up.1.block.2.time_stack.in_layers.0.weight": "blocks.24.norm1.weight",
|
455 |
+
"first_stage_model.decoder.up.1.block.2.time_stack.in_layers.2.bias": "blocks.24.conv1.bias",
|
456 |
+
"first_stage_model.decoder.up.1.block.2.time_stack.in_layers.2.weight": "blocks.24.conv1.weight",
|
457 |
+
"first_stage_model.decoder.up.1.block.2.time_stack.out_layers.0.bias": "blocks.24.norm2.bias",
|
458 |
+
"first_stage_model.decoder.up.1.block.2.time_stack.out_layers.0.weight": "blocks.24.norm2.weight",
|
459 |
+
"first_stage_model.decoder.up.1.block.2.time_stack.out_layers.3.bias": "blocks.24.conv2.bias",
|
460 |
+
"first_stage_model.decoder.up.1.block.2.time_stack.out_layers.3.weight": "blocks.24.conv2.weight",
|
461 |
+
"first_stage_model.decoder.up.1.upsample.conv.bias": "blocks.25.conv.bias",
|
462 |
+
"first_stage_model.decoder.up.1.upsample.conv.weight": "blocks.25.conv.weight",
|
463 |
+
"first_stage_model.decoder.up.2.block.0.conv1.bias": "blocks.12.conv1.bias",
|
464 |
+
"first_stage_model.decoder.up.2.block.0.conv1.weight": "blocks.12.conv1.weight",
|
465 |
+
"first_stage_model.decoder.up.2.block.0.conv2.bias": "blocks.12.conv2.bias",
|
466 |
+
"first_stage_model.decoder.up.2.block.0.conv2.weight": "blocks.12.conv2.weight",
|
467 |
+
"first_stage_model.decoder.up.2.block.0.mix_factor": "blocks.13.mix_factor",
|
468 |
+
"first_stage_model.decoder.up.2.block.0.norm1.bias": "blocks.12.norm1.bias",
|
469 |
+
"first_stage_model.decoder.up.2.block.0.norm1.weight": "blocks.12.norm1.weight",
|
470 |
+
"first_stage_model.decoder.up.2.block.0.norm2.bias": "blocks.12.norm2.bias",
|
471 |
+
"first_stage_model.decoder.up.2.block.0.norm2.weight": "blocks.12.norm2.weight",
|
472 |
+
"first_stage_model.decoder.up.2.block.0.time_stack.in_layers.0.bias": "blocks.13.norm1.bias",
|
473 |
+
"first_stage_model.decoder.up.2.block.0.time_stack.in_layers.0.weight": "blocks.13.norm1.weight",
|
474 |
+
"first_stage_model.decoder.up.2.block.0.time_stack.in_layers.2.bias": "blocks.13.conv1.bias",
|
475 |
+
"first_stage_model.decoder.up.2.block.0.time_stack.in_layers.2.weight": "blocks.13.conv1.weight",
|
476 |
+
"first_stage_model.decoder.up.2.block.0.time_stack.out_layers.0.bias": "blocks.13.norm2.bias",
|
477 |
+
"first_stage_model.decoder.up.2.block.0.time_stack.out_layers.0.weight": "blocks.13.norm2.weight",
|
478 |
+
"first_stage_model.decoder.up.2.block.0.time_stack.out_layers.3.bias": "blocks.13.conv2.bias",
|
479 |
+
"first_stage_model.decoder.up.2.block.0.time_stack.out_layers.3.weight": "blocks.13.conv2.weight",
|
480 |
+
"first_stage_model.decoder.up.2.block.1.conv1.bias": "blocks.14.conv1.bias",
|
481 |
+
"first_stage_model.decoder.up.2.block.1.conv1.weight": "blocks.14.conv1.weight",
|
482 |
+
"first_stage_model.decoder.up.2.block.1.conv2.bias": "blocks.14.conv2.bias",
|
483 |
+
"first_stage_model.decoder.up.2.block.1.conv2.weight": "blocks.14.conv2.weight",
|
484 |
+
"first_stage_model.decoder.up.2.block.1.mix_factor": "blocks.15.mix_factor",
|
485 |
+
"first_stage_model.decoder.up.2.block.1.norm1.bias": "blocks.14.norm1.bias",
|
486 |
+
"first_stage_model.decoder.up.2.block.1.norm1.weight": "blocks.14.norm1.weight",
|
487 |
+
"first_stage_model.decoder.up.2.block.1.norm2.bias": "blocks.14.norm2.bias",
|
488 |
+
"first_stage_model.decoder.up.2.block.1.norm2.weight": "blocks.14.norm2.weight",
|
489 |
+
"first_stage_model.decoder.up.2.block.1.time_stack.in_layers.0.bias": "blocks.15.norm1.bias",
|
490 |
+
"first_stage_model.decoder.up.2.block.1.time_stack.in_layers.0.weight": "blocks.15.norm1.weight",
|
491 |
+
"first_stage_model.decoder.up.2.block.1.time_stack.in_layers.2.bias": "blocks.15.conv1.bias",
|
492 |
+
"first_stage_model.decoder.up.2.block.1.time_stack.in_layers.2.weight": "blocks.15.conv1.weight",
|
493 |
+
"first_stage_model.decoder.up.2.block.1.time_stack.out_layers.0.bias": "blocks.15.norm2.bias",
|
494 |
+
"first_stage_model.decoder.up.2.block.1.time_stack.out_layers.0.weight": "blocks.15.norm2.weight",
|
495 |
+
"first_stage_model.decoder.up.2.block.1.time_stack.out_layers.3.bias": "blocks.15.conv2.bias",
|
496 |
+
"first_stage_model.decoder.up.2.block.1.time_stack.out_layers.3.weight": "blocks.15.conv2.weight",
|
497 |
+
"first_stage_model.decoder.up.2.block.2.conv1.bias": "blocks.16.conv1.bias",
|
498 |
+
"first_stage_model.decoder.up.2.block.2.conv1.weight": "blocks.16.conv1.weight",
|
499 |
+
"first_stage_model.decoder.up.2.block.2.conv2.bias": "blocks.16.conv2.bias",
|
500 |
+
"first_stage_model.decoder.up.2.block.2.conv2.weight": "blocks.16.conv2.weight",
|
501 |
+
"first_stage_model.decoder.up.2.block.2.mix_factor": "blocks.17.mix_factor",
|
502 |
+
"first_stage_model.decoder.up.2.block.2.norm1.bias": "blocks.16.norm1.bias",
|
503 |
+
"first_stage_model.decoder.up.2.block.2.norm1.weight": "blocks.16.norm1.weight",
|
504 |
+
"first_stage_model.decoder.up.2.block.2.norm2.bias": "blocks.16.norm2.bias",
|
505 |
+
"first_stage_model.decoder.up.2.block.2.norm2.weight": "blocks.16.norm2.weight",
|
506 |
+
"first_stage_model.decoder.up.2.block.2.time_stack.in_layers.0.bias": "blocks.17.norm1.bias",
|
507 |
+
"first_stage_model.decoder.up.2.block.2.time_stack.in_layers.0.weight": "blocks.17.norm1.weight",
|
508 |
+
"first_stage_model.decoder.up.2.block.2.time_stack.in_layers.2.bias": "blocks.17.conv1.bias",
|
509 |
+
"first_stage_model.decoder.up.2.block.2.time_stack.in_layers.2.weight": "blocks.17.conv1.weight",
|
510 |
+
"first_stage_model.decoder.up.2.block.2.time_stack.out_layers.0.bias": "blocks.17.norm2.bias",
|
511 |
+
"first_stage_model.decoder.up.2.block.2.time_stack.out_layers.0.weight": "blocks.17.norm2.weight",
|
512 |
+
"first_stage_model.decoder.up.2.block.2.time_stack.out_layers.3.bias": "blocks.17.conv2.bias",
|
513 |
+
"first_stage_model.decoder.up.2.block.2.time_stack.out_layers.3.weight": "blocks.17.conv2.weight",
|
514 |
+
"first_stage_model.decoder.up.2.upsample.conv.bias": "blocks.18.conv.bias",
|
515 |
+
"first_stage_model.decoder.up.2.upsample.conv.weight": "blocks.18.conv.weight",
|
516 |
+
"first_stage_model.decoder.up.3.block.0.conv1.bias": "blocks.5.conv1.bias",
|
517 |
+
"first_stage_model.decoder.up.3.block.0.conv1.weight": "blocks.5.conv1.weight",
|
518 |
+
"first_stage_model.decoder.up.3.block.0.conv2.bias": "blocks.5.conv2.bias",
|
519 |
+
"first_stage_model.decoder.up.3.block.0.conv2.weight": "blocks.5.conv2.weight",
|
520 |
+
"first_stage_model.decoder.up.3.block.0.mix_factor": "blocks.6.mix_factor",
|
521 |
+
"first_stage_model.decoder.up.3.block.0.norm1.bias": "blocks.5.norm1.bias",
|
522 |
+
"first_stage_model.decoder.up.3.block.0.norm1.weight": "blocks.5.norm1.weight",
|
523 |
+
"first_stage_model.decoder.up.3.block.0.norm2.bias": "blocks.5.norm2.bias",
|
524 |
+
"first_stage_model.decoder.up.3.block.0.norm2.weight": "blocks.5.norm2.weight",
|
525 |
+
"first_stage_model.decoder.up.3.block.0.time_stack.in_layers.0.bias": "blocks.6.norm1.bias",
|
526 |
+
"first_stage_model.decoder.up.3.block.0.time_stack.in_layers.0.weight": "blocks.6.norm1.weight",
|
527 |
+
"first_stage_model.decoder.up.3.block.0.time_stack.in_layers.2.bias": "blocks.6.conv1.bias",
|
528 |
+
"first_stage_model.decoder.up.3.block.0.time_stack.in_layers.2.weight": "blocks.6.conv1.weight",
|
529 |
+
"first_stage_model.decoder.up.3.block.0.time_stack.out_layers.0.bias": "blocks.6.norm2.bias",
|
530 |
+
"first_stage_model.decoder.up.3.block.0.time_stack.out_layers.0.weight": "blocks.6.norm2.weight",
|
531 |
+
"first_stage_model.decoder.up.3.block.0.time_stack.out_layers.3.bias": "blocks.6.conv2.bias",
|
532 |
+
"first_stage_model.decoder.up.3.block.0.time_stack.out_layers.3.weight": "blocks.6.conv2.weight",
|
533 |
+
"first_stage_model.decoder.up.3.block.1.conv1.bias": "blocks.7.conv1.bias",
|
534 |
+
"first_stage_model.decoder.up.3.block.1.conv1.weight": "blocks.7.conv1.weight",
|
535 |
+
"first_stage_model.decoder.up.3.block.1.conv2.bias": "blocks.7.conv2.bias",
|
536 |
+
"first_stage_model.decoder.up.3.block.1.conv2.weight": "blocks.7.conv2.weight",
|
537 |
+
"first_stage_model.decoder.up.3.block.1.mix_factor": "blocks.8.mix_factor",
|
538 |
+
"first_stage_model.decoder.up.3.block.1.norm1.bias": "blocks.7.norm1.bias",
|
539 |
+
"first_stage_model.decoder.up.3.block.1.norm1.weight": "blocks.7.norm1.weight",
|
540 |
+
"first_stage_model.decoder.up.3.block.1.norm2.bias": "blocks.7.norm2.bias",
|
541 |
+
"first_stage_model.decoder.up.3.block.1.norm2.weight": "blocks.7.norm2.weight",
|
542 |
+
"first_stage_model.decoder.up.3.block.1.time_stack.in_layers.0.bias": "blocks.8.norm1.bias",
|
543 |
+
"first_stage_model.decoder.up.3.block.1.time_stack.in_layers.0.weight": "blocks.8.norm1.weight",
|
544 |
+
"first_stage_model.decoder.up.3.block.1.time_stack.in_layers.2.bias": "blocks.8.conv1.bias",
|
545 |
+
"first_stage_model.decoder.up.3.block.1.time_stack.in_layers.2.weight": "blocks.8.conv1.weight",
|
546 |
+
"first_stage_model.decoder.up.3.block.1.time_stack.out_layers.0.bias": "blocks.8.norm2.bias",
|
547 |
+
"first_stage_model.decoder.up.3.block.1.time_stack.out_layers.0.weight": "blocks.8.norm2.weight",
|
548 |
+
"first_stage_model.decoder.up.3.block.1.time_stack.out_layers.3.bias": "blocks.8.conv2.bias",
|
549 |
+
"first_stage_model.decoder.up.3.block.1.time_stack.out_layers.3.weight": "blocks.8.conv2.weight",
|
550 |
+
"first_stage_model.decoder.up.3.block.2.conv1.bias": "blocks.9.conv1.bias",
|
551 |
+
"first_stage_model.decoder.up.3.block.2.conv1.weight": "blocks.9.conv1.weight",
|
552 |
+
"first_stage_model.decoder.up.3.block.2.conv2.bias": "blocks.9.conv2.bias",
|
553 |
+
"first_stage_model.decoder.up.3.block.2.conv2.weight": "blocks.9.conv2.weight",
|
554 |
+
"first_stage_model.decoder.up.3.block.2.mix_factor": "blocks.10.mix_factor",
|
555 |
+
"first_stage_model.decoder.up.3.block.2.norm1.bias": "blocks.9.norm1.bias",
|
556 |
+
"first_stage_model.decoder.up.3.block.2.norm1.weight": "blocks.9.norm1.weight",
|
557 |
+
"first_stage_model.decoder.up.3.block.2.norm2.bias": "blocks.9.norm2.bias",
|
558 |
+
"first_stage_model.decoder.up.3.block.2.norm2.weight": "blocks.9.norm2.weight",
|
559 |
+
"first_stage_model.decoder.up.3.block.2.time_stack.in_layers.0.bias": "blocks.10.norm1.bias",
|
560 |
+
"first_stage_model.decoder.up.3.block.2.time_stack.in_layers.0.weight": "blocks.10.norm1.weight",
|
561 |
+
"first_stage_model.decoder.up.3.block.2.time_stack.in_layers.2.bias": "blocks.10.conv1.bias",
|
562 |
+
"first_stage_model.decoder.up.3.block.2.time_stack.in_layers.2.weight": "blocks.10.conv1.weight",
|
563 |
+
"first_stage_model.decoder.up.3.block.2.time_stack.out_layers.0.bias": "blocks.10.norm2.bias",
|
564 |
+
"first_stage_model.decoder.up.3.block.2.time_stack.out_layers.0.weight": "blocks.10.norm2.weight",
|
565 |
+
"first_stage_model.decoder.up.3.block.2.time_stack.out_layers.3.bias": "blocks.10.conv2.bias",
|
566 |
+
"first_stage_model.decoder.up.3.block.2.time_stack.out_layers.3.weight": "blocks.10.conv2.weight",
|
567 |
+
"first_stage_model.decoder.up.3.upsample.conv.bias": "blocks.11.conv.bias",
|
568 |
+
"first_stage_model.decoder.up.3.upsample.conv.weight": "blocks.11.conv.weight",
|
569 |
+
}
|
570 |
+
state_dict_ = {}
|
571 |
+
for name in state_dict:
|
572 |
+
if name in rename_dict:
|
573 |
+
param = state_dict[name]
|
574 |
+
if "blocks.2.transformer_blocks.0" in rename_dict[name]:
|
575 |
+
param = param.squeeze()
|
576 |
+
state_dict_[rename_dict[name]] = param
|
577 |
+
return state_dict_
|
diffsynth/models/svd_vae_encoder.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder
|
2 |
+
|
3 |
+
|
4 |
+
class SVDVAEEncoder(SDVAEEncoder):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
self.scaling_factor = 0.13025
|
8 |
+
|
9 |
+
def state_dict_converter(self):
|
10 |
+
return SVDVAEEncoderStateDictConverter()
|
11 |
+
|
12 |
+
|
13 |
+
class SVDVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
|
14 |
+
def __init__(self):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
def from_diffusers(self, state_dict):
|
18 |
+
return super().from_diffusers(state_dict)
|
19 |
+
|
20 |
+
def from_civitai(self, state_dict):
|
21 |
+
rename_dict = {
|
22 |
+
"conditioner.embedders.3.encoder.encoder.conv_in.bias": "conv_in.bias",
|
23 |
+
"conditioner.embedders.3.encoder.encoder.conv_in.weight": "conv_in.weight",
|
24 |
+
"conditioner.embedders.3.encoder.encoder.conv_out.bias": "conv_out.bias",
|
25 |
+
"conditioner.embedders.3.encoder.encoder.conv_out.weight": "conv_out.weight",
|
26 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
|
27 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
|
28 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
|
29 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
|
30 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
|
31 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
|
32 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
|
33 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
|
34 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
|
35 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
|
36 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
|
37 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
|
38 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
|
39 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
|
40 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
|
41 |
+
"conditioner.embedders.3.encoder.encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
|
42 |
+
"conditioner.embedders.3.encoder.encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
|
43 |
+
"conditioner.embedders.3.encoder.encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
|
44 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
|
45 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
|
46 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
|
47 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
|
48 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
|
49 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
|
50 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
|
51 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
|
52 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
|
53 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
|
54 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
|
55 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
|
56 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
|
57 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
|
58 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
|
59 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
|
60 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
|
61 |
+
"conditioner.embedders.3.encoder.encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
|
62 |
+
"conditioner.embedders.3.encoder.encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
|
63 |
+
"conditioner.embedders.3.encoder.encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
|
64 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
|
65 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
|
66 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
|
67 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
|
68 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
|
69 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
|
70 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
|
71 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
|
72 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
|
73 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
|
74 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
|
75 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
|
76 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
|
77 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
|
78 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
|
79 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
|
80 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
|
81 |
+
"conditioner.embedders.3.encoder.encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
|
82 |
+
"conditioner.embedders.3.encoder.encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
|
83 |
+
"conditioner.embedders.3.encoder.encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
|
84 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
|
85 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
|
86 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
|
87 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
|
88 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
|
89 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
|
90 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
|
91 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
|
92 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
|
93 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
|
94 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
|
95 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
|
96 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
|
97 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
|
98 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
|
99 |
+
"conditioner.embedders.3.encoder.encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
|
100 |
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
|
101 |
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
|
102 |
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
|
103 |
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
|
104 |
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
|
105 |
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
|
106 |
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
|
107 |
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
|
108 |
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
|
109 |
+
"conditioner.embedders.3.encoder.encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
|
110 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
|
111 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
|
112 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
|
113 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
|
114 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
|
115 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
|
116 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
|
117 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
|
118 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
|
119 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
|
120 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
|
121 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
|
122 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
|
123 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
|
124 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
|
125 |
+
"conditioner.embedders.3.encoder.encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
|
126 |
+
"conditioner.embedders.3.encoder.encoder.norm_out.bias": "conv_norm_out.bias",
|
127 |
+
"conditioner.embedders.3.encoder.encoder.norm_out.weight": "conv_norm_out.weight",
|
128 |
+
"conditioner.embedders.3.encoder.quant_conv.bias": "quant_conv.bias",
|
129 |
+
"conditioner.embedders.3.encoder.quant_conv.weight": "quant_conv.weight",
|
130 |
+
}
|
131 |
+
state_dict_ = {}
|
132 |
+
for name in state_dict:
|
133 |
+
if name in rename_dict:
|
134 |
+
param = state_dict[name]
|
135 |
+
if "transformer_blocks" in rename_dict[name]:
|
136 |
+
param = param.squeeze()
|
137 |
+
state_dict_[rename_dict[name]] = param
|
138 |
+
return state_dict_
|
diffsynth/models/tiler.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from einops import rearrange, repeat
|
3 |
+
|
4 |
+
|
5 |
+
class TileWorker:
|
6 |
+
def __init__(self):
|
7 |
+
pass
|
8 |
+
|
9 |
+
|
10 |
+
def mask(self, height, width, border_width):
|
11 |
+
# Create a mask with shape (height, width).
|
12 |
+
# The centre area is filled with 1, and the border line is filled with values in range (0, 1].
|
13 |
+
x = torch.arange(height).repeat(width, 1).T
|
14 |
+
y = torch.arange(width).repeat(height, 1)
|
15 |
+
mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
|
16 |
+
mask = (mask / border_width).clip(0, 1)
|
17 |
+
return mask
|
18 |
+
|
19 |
+
|
20 |
+
def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):
|
21 |
+
# Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)
|
22 |
+
batch_size, channel, _, _ = model_input.shape
|
23 |
+
model_input = model_input.to(device=tile_device, dtype=tile_dtype)
|
24 |
+
unfold_operator = torch.nn.Unfold(
|
25 |
+
kernel_size=(tile_size, tile_size),
|
26 |
+
stride=(tile_stride, tile_stride)
|
27 |
+
)
|
28 |
+
model_input = unfold_operator(model_input)
|
29 |
+
model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))
|
30 |
+
|
31 |
+
return model_input
|
32 |
+
|
33 |
+
|
34 |
+
def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype):
|
35 |
+
# Call y=forward_fn(x) for each tile
|
36 |
+
tile_num = model_input.shape[-1]
|
37 |
+
model_output_stack = []
|
38 |
+
|
39 |
+
for tile_id in range(0, tile_num, tile_batch_size):
|
40 |
+
|
41 |
+
# process input
|
42 |
+
tile_id_ = min(tile_id + tile_batch_size, tile_num)
|
43 |
+
x = model_input[:, :, :, :, tile_id: tile_id_]
|
44 |
+
x = x.to(device=inference_device, dtype=inference_dtype)
|
45 |
+
x = rearrange(x, "b c h w n -> (n b) c h w")
|
46 |
+
|
47 |
+
# process output
|
48 |
+
y = forward_fn(x)
|
49 |
+
y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id)
|
50 |
+
y = y.to(device=tile_device, dtype=tile_dtype)
|
51 |
+
model_output_stack.append(y)
|
52 |
+
|
53 |
+
model_output = torch.concat(model_output_stack, dim=-1)
|
54 |
+
return model_output
|
55 |
+
|
56 |
+
|
57 |
+
def io_scale(self, model_output, tile_size):
|
58 |
+
# Determine the size modification happend in forward_fn
|
59 |
+
# We only consider the same scale on height and width.
|
60 |
+
io_scale = model_output.shape[2] / tile_size
|
61 |
+
return io_scale
|
62 |
+
|
63 |
+
|
64 |
+
def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype):
|
65 |
+
# The reversed function of tile
|
66 |
+
mask = self.mask(tile_size, tile_size, border_width)
|
67 |
+
mask = mask.to(device=tile_device, dtype=tile_dtype)
|
68 |
+
mask = rearrange(mask, "h w -> 1 1 h w 1")
|
69 |
+
model_output = model_output * mask
|
70 |
+
|
71 |
+
fold_operator = torch.nn.Fold(
|
72 |
+
output_size=(height, width),
|
73 |
+
kernel_size=(tile_size, tile_size),
|
74 |
+
stride=(tile_stride, tile_stride)
|
75 |
+
)
|
76 |
+
mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1])
|
77 |
+
model_output = rearrange(model_output, "b c h w n -> b (c h w) n")
|
78 |
+
model_output = fold_operator(model_output) / fold_operator(mask)
|
79 |
+
|
80 |
+
return model_output
|
81 |
+
|
82 |
+
|
83 |
+
def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
|
84 |
+
# Prepare
|
85 |
+
inference_device, inference_dtype = model_input.device, model_input.dtype
|
86 |
+
height, width = model_input.shape[2], model_input.shape[3]
|
87 |
+
border_width = int(tile_stride*0.5) if border_width is None else border_width
|
88 |
+
|
89 |
+
# tile
|
90 |
+
model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype)
|
91 |
+
|
92 |
+
# inference
|
93 |
+
model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype)
|
94 |
+
|
95 |
+
# resize
|
96 |
+
io_scale = self.io_scale(model_output, tile_size)
|
97 |
+
height, width = int(height*io_scale), int(width*io_scale)
|
98 |
+
tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale)
|
99 |
+
border_width = int(border_width*io_scale)
|
100 |
+
|
101 |
+
# untile
|
102 |
+
model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype)
|
103 |
+
|
104 |
+
# Done!
|
105 |
+
model_output = model_output.to(device=inference_device, dtype=inference_dtype)
|
106 |
+
return model_output
|
diffsynth/pipelines/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .stable_diffusion import SDImagePipeline
|
2 |
+
from .stable_diffusion_xl import SDXLImagePipeline
|
3 |
+
from .stable_diffusion_video import SDVideoPipeline, SDVideoPipelineRunner
|
4 |
+
from .stable_diffusion_xl_video import SDXLVideoPipeline
|
5 |
+
from .stable_video_diffusion import SVDVideoPipeline
|
6 |
+
from .hunyuan_dit import HunyuanDiTImagePipeline
|
diffsynth/pipelines/dancer.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from ..models import SDUNet, SDMotionModel, SDXLUNet, SDXLMotionModel
|
3 |
+
from ..models.sd_unet import PushBlock, PopBlock
|
4 |
+
from ..controlnets import MultiControlNetManager
|
5 |
+
|
6 |
+
|
7 |
+
def lets_dance(
|
8 |
+
unet: SDUNet,
|
9 |
+
motion_modules: SDMotionModel = None,
|
10 |
+
controlnet: MultiControlNetManager = None,
|
11 |
+
sample = None,
|
12 |
+
timestep = None,
|
13 |
+
encoder_hidden_states = None,
|
14 |
+
ipadapter_kwargs_list = {},
|
15 |
+
controlnet_frames = None,
|
16 |
+
unet_batch_size = 1,
|
17 |
+
controlnet_batch_size = 1,
|
18 |
+
cross_frame_attention = False,
|
19 |
+
tiled=False,
|
20 |
+
tile_size=64,
|
21 |
+
tile_stride=32,
|
22 |
+
device = "cuda",
|
23 |
+
vram_limit_level = 0,
|
24 |
+
):
|
25 |
+
# 1. ControlNet
|
26 |
+
# This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride.
|
27 |
+
# I leave it here because I intend to do something interesting on the ControlNets.
|
28 |
+
controlnet_insert_block_id = 30
|
29 |
+
if controlnet is not None and controlnet_frames is not None:
|
30 |
+
res_stacks = []
|
31 |
+
# process controlnet frames with batch
|
32 |
+
for batch_id in range(0, sample.shape[0], controlnet_batch_size):
|
33 |
+
batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0])
|
34 |
+
res_stack = controlnet(
|
35 |
+
sample[batch_id: batch_id_],
|
36 |
+
timestep,
|
37 |
+
encoder_hidden_states[batch_id: batch_id_],
|
38 |
+
controlnet_frames[:, batch_id: batch_id_],
|
39 |
+
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
40 |
+
)
|
41 |
+
if vram_limit_level >= 1:
|
42 |
+
res_stack = [res.cpu() for res in res_stack]
|
43 |
+
res_stacks.append(res_stack)
|
44 |
+
# concat the residual
|
45 |
+
additional_res_stack = []
|
46 |
+
for i in range(len(res_stacks[0])):
|
47 |
+
res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0)
|
48 |
+
additional_res_stack.append(res)
|
49 |
+
else:
|
50 |
+
additional_res_stack = None
|
51 |
+
|
52 |
+
# 2. time
|
53 |
+
time_emb = unet.time_proj(timestep[None]).to(sample.dtype)
|
54 |
+
time_emb = unet.time_embedding(time_emb)
|
55 |
+
|
56 |
+
# 3. pre-process
|
57 |
+
height, width = sample.shape[2], sample.shape[3]
|
58 |
+
hidden_states = unet.conv_in(sample)
|
59 |
+
text_emb = encoder_hidden_states
|
60 |
+
res_stack = [hidden_states.cpu() if vram_limit_level>=1 else hidden_states]
|
61 |
+
|
62 |
+
# 4. blocks
|
63 |
+
for block_id, block in enumerate(unet.blocks):
|
64 |
+
# 4.1 UNet
|
65 |
+
if isinstance(block, PushBlock):
|
66 |
+
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
67 |
+
if vram_limit_level>=1:
|
68 |
+
res_stack[-1] = res_stack[-1].cpu()
|
69 |
+
elif isinstance(block, PopBlock):
|
70 |
+
if vram_limit_level>=1:
|
71 |
+
res_stack[-1] = res_stack[-1].to(device)
|
72 |
+
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
73 |
+
else:
|
74 |
+
hidden_states_input = hidden_states
|
75 |
+
hidden_states_output = []
|
76 |
+
for batch_id in range(0, sample.shape[0], unet_batch_size):
|
77 |
+
batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
|
78 |
+
hidden_states, _, _, _ = block(
|
79 |
+
hidden_states_input[batch_id: batch_id_],
|
80 |
+
time_emb,
|
81 |
+
text_emb[batch_id: batch_id_],
|
82 |
+
res_stack,
|
83 |
+
cross_frame_attention=cross_frame_attention,
|
84 |
+
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}),
|
85 |
+
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
86 |
+
)
|
87 |
+
hidden_states_output.append(hidden_states)
|
88 |
+
hidden_states = torch.concat(hidden_states_output, dim=0)
|
89 |
+
# 4.2 AnimateDiff
|
90 |
+
if motion_modules is not None:
|
91 |
+
if block_id in motion_modules.call_block_id:
|
92 |
+
motion_module_id = motion_modules.call_block_id[block_id]
|
93 |
+
hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](
|
94 |
+
hidden_states, time_emb, text_emb, res_stack,
|
95 |
+
batch_size=1
|
96 |
+
)
|
97 |
+
# 4.3 ControlNet
|
98 |
+
if block_id == controlnet_insert_block_id and additional_res_stack is not None:
|
99 |
+
hidden_states += additional_res_stack.pop().to(device)
|
100 |
+
if vram_limit_level>=1:
|
101 |
+
res_stack = [(res.to(device) + additional_res.to(device)).cpu() for res, additional_res in zip(res_stack, additional_res_stack)]
|
102 |
+
else:
|
103 |
+
res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
|
104 |
+
|
105 |
+
# 5. output
|
106 |
+
hidden_states = unet.conv_norm_out(hidden_states)
|
107 |
+
hidden_states = unet.conv_act(hidden_states)
|
108 |
+
hidden_states = unet.conv_out(hidden_states)
|
109 |
+
|
110 |
+
return hidden_states
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
def lets_dance_xl(
|
116 |
+
unet: SDXLUNet,
|
117 |
+
motion_modules: SDXLMotionModel = None,
|
118 |
+
controlnet: MultiControlNetManager = None,
|
119 |
+
sample = None,
|
120 |
+
add_time_id = None,
|
121 |
+
add_text_embeds = None,
|
122 |
+
timestep = None,
|
123 |
+
encoder_hidden_states = None,
|
124 |
+
ipadapter_kwargs_list = {},
|
125 |
+
controlnet_frames = None,
|
126 |
+
unet_batch_size = 1,
|
127 |
+
controlnet_batch_size = 1,
|
128 |
+
cross_frame_attention = False,
|
129 |
+
tiled=False,
|
130 |
+
tile_size=64,
|
131 |
+
tile_stride=32,
|
132 |
+
device = "cuda",
|
133 |
+
vram_limit_level = 0,
|
134 |
+
):
|
135 |
+
# 2. time
|
136 |
+
t_emb = unet.time_proj(timestep[None]).to(sample.dtype)
|
137 |
+
t_emb = unet.time_embedding(t_emb)
|
138 |
+
|
139 |
+
time_embeds = unet.add_time_proj(add_time_id)
|
140 |
+
time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1))
|
141 |
+
add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1)
|
142 |
+
add_embeds = add_embeds.to(sample.dtype)
|
143 |
+
add_embeds = unet.add_time_embedding(add_embeds)
|
144 |
+
|
145 |
+
time_emb = t_emb + add_embeds
|
146 |
+
|
147 |
+
# 3. pre-process
|
148 |
+
height, width = sample.shape[2], sample.shape[3]
|
149 |
+
hidden_states = unet.conv_in(sample)
|
150 |
+
text_emb = encoder_hidden_states
|
151 |
+
res_stack = [hidden_states]
|
152 |
+
|
153 |
+
# 4. blocks
|
154 |
+
for block_id, block in enumerate(unet.blocks):
|
155 |
+
hidden_states, time_emb, text_emb, res_stack = block(
|
156 |
+
hidden_states, time_emb, text_emb, res_stack,
|
157 |
+
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
158 |
+
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {})
|
159 |
+
)
|
160 |
+
# 4.2 AnimateDiff
|
161 |
+
if motion_modules is not None:
|
162 |
+
if block_id in motion_modules.call_block_id:
|
163 |
+
motion_module_id = motion_modules.call_block_id[block_id]
|
164 |
+
hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](
|
165 |
+
hidden_states, time_emb, text_emb, res_stack,
|
166 |
+
batch_size=1
|
167 |
+
)
|
168 |
+
|
169 |
+
# 5. output
|
170 |
+
hidden_states = unet.conv_norm_out(hidden_states)
|
171 |
+
hidden_states = unet.conv_act(hidden_states)
|
172 |
+
hidden_states = unet.conv_out(hidden_states)
|
173 |
+
|
174 |
+
return hidden_states
|
diffsynth/pipelines/hunyuan_dit.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..models.hunyuan_dit import HunyuanDiT
|
2 |
+
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
3 |
+
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
|
4 |
+
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
|
5 |
+
from ..models import ModelManager
|
6 |
+
from ..prompts import HunyuanDiTPrompter
|
7 |
+
from ..schedulers import EnhancedDDIMScheduler
|
8 |
+
import torch
|
9 |
+
from tqdm import tqdm
|
10 |
+
from PIL import Image
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
class ImageSizeManager:
|
16 |
+
def __init__(self):
|
17 |
+
pass
|
18 |
+
|
19 |
+
|
20 |
+
def _to_tuple(self, x):
|
21 |
+
if isinstance(x, int):
|
22 |
+
return x, x
|
23 |
+
else:
|
24 |
+
return x
|
25 |
+
|
26 |
+
|
27 |
+
def get_fill_resize_and_crop(self, src, tgt):
|
28 |
+
th, tw = self._to_tuple(tgt)
|
29 |
+
h, w = self._to_tuple(src)
|
30 |
+
|
31 |
+
tr = th / tw # base 分辨率
|
32 |
+
r = h / w # 目标分辨率
|
33 |
+
|
34 |
+
# resize
|
35 |
+
if r > tr:
|
36 |
+
resize_height = th
|
37 |
+
resize_width = int(round(th / h * w))
|
38 |
+
else:
|
39 |
+
resize_width = tw
|
40 |
+
resize_height = int(round(tw / w * h)) # 根据base分辨率,将目标分辨率resize下来
|
41 |
+
|
42 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
43 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
44 |
+
|
45 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
46 |
+
|
47 |
+
|
48 |
+
def get_meshgrid(self, start, *args):
|
49 |
+
if len(args) == 0:
|
50 |
+
# start is grid_size
|
51 |
+
num = self._to_tuple(start)
|
52 |
+
start = (0, 0)
|
53 |
+
stop = num
|
54 |
+
elif len(args) == 1:
|
55 |
+
# start is start, args[0] is stop, step is 1
|
56 |
+
start = self._to_tuple(start)
|
57 |
+
stop = self._to_tuple(args[0])
|
58 |
+
num = (stop[0] - start[0], stop[1] - start[1])
|
59 |
+
elif len(args) == 2:
|
60 |
+
# start is start, args[0] is stop, args[1] is num
|
61 |
+
start = self._to_tuple(start) # 左上角 eg: 12,0
|
62 |
+
stop = self._to_tuple(args[0]) # 右下角 eg: 20,32
|
63 |
+
num = self._to_tuple(args[1]) # 目标大小 eg: 32,124
|
64 |
+
else:
|
65 |
+
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
66 |
+
|
67 |
+
grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) # 12-20 中间差值32份 0-32 中间差值124份
|
68 |
+
grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
|
69 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
70 |
+
grid = np.stack(grid, axis=0) # [2, W, H]
|
71 |
+
return grid
|
72 |
+
|
73 |
+
|
74 |
+
def get_2d_rotary_pos_embed(self, embed_dim, start, *args, use_real=True):
|
75 |
+
grid = self.get_meshgrid(start, *args) # [2, H, w]
|
76 |
+
grid = grid.reshape([2, 1, *grid.shape[1:]]) # 返回一个采样矩阵 分辨率与目标分辨率一致
|
77 |
+
pos_embed = self.get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
|
78 |
+
return pos_embed
|
79 |
+
|
80 |
+
|
81 |
+
def get_2d_rotary_pos_embed_from_grid(self, embed_dim, grid, use_real=False):
|
82 |
+
assert embed_dim % 4 == 0
|
83 |
+
|
84 |
+
# use half of dimensions to encode grid_h
|
85 |
+
emb_h = self.get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
|
86 |
+
emb_w = self.get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
|
87 |
+
|
88 |
+
if use_real:
|
89 |
+
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
|
90 |
+
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
|
91 |
+
return cos, sin
|
92 |
+
else:
|
93 |
+
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
|
94 |
+
return emb
|
95 |
+
|
96 |
+
|
97 |
+
def get_1d_rotary_pos_embed(self, dim: int, pos, theta: float = 10000.0, use_real=False):
|
98 |
+
if isinstance(pos, int):
|
99 |
+
pos = np.arange(pos)
|
100 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
|
101 |
+
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
|
102 |
+
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
|
103 |
+
if use_real:
|
104 |
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
105 |
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
106 |
+
return freqs_cos, freqs_sin
|
107 |
+
else:
|
108 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
109 |
+
return freqs_cis
|
110 |
+
|
111 |
+
|
112 |
+
def calc_rope(self, height, width):
|
113 |
+
patch_size = 2
|
114 |
+
head_size = 88
|
115 |
+
th = height // 8 // patch_size
|
116 |
+
tw = width // 8 // patch_size
|
117 |
+
base_size = 512 // 8 // patch_size
|
118 |
+
start, stop = self.get_fill_resize_and_crop((th, tw), base_size)
|
119 |
+
sub_args = [start, stop, (th, tw)]
|
120 |
+
rope = self.get_2d_rotary_pos_embed(head_size, *sub_args)
|
121 |
+
return rope
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
class HunyuanDiTImagePipeline(torch.nn.Module):
|
126 |
+
|
127 |
+
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
128 |
+
super().__init__()
|
129 |
+
self.scheduler = EnhancedDDIMScheduler(prediction_type="v_prediction", beta_start=0.00085, beta_end=0.03)
|
130 |
+
self.prompter = HunyuanDiTPrompter()
|
131 |
+
self.device = device
|
132 |
+
self.torch_dtype = torch_dtype
|
133 |
+
self.image_size_manager = ImageSizeManager()
|
134 |
+
# models
|
135 |
+
self.text_encoder: HunyuanDiTCLIPTextEncoder = None
|
136 |
+
self.text_encoder_t5: HunyuanDiTT5TextEncoder = None
|
137 |
+
self.dit: HunyuanDiT = None
|
138 |
+
self.vae_decoder: SDXLVAEDecoder = None
|
139 |
+
self.vae_encoder: SDXLVAEEncoder = None
|
140 |
+
|
141 |
+
|
142 |
+
def fetch_main_models(self, model_manager: ModelManager):
|
143 |
+
self.text_encoder = model_manager.hunyuan_dit_clip_text_encoder
|
144 |
+
self.text_encoder_t5 = model_manager.hunyuan_dit_t5_text_encoder
|
145 |
+
self.dit = model_manager.hunyuan_dit
|
146 |
+
self.vae_decoder = model_manager.vae_decoder
|
147 |
+
self.vae_encoder = model_manager.vae_encoder
|
148 |
+
|
149 |
+
|
150 |
+
def fetch_prompter(self, model_manager: ModelManager):
|
151 |
+
self.prompter.load_from_model_manager(model_manager)
|
152 |
+
|
153 |
+
|
154 |
+
@staticmethod
|
155 |
+
def from_model_manager(model_manager: ModelManager):
|
156 |
+
pipe = HunyuanDiTImagePipeline(
|
157 |
+
device=model_manager.device,
|
158 |
+
torch_dtype=model_manager.torch_dtype,
|
159 |
+
)
|
160 |
+
pipe.fetch_main_models(model_manager)
|
161 |
+
pipe.fetch_prompter(model_manager)
|
162 |
+
return pipe
|
163 |
+
|
164 |
+
|
165 |
+
def preprocess_image(self, image):
|
166 |
+
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
167 |
+
return image
|
168 |
+
|
169 |
+
|
170 |
+
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
171 |
+
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
172 |
+
image = image.cpu().permute(1, 2, 0).numpy()
|
173 |
+
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
174 |
+
return image
|
175 |
+
|
176 |
+
|
177 |
+
def prepare_extra_input(self, height=1024, width=1024, tiled=False, tile_size=64, tile_stride=32, batch_size=1):
|
178 |
+
if tiled:
|
179 |
+
height, width = tile_size * 16, tile_size * 16
|
180 |
+
image_meta_size = torch.as_tensor([width, height, width, height, 0, 0]).to(device=self.device)
|
181 |
+
freqs_cis_img = self.image_size_manager.calc_rope(height, width)
|
182 |
+
image_meta_size = torch.stack([image_meta_size] * batch_size)
|
183 |
+
return {
|
184 |
+
"size_emb": image_meta_size,
|
185 |
+
"freq_cis_img": (freqs_cis_img[0].to(dtype=self.torch_dtype, device=self.device), freqs_cis_img[1].to(dtype=self.torch_dtype, device=self.device)),
|
186 |
+
"tiled": tiled,
|
187 |
+
"tile_size": tile_size,
|
188 |
+
"tile_stride": tile_stride
|
189 |
+
}
|
190 |
+
|
191 |
+
|
192 |
+
@torch.no_grad()
|
193 |
+
def __call__(
|
194 |
+
self,
|
195 |
+
prompt,
|
196 |
+
negative_prompt="",
|
197 |
+
cfg_scale=7.5,
|
198 |
+
clip_skip=1,
|
199 |
+
clip_skip_2=1,
|
200 |
+
input_image=None,
|
201 |
+
reference_images=[],
|
202 |
+
reference_strengths=[0.4],
|
203 |
+
denoising_strength=1.0,
|
204 |
+
height=1024,
|
205 |
+
width=1024,
|
206 |
+
num_inference_steps=20,
|
207 |
+
tiled=False,
|
208 |
+
tile_size=64,
|
209 |
+
tile_stride=32,
|
210 |
+
progress_bar_cmd=tqdm,
|
211 |
+
progress_bar_st=None,
|
212 |
+
):
|
213 |
+
# Prepare scheduler
|
214 |
+
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
215 |
+
|
216 |
+
# Prepare latent tensors
|
217 |
+
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
218 |
+
if input_image is not None:
|
219 |
+
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
220 |
+
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
|
221 |
+
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
222 |
+
else:
|
223 |
+
latents = noise.clone()
|
224 |
+
|
225 |
+
# Prepare reference latents
|
226 |
+
reference_latents = []
|
227 |
+
for reference_image in reference_images:
|
228 |
+
reference_image = self.preprocess_image(reference_image).to(device=self.device, dtype=self.torch_dtype)
|
229 |
+
reference_latents.append(self.vae_encoder(reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype))
|
230 |
+
|
231 |
+
# Encode prompts
|
232 |
+
prompt_emb_posi, attention_mask_posi, prompt_emb_t5_posi, attention_mask_t5_posi = self.prompter.encode_prompt(
|
233 |
+
self.text_encoder,
|
234 |
+
self.text_encoder_t5,
|
235 |
+
prompt,
|
236 |
+
clip_skip=clip_skip,
|
237 |
+
clip_skip_2=clip_skip_2,
|
238 |
+
positive=True,
|
239 |
+
device=self.device
|
240 |
+
)
|
241 |
+
if cfg_scale != 1.0:
|
242 |
+
prompt_emb_nega, attention_mask_nega, prompt_emb_t5_nega, attention_mask_t5_nega = self.prompter.encode_prompt(
|
243 |
+
self.text_encoder,
|
244 |
+
self.text_encoder_t5,
|
245 |
+
negative_prompt,
|
246 |
+
clip_skip=clip_skip,
|
247 |
+
clip_skip_2=clip_skip_2,
|
248 |
+
positive=False,
|
249 |
+
device=self.device
|
250 |
+
)
|
251 |
+
|
252 |
+
# Prepare positional id
|
253 |
+
extra_input = self.prepare_extra_input(height, width, tiled, tile_size)
|
254 |
+
|
255 |
+
# Denoise
|
256 |
+
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
257 |
+
timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device)
|
258 |
+
|
259 |
+
# In-context reference
|
260 |
+
for reference_latents_, reference_strength in zip(reference_latents, reference_strengths):
|
261 |
+
if progress_id < num_inference_steps * reference_strength:
|
262 |
+
noisy_reference_latents = self.scheduler.add_noise(reference_latents_, noise, self.scheduler.timesteps[progress_id])
|
263 |
+
self.dit(
|
264 |
+
noisy_reference_latents,
|
265 |
+
prompt_emb_posi, prompt_emb_t5_posi, attention_mask_posi, attention_mask_t5_posi,
|
266 |
+
timestep,
|
267 |
+
**extra_input,
|
268 |
+
to_cache=True
|
269 |
+
)
|
270 |
+
# Positive side
|
271 |
+
noise_pred_posi = self.dit(
|
272 |
+
latents,
|
273 |
+
prompt_emb_posi, prompt_emb_t5_posi, attention_mask_posi, attention_mask_t5_posi,
|
274 |
+
timestep,
|
275 |
+
**extra_input,
|
276 |
+
)
|
277 |
+
if cfg_scale != 1.0:
|
278 |
+
# Negative side
|
279 |
+
noise_pred_nega = self.dit(
|
280 |
+
latents,
|
281 |
+
prompt_emb_nega, prompt_emb_t5_nega, attention_mask_nega, attention_mask_t5_nega,
|
282 |
+
timestep,
|
283 |
+
**extra_input
|
284 |
+
)
|
285 |
+
# Classifier-free guidance
|
286 |
+
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
287 |
+
else:
|
288 |
+
noise_pred = noise_pred_posi
|
289 |
+
|
290 |
+
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
291 |
+
|
292 |
+
if progress_bar_st is not None:
|
293 |
+
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
294 |
+
|
295 |
+
# Decode image
|
296 |
+
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
297 |
+
|
298 |
+
return image
|
diffsynth/pipelines/stable_diffusion.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder
|
2 |
+
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
3 |
+
from ..prompts import SDPrompter
|
4 |
+
from ..schedulers import EnhancedDDIMScheduler
|
5 |
+
from .dancer import lets_dance
|
6 |
+
from typing import List
|
7 |
+
import torch
|
8 |
+
from tqdm import tqdm
|
9 |
+
from PIL import Image
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
|
13 |
+
class SDImagePipeline(torch.nn.Module):
|
14 |
+
|
15 |
+
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
16 |
+
super().__init__()
|
17 |
+
self.scheduler = EnhancedDDIMScheduler()
|
18 |
+
self.prompter = SDPrompter()
|
19 |
+
self.device = device
|
20 |
+
self.torch_dtype = torch_dtype
|
21 |
+
# models
|
22 |
+
self.text_encoder: SDTextEncoder = None
|
23 |
+
self.unet: SDUNet = None
|
24 |
+
self.vae_decoder: SDVAEDecoder = None
|
25 |
+
self.vae_encoder: SDVAEEncoder = None
|
26 |
+
self.controlnet: MultiControlNetManager = None
|
27 |
+
self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
|
28 |
+
self.ipadapter: SDIpAdapter = None
|
29 |
+
|
30 |
+
|
31 |
+
def fetch_main_models(self, model_manager: ModelManager):
|
32 |
+
self.text_encoder = model_manager.text_encoder
|
33 |
+
self.unet = model_manager.unet
|
34 |
+
self.vae_decoder = model_manager.vae_decoder
|
35 |
+
self.vae_encoder = model_manager.vae_encoder
|
36 |
+
|
37 |
+
|
38 |
+
def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
39 |
+
controlnet_units = []
|
40 |
+
for config in controlnet_config_units:
|
41 |
+
controlnet_unit = ControlNetUnit(
|
42 |
+
Annotator(config.processor_id),
|
43 |
+
model_manager.get_model_with_model_path(config.model_path),
|
44 |
+
config.scale
|
45 |
+
)
|
46 |
+
controlnet_units.append(controlnet_unit)
|
47 |
+
self.controlnet = MultiControlNetManager(controlnet_units)
|
48 |
+
|
49 |
+
|
50 |
+
def fetch_ipadapter(self, model_manager: ModelManager):
|
51 |
+
if "ipadapter" in model_manager.model:
|
52 |
+
self.ipadapter = model_manager.ipadapter
|
53 |
+
if "ipadapter_image_encoder" in model_manager.model:
|
54 |
+
self.ipadapter_image_encoder = model_manager.ipadapter_image_encoder
|
55 |
+
|
56 |
+
|
57 |
+
def fetch_prompter(self, model_manager: ModelManager):
|
58 |
+
self.prompter.load_from_model_manager(model_manager)
|
59 |
+
|
60 |
+
|
61 |
+
@staticmethod
|
62 |
+
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
63 |
+
pipe = SDImagePipeline(
|
64 |
+
device=model_manager.device,
|
65 |
+
torch_dtype=model_manager.torch_dtype,
|
66 |
+
)
|
67 |
+
pipe.fetch_main_models(model_manager)
|
68 |
+
pipe.fetch_prompter(model_manager)
|
69 |
+
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
|
70 |
+
pipe.fetch_ipadapter(model_manager)
|
71 |
+
return pipe
|
72 |
+
|
73 |
+
|
74 |
+
def preprocess_image(self, image):
|
75 |
+
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
76 |
+
return image
|
77 |
+
|
78 |
+
|
79 |
+
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
80 |
+
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
81 |
+
image = image.cpu().permute(1, 2, 0).numpy()
|
82 |
+
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
83 |
+
return image
|
84 |
+
|
85 |
+
|
86 |
+
@torch.no_grad()
|
87 |
+
def __call__(
|
88 |
+
self,
|
89 |
+
prompt,
|
90 |
+
negative_prompt="",
|
91 |
+
cfg_scale=7.5,
|
92 |
+
clip_skip=1,
|
93 |
+
input_image=None,
|
94 |
+
ipadapter_images=None,
|
95 |
+
ipadapter_scale=1.0,
|
96 |
+
controlnet_image=None,
|
97 |
+
denoising_strength=1.0,
|
98 |
+
height=512,
|
99 |
+
width=512,
|
100 |
+
num_inference_steps=20,
|
101 |
+
tiled=False,
|
102 |
+
tile_size=64,
|
103 |
+
tile_stride=32,
|
104 |
+
progress_bar_cmd=tqdm,
|
105 |
+
progress_bar_st=None,
|
106 |
+
):
|
107 |
+
# Prepare scheduler
|
108 |
+
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
109 |
+
|
110 |
+
# Prepare latent tensors
|
111 |
+
if input_image is not None:
|
112 |
+
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
113 |
+
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
114 |
+
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
115 |
+
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
116 |
+
else:
|
117 |
+
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
118 |
+
|
119 |
+
# Encode prompts
|
120 |
+
prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True)
|
121 |
+
prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False)
|
122 |
+
|
123 |
+
# IP-Adapter
|
124 |
+
if ipadapter_images is not None:
|
125 |
+
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
126 |
+
ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)
|
127 |
+
ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding))
|
128 |
+
else:
|
129 |
+
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {}
|
130 |
+
|
131 |
+
# Prepare ControlNets
|
132 |
+
if controlnet_image is not None:
|
133 |
+
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
|
134 |
+
controlnet_image = controlnet_image.unsqueeze(1)
|
135 |
+
|
136 |
+
# Denoise
|
137 |
+
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
138 |
+
timestep = torch.IntTensor((timestep,))[0].to(self.device)
|
139 |
+
|
140 |
+
# Classifier-free guidance
|
141 |
+
noise_pred_posi = lets_dance(
|
142 |
+
self.unet, motion_modules=None, controlnet=self.controlnet,
|
143 |
+
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_image,
|
144 |
+
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
145 |
+
ipadapter_kwargs_list=ipadapter_kwargs_list_posi,
|
146 |
+
device=self.device, vram_limit_level=0
|
147 |
+
)
|
148 |
+
noise_pred_nega = lets_dance(
|
149 |
+
self.unet, motion_modules=None, controlnet=self.controlnet,
|
150 |
+
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_image,
|
151 |
+
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
152 |
+
ipadapter_kwargs_list=ipadapter_kwargs_list_nega,
|
153 |
+
device=self.device, vram_limit_level=0
|
154 |
+
)
|
155 |
+
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
156 |
+
|
157 |
+
# DDIM
|
158 |
+
latents = self.scheduler.step(noise_pred, timestep, latents)
|
159 |
+
|
160 |
+
# UI
|
161 |
+
if progress_bar_st is not None:
|
162 |
+
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
163 |
+
|
164 |
+
# Decode image
|
165 |
+
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
166 |
+
|
167 |
+
return image
|
diffsynth/pipelines/stable_diffusion_video.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDMotionModel
|
2 |
+
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
3 |
+
from ..prompts import SDPrompter
|
4 |
+
from ..schedulers import EnhancedDDIMScheduler
|
5 |
+
from ..data import VideoData, save_frames, save_video
|
6 |
+
from .dancer import lets_dance
|
7 |
+
from ..processors.sequencial_processor import SequencialProcessor
|
8 |
+
from typing import List
|
9 |
+
import torch, os, json
|
10 |
+
from tqdm import tqdm
|
11 |
+
from PIL import Image
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
|
15 |
+
def lets_dance_with_long_video(
|
16 |
+
unet: SDUNet,
|
17 |
+
motion_modules: SDMotionModel = None,
|
18 |
+
controlnet: MultiControlNetManager = None,
|
19 |
+
sample = None,
|
20 |
+
timestep = None,
|
21 |
+
encoder_hidden_states = None,
|
22 |
+
controlnet_frames = None,
|
23 |
+
animatediff_batch_size = 16,
|
24 |
+
animatediff_stride = 8,
|
25 |
+
unet_batch_size = 1,
|
26 |
+
controlnet_batch_size = 1,
|
27 |
+
cross_frame_attention = False,
|
28 |
+
device = "cuda",
|
29 |
+
vram_limit_level = 0,
|
30 |
+
):
|
31 |
+
num_frames = sample.shape[0]
|
32 |
+
hidden_states_output = [(torch.zeros(sample[0].shape, dtype=sample[0].dtype), 0) for i in range(num_frames)]
|
33 |
+
|
34 |
+
for batch_id in range(0, num_frames, animatediff_stride):
|
35 |
+
batch_id_ = min(batch_id + animatediff_batch_size, num_frames)
|
36 |
+
|
37 |
+
# process this batch
|
38 |
+
hidden_states_batch = lets_dance(
|
39 |
+
unet, motion_modules, controlnet,
|
40 |
+
sample[batch_id: batch_id_].to(device),
|
41 |
+
timestep,
|
42 |
+
encoder_hidden_states[batch_id: batch_id_].to(device),
|
43 |
+
controlnet_frames=controlnet_frames[:, batch_id: batch_id_].to(device) if controlnet_frames is not None else None,
|
44 |
+
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
|
45 |
+
cross_frame_attention=cross_frame_attention,
|
46 |
+
device=device, vram_limit_level=vram_limit_level
|
47 |
+
).cpu()
|
48 |
+
|
49 |
+
# update hidden_states
|
50 |
+
for i, hidden_states_updated in zip(range(batch_id, batch_id_), hidden_states_batch):
|
51 |
+
bias = max(1 - abs(i - (batch_id + batch_id_ - 1) / 2) / ((batch_id_ - batch_id - 1 + 1e-2) / 2), 1e-2)
|
52 |
+
hidden_states, num = hidden_states_output[i]
|
53 |
+
hidden_states = hidden_states * (num / (num + bias)) + hidden_states_updated * (bias / (num + bias))
|
54 |
+
hidden_states_output[i] = (hidden_states, num + bias)
|
55 |
+
|
56 |
+
if batch_id_ == num_frames:
|
57 |
+
break
|
58 |
+
|
59 |
+
# output
|
60 |
+
hidden_states = torch.stack([h for h, _ in hidden_states_output])
|
61 |
+
return hidden_states
|
62 |
+
|
63 |
+
|
64 |
+
class SDVideoPipeline(torch.nn.Module):
|
65 |
+
|
66 |
+
def __init__(self, device="cuda", torch_dtype=torch.float16, use_animatediff=True):
|
67 |
+
super().__init__()
|
68 |
+
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_animatediff else "scaled_linear")
|
69 |
+
self.prompter = SDPrompter()
|
70 |
+
self.device = device
|
71 |
+
self.torch_dtype = torch_dtype
|
72 |
+
# models
|
73 |
+
self.text_encoder: SDTextEncoder = None
|
74 |
+
self.unet: SDUNet = None
|
75 |
+
self.vae_decoder: SDVAEDecoder = None
|
76 |
+
self.vae_encoder: SDVAEEncoder = None
|
77 |
+
self.controlnet: MultiControlNetManager = None
|
78 |
+
self.motion_modules: SDMotionModel = None
|
79 |
+
|
80 |
+
|
81 |
+
def fetch_main_models(self, model_manager: ModelManager):
|
82 |
+
self.text_encoder = model_manager.text_encoder
|
83 |
+
self.unet = model_manager.unet
|
84 |
+
self.vae_decoder = model_manager.vae_decoder
|
85 |
+
self.vae_encoder = model_manager.vae_encoder
|
86 |
+
|
87 |
+
|
88 |
+
def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
89 |
+
controlnet_units = []
|
90 |
+
for config in controlnet_config_units:
|
91 |
+
controlnet_unit = ControlNetUnit(
|
92 |
+
Annotator(config.processor_id),
|
93 |
+
model_manager.get_model_with_model_path(config.model_path),
|
94 |
+
config.scale
|
95 |
+
)
|
96 |
+
controlnet_units.append(controlnet_unit)
|
97 |
+
self.controlnet = MultiControlNetManager(controlnet_units)
|
98 |
+
|
99 |
+
|
100 |
+
def fetch_motion_modules(self, model_manager: ModelManager):
|
101 |
+
if "motion_modules" in model_manager.model:
|
102 |
+
self.motion_modules = model_manager.motion_modules
|
103 |
+
|
104 |
+
|
105 |
+
def fetch_prompter(self, model_manager: ModelManager):
|
106 |
+
self.prompter.load_from_model_manager(model_manager)
|
107 |
+
|
108 |
+
|
109 |
+
@staticmethod
|
110 |
+
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
|
111 |
+
pipe = SDVideoPipeline(
|
112 |
+
device=model_manager.device,
|
113 |
+
torch_dtype=model_manager.torch_dtype,
|
114 |
+
use_animatediff="motion_modules" in model_manager.model
|
115 |
+
)
|
116 |
+
pipe.fetch_main_models(model_manager)
|
117 |
+
pipe.fetch_motion_modules(model_manager)
|
118 |
+
pipe.fetch_prompter(model_manager)
|
119 |
+
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
|
120 |
+
return pipe
|
121 |
+
|
122 |
+
|
123 |
+
def preprocess_image(self, image):
|
124 |
+
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
125 |
+
return image
|
126 |
+
|
127 |
+
|
128 |
+
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
129 |
+
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
130 |
+
image = image.cpu().permute(1, 2, 0).numpy()
|
131 |
+
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
132 |
+
return image
|
133 |
+
|
134 |
+
|
135 |
+
def decode_images(self, latents, tiled=False, tile_size=64, tile_stride=32):
|
136 |
+
images = [
|
137 |
+
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
138 |
+
for frame_id in range(latents.shape[0])
|
139 |
+
]
|
140 |
+
return images
|
141 |
+
|
142 |
+
|
143 |
+
def encode_images(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
|
144 |
+
latents = []
|
145 |
+
for image in processed_images:
|
146 |
+
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
147 |
+
latent = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).cpu()
|
148 |
+
latents.append(latent)
|
149 |
+
latents = torch.concat(latents, dim=0)
|
150 |
+
return latents
|
151 |
+
|
152 |
+
|
153 |
+
@torch.no_grad()
|
154 |
+
def __call__(
|
155 |
+
self,
|
156 |
+
prompt,
|
157 |
+
negative_prompt="",
|
158 |
+
cfg_scale=7.5,
|
159 |
+
clip_skip=1,
|
160 |
+
num_frames=None,
|
161 |
+
input_frames=None,
|
162 |
+
controlnet_frames=None,
|
163 |
+
denoising_strength=1.0,
|
164 |
+
height=512,
|
165 |
+
width=512,
|
166 |
+
num_inference_steps=20,
|
167 |
+
animatediff_batch_size = 16,
|
168 |
+
animatediff_stride = 8,
|
169 |
+
unet_batch_size = 1,
|
170 |
+
controlnet_batch_size = 1,
|
171 |
+
cross_frame_attention = False,
|
172 |
+
smoother=None,
|
173 |
+
smoother_progress_ids=[],
|
174 |
+
vram_limit_level=0,
|
175 |
+
progress_bar_cmd=tqdm,
|
176 |
+
progress_bar_st=None,
|
177 |
+
):
|
178 |
+
# Prepare scheduler
|
179 |
+
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
180 |
+
|
181 |
+
# Prepare latent tensors
|
182 |
+
if self.motion_modules is None:
|
183 |
+
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
|
184 |
+
else:
|
185 |
+
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype)
|
186 |
+
if input_frames is None or denoising_strength == 1.0:
|
187 |
+
latents = noise
|
188 |
+
else:
|
189 |
+
latents = self.encode_images(input_frames)
|
190 |
+
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
191 |
+
|
192 |
+
# Encode prompts
|
193 |
+
prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True).cpu()
|
194 |
+
prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False).cpu()
|
195 |
+
prompt_emb_posi = prompt_emb_posi.repeat(num_frames, 1, 1)
|
196 |
+
prompt_emb_nega = prompt_emb_nega.repeat(num_frames, 1, 1)
|
197 |
+
|
198 |
+
# Prepare ControlNets
|
199 |
+
if controlnet_frames is not None:
|
200 |
+
if isinstance(controlnet_frames[0], list):
|
201 |
+
controlnet_frames_ = []
|
202 |
+
for processor_id in range(len(controlnet_frames)):
|
203 |
+
controlnet_frames_.append(
|
204 |
+
torch.stack([
|
205 |
+
self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
|
206 |
+
for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
|
207 |
+
], dim=1)
|
208 |
+
)
|
209 |
+
controlnet_frames = torch.concat(controlnet_frames_, dim=0)
|
210 |
+
else:
|
211 |
+
controlnet_frames = torch.stack([
|
212 |
+
self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
|
213 |
+
for controlnet_frame in progress_bar_cmd(controlnet_frames)
|
214 |
+
], dim=1)
|
215 |
+
|
216 |
+
# Denoise
|
217 |
+
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
218 |
+
timestep = torch.IntTensor((timestep,))[0].to(self.device)
|
219 |
+
|
220 |
+
# Classifier-free guidance
|
221 |
+
noise_pred_posi = lets_dance_with_long_video(
|
222 |
+
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
|
223 |
+
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_frames,
|
224 |
+
animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride,
|
225 |
+
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
|
226 |
+
cross_frame_attention=cross_frame_attention,
|
227 |
+
device=self.device, vram_limit_level=vram_limit_level
|
228 |
+
)
|
229 |
+
noise_pred_nega = lets_dance_with_long_video(
|
230 |
+
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
|
231 |
+
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
|
232 |
+
animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride,
|
233 |
+
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
|
234 |
+
cross_frame_attention=cross_frame_attention,
|
235 |
+
device=self.device, vram_limit_level=vram_limit_level
|
236 |
+
)
|
237 |
+
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
238 |
+
|
239 |
+
# DDIM and smoother
|
240 |
+
if smoother is not None and progress_id in smoother_progress_ids:
|
241 |
+
rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
|
242 |
+
rendered_frames = self.decode_images(rendered_frames)
|
243 |
+
rendered_frames = smoother(rendered_frames, original_frames=input_frames)
|
244 |
+
target_latents = self.encode_images(rendered_frames)
|
245 |
+
noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
|
246 |
+
latents = self.scheduler.step(noise_pred, timestep, latents)
|
247 |
+
|
248 |
+
# UI
|
249 |
+
if progress_bar_st is not None:
|
250 |
+
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
251 |
+
|
252 |
+
# Decode image
|
253 |
+
output_frames = self.decode_images(latents)
|
254 |
+
|
255 |
+
# Post-process
|
256 |
+
if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
|
257 |
+
output_frames = smoother(output_frames, original_frames=input_frames)
|
258 |
+
|
259 |
+
return output_frames
|
260 |
+
|
261 |
+
|
262 |
+
|
263 |
+
class SDVideoPipelineRunner:
|
264 |
+
def __init__(self, in_streamlit=False):
|
265 |
+
self.in_streamlit = in_streamlit
|
266 |
+
|
267 |
+
|
268 |
+
def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units):
|
269 |
+
# Load models
|
270 |
+
model_manager = ModelManager(torch_dtype=torch.float16, device=device)
|
271 |
+
model_manager.load_textual_inversions(textual_inversion_folder)
|
272 |
+
model_manager.load_models(model_list, lora_alphas=lora_alphas)
|
273 |
+
pipe = SDVideoPipeline.from_model_manager(
|
274 |
+
model_manager,
|
275 |
+
[
|
276 |
+
ControlNetConfigUnit(
|
277 |
+
processor_id=unit["processor_id"],
|
278 |
+
model_path=unit["model_path"],
|
279 |
+
scale=unit["scale"]
|
280 |
+
) for unit in controlnet_units
|
281 |
+
]
|
282 |
+
)
|
283 |
+
return model_manager, pipe
|
284 |
+
|
285 |
+
|
286 |
+
def load_smoother(self, model_manager, smoother_configs):
|
287 |
+
smoother = SequencialProcessor.from_model_manager(model_manager, smoother_configs)
|
288 |
+
return smoother
|
289 |
+
|
290 |
+
|
291 |
+
def synthesize_video(self, model_manager, pipe, seed, smoother, **pipeline_inputs):
|
292 |
+
torch.manual_seed(seed)
|
293 |
+
if self.in_streamlit:
|
294 |
+
import streamlit as st
|
295 |
+
progress_bar_st = st.progress(0.0)
|
296 |
+
output_video = pipe(**pipeline_inputs, smoother=smoother, progress_bar_st=progress_bar_st)
|
297 |
+
progress_bar_st.progress(1.0)
|
298 |
+
else:
|
299 |
+
output_video = pipe(**pipeline_inputs, smoother=smoother)
|
300 |
+
model_manager.to("cpu")
|
301 |
+
return output_video
|
302 |
+
|
303 |
+
|
304 |
+
def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id):
|
305 |
+
video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width)
|
306 |
+
if start_frame_id is None:
|
307 |
+
start_frame_id = 0
|
308 |
+
if end_frame_id is None:
|
309 |
+
end_frame_id = len(video)
|
310 |
+
frames = [video[i] for i in range(start_frame_id, end_frame_id)]
|
311 |
+
return frames
|
312 |
+
|
313 |
+
|
314 |
+
def add_data_to_pipeline_inputs(self, data, pipeline_inputs):
|
315 |
+
pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"])
|
316 |
+
pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"])
|
317 |
+
pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size
|
318 |
+
if len(data["controlnet_frames"]) > 0:
|
319 |
+
pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]]
|
320 |
+
return pipeline_inputs
|
321 |
+
|
322 |
+
|
323 |
+
def save_output(self, video, output_folder, fps, config):
|
324 |
+
os.makedirs(output_folder, exist_ok=True)
|
325 |
+
save_frames(video, os.path.join(output_folder, "frames"))
|
326 |
+
save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps)
|
327 |
+
config["pipeline"]["pipeline_inputs"]["input_frames"] = []
|
328 |
+
config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = []
|
329 |
+
with open(os.path.join(output_folder, "config.json"), 'w') as file:
|
330 |
+
json.dump(config, file, indent=4)
|
331 |
+
|
332 |
+
|
333 |
+
def run(self, config):
|
334 |
+
if self.in_streamlit:
|
335 |
+
import streamlit as st
|
336 |
+
if self.in_streamlit: st.markdown("Loading videos ...")
|
337 |
+
config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"])
|
338 |
+
if self.in_streamlit: st.markdown("Loading videos ... done!")
|
339 |
+
if self.in_streamlit: st.markdown("Loading models ...")
|
340 |
+
model_manager, pipe = self.load_pipeline(**config["models"])
|
341 |
+
if self.in_streamlit: st.markdown("Loading models ... done!")
|
342 |
+
if "smoother_configs" in config:
|
343 |
+
if self.in_streamlit: st.markdown("Loading smoother ...")
|
344 |
+
smoother = self.load_smoother(model_manager, config["smoother_configs"])
|
345 |
+
if self.in_streamlit: st.markdown("Loading smoother ... done!")
|
346 |
+
else:
|
347 |
+
smoother = None
|
348 |
+
if self.in_streamlit: st.markdown("Synthesizing videos ...")
|
349 |
+
output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], smoother, **config["pipeline"]["pipeline_inputs"])
|
350 |
+
if self.in_streamlit: st.markdown("Synthesizing videos ... done!")
|
351 |
+
if self.in_streamlit: st.markdown("Saving videos ...")
|
352 |
+
self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config)
|
353 |
+
if self.in_streamlit: st.markdown("Saving videos ... done!")
|
354 |
+
if self.in_streamlit: st.markdown("Finished!")
|
355 |
+
video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb')
|
356 |
+
if self.in_streamlit: st.video(video_file.read())
|
diffsynth/pipelines/stable_diffusion_xl.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
2 |
+
# TODO: SDXL ControlNet
|
3 |
+
from ..prompts import SDXLPrompter
|
4 |
+
from ..schedulers import EnhancedDDIMScheduler
|
5 |
+
from .dancer import lets_dance_xl
|
6 |
+
import torch
|
7 |
+
from tqdm import tqdm
|
8 |
+
from PIL import Image
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
class SDXLImagePipeline(torch.nn.Module):
|
13 |
+
|
14 |
+
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
15 |
+
super().__init__()
|
16 |
+
self.scheduler = EnhancedDDIMScheduler()
|
17 |
+
self.prompter = SDXLPrompter()
|
18 |
+
self.device = device
|
19 |
+
self.torch_dtype = torch_dtype
|
20 |
+
# models
|
21 |
+
self.text_encoder: SDXLTextEncoder = None
|
22 |
+
self.text_encoder_2: SDXLTextEncoder2 = None
|
23 |
+
self.unet: SDXLUNet = None
|
24 |
+
self.vae_decoder: SDXLVAEDecoder = None
|
25 |
+
self.vae_encoder: SDXLVAEEncoder = None
|
26 |
+
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
|
27 |
+
self.ipadapter: SDXLIpAdapter = None
|
28 |
+
# TODO: SDXL ControlNet
|
29 |
+
|
30 |
+
def fetch_main_models(self, model_manager: ModelManager):
|
31 |
+
self.text_encoder = model_manager.text_encoder
|
32 |
+
self.text_encoder_2 = model_manager.text_encoder_2
|
33 |
+
self.unet = model_manager.unet
|
34 |
+
self.vae_decoder = model_manager.vae_decoder
|
35 |
+
self.vae_encoder = model_manager.vae_encoder
|
36 |
+
|
37 |
+
|
38 |
+
def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs):
|
39 |
+
# TODO: SDXL ControlNet
|
40 |
+
pass
|
41 |
+
|
42 |
+
|
43 |
+
def fetch_ipadapter(self, model_manager: ModelManager):
|
44 |
+
if "ipadapter_xl" in model_manager.model:
|
45 |
+
self.ipadapter = model_manager.ipadapter_xl
|
46 |
+
if "ipadapter_xl_image_encoder" in model_manager.model:
|
47 |
+
self.ipadapter_image_encoder = model_manager.ipadapter_xl_image_encoder
|
48 |
+
|
49 |
+
|
50 |
+
def fetch_prompter(self, model_manager: ModelManager):
|
51 |
+
self.prompter.load_from_model_manager(model_manager)
|
52 |
+
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs):
|
56 |
+
pipe = SDXLImagePipeline(
|
57 |
+
device=model_manager.device,
|
58 |
+
torch_dtype=model_manager.torch_dtype,
|
59 |
+
)
|
60 |
+
pipe.fetch_main_models(model_manager)
|
61 |
+
pipe.fetch_prompter(model_manager)
|
62 |
+
pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
|
63 |
+
pipe.fetch_ipadapter(model_manager)
|
64 |
+
return pipe
|
65 |
+
|
66 |
+
|
67 |
+
def preprocess_image(self, image):
|
68 |
+
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
69 |
+
return image
|
70 |
+
|
71 |
+
|
72 |
+
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
73 |
+
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
74 |
+
image = image.cpu().permute(1, 2, 0).numpy()
|
75 |
+
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
76 |
+
return image
|
77 |
+
|
78 |
+
|
79 |
+
@torch.no_grad()
|
80 |
+
def __call__(
|
81 |
+
self,
|
82 |
+
prompt,
|
83 |
+
negative_prompt="",
|
84 |
+
cfg_scale=7.5,
|
85 |
+
clip_skip=1,
|
86 |
+
clip_skip_2=2,
|
87 |
+
input_image=None,
|
88 |
+
ipadapter_images=None,
|
89 |
+
ipadapter_scale=1.0,
|
90 |
+
controlnet_image=None,
|
91 |
+
denoising_strength=1.0,
|
92 |
+
height=1024,
|
93 |
+
width=1024,
|
94 |
+
num_inference_steps=20,
|
95 |
+
tiled=False,
|
96 |
+
tile_size=64,
|
97 |
+
tile_stride=32,
|
98 |
+
progress_bar_cmd=tqdm,
|
99 |
+
progress_bar_st=None,
|
100 |
+
):
|
101 |
+
# Prepare scheduler
|
102 |
+
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
103 |
+
|
104 |
+
# Prepare latent tensors
|
105 |
+
if input_image is not None:
|
106 |
+
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
107 |
+
latents = self.vae_encoder(image.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
|
108 |
+
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
109 |
+
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
110 |
+
else:
|
111 |
+
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
112 |
+
|
113 |
+
# Encode prompts
|
114 |
+
add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt(
|
115 |
+
self.text_encoder,
|
116 |
+
self.text_encoder_2,
|
117 |
+
prompt,
|
118 |
+
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
119 |
+
device=self.device,
|
120 |
+
positive=True,
|
121 |
+
)
|
122 |
+
if cfg_scale != 1.0:
|
123 |
+
add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
|
124 |
+
self.text_encoder,
|
125 |
+
self.text_encoder_2,
|
126 |
+
negative_prompt,
|
127 |
+
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
128 |
+
device=self.device,
|
129 |
+
positive=False,
|
130 |
+
)
|
131 |
+
|
132 |
+
# Prepare positional id
|
133 |
+
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
|
134 |
+
|
135 |
+
# IP-Adapter
|
136 |
+
if ipadapter_images is not None:
|
137 |
+
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
138 |
+
ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)
|
139 |
+
ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding))
|
140 |
+
else:
|
141 |
+
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {}
|
142 |
+
|
143 |
+
# Denoise
|
144 |
+
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
145 |
+
timestep = torch.IntTensor((timestep,))[0].to(self.device)
|
146 |
+
|
147 |
+
# Classifier-free guidance
|
148 |
+
noise_pred_posi = lets_dance_xl(
|
149 |
+
self.unet,
|
150 |
+
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi,
|
151 |
+
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
|
152 |
+
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
153 |
+
ipadapter_kwargs_list=ipadapter_kwargs_list_posi,
|
154 |
+
)
|
155 |
+
if cfg_scale != 1.0:
|
156 |
+
noise_pred_nega = lets_dance_xl(
|
157 |
+
self.unet,
|
158 |
+
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega,
|
159 |
+
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
|
160 |
+
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
161 |
+
ipadapter_kwargs_list=ipadapter_kwargs_list_nega,
|
162 |
+
)
|
163 |
+
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
164 |
+
else:
|
165 |
+
noise_pred = noise_pred_posi
|
166 |
+
|
167 |
+
latents = self.scheduler.step(noise_pred, timestep, latents)
|
168 |
+
|
169 |
+
if progress_bar_st is not None:
|
170 |
+
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
171 |
+
|
172 |
+
# Decode image
|
173 |
+
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
174 |
+
|
175 |
+
return image
|
diffsynth/pipelines/stable_diffusion_xl_video.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLMotionModel
|
2 |
+
from .dancer import lets_dance_xl
|
3 |
+
# TODO: SDXL ControlNet
|
4 |
+
from ..prompts import SDXLPrompter
|
5 |
+
from ..schedulers import EnhancedDDIMScheduler
|
6 |
+
import torch
|
7 |
+
from tqdm import tqdm
|
8 |
+
from PIL import Image
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
class SDXLVideoPipeline(torch.nn.Module):
|
13 |
+
|
14 |
+
def __init__(self, device="cuda", torch_dtype=torch.float16, use_animatediff=True):
|
15 |
+
super().__init__()
|
16 |
+
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_animatediff else "scaled_linear")
|
17 |
+
self.prompter = SDXLPrompter()
|
18 |
+
self.device = device
|
19 |
+
self.torch_dtype = torch_dtype
|
20 |
+
# models
|
21 |
+
self.text_encoder: SDXLTextEncoder = None
|
22 |
+
self.text_encoder_2: SDXLTextEncoder2 = None
|
23 |
+
self.unet: SDXLUNet = None
|
24 |
+
self.vae_decoder: SDXLVAEDecoder = None
|
25 |
+
self.vae_encoder: SDXLVAEEncoder = None
|
26 |
+
# TODO: SDXL ControlNet
|
27 |
+
self.motion_modules: SDXLMotionModel = None
|
28 |
+
|
29 |
+
|
30 |
+
def fetch_main_models(self, model_manager: ModelManager):
|
31 |
+
self.text_encoder = model_manager.text_encoder
|
32 |
+
self.text_encoder_2 = model_manager.text_encoder_2
|
33 |
+
self.unet = model_manager.unet
|
34 |
+
self.vae_decoder = model_manager.vae_decoder
|
35 |
+
self.vae_encoder = model_manager.vae_encoder
|
36 |
+
|
37 |
+
|
38 |
+
def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs):
|
39 |
+
# TODO: SDXL ControlNet
|
40 |
+
pass
|
41 |
+
|
42 |
+
|
43 |
+
def fetch_motion_modules(self, model_manager: ModelManager):
|
44 |
+
if "motion_modules_xl" in model_manager.model:
|
45 |
+
self.motion_modules = model_manager.motion_modules_xl
|
46 |
+
|
47 |
+
|
48 |
+
def fetch_prompter(self, model_manager: ModelManager):
|
49 |
+
self.prompter.load_from_model_manager(model_manager)
|
50 |
+
|
51 |
+
|
52 |
+
@staticmethod
|
53 |
+
def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs):
|
54 |
+
pipe = SDXLVideoPipeline(
|
55 |
+
device=model_manager.device,
|
56 |
+
torch_dtype=model_manager.torch_dtype,
|
57 |
+
use_animatediff="motion_modules_xl" in model_manager.model
|
58 |
+
)
|
59 |
+
pipe.fetch_main_models(model_manager)
|
60 |
+
pipe.fetch_motion_modules(model_manager)
|
61 |
+
pipe.fetch_prompter(model_manager)
|
62 |
+
pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
|
63 |
+
return pipe
|
64 |
+
|
65 |
+
|
66 |
+
def preprocess_image(self, image):
|
67 |
+
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
68 |
+
return image
|
69 |
+
|
70 |
+
|
71 |
+
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
72 |
+
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
73 |
+
image = image.cpu().permute(1, 2, 0).numpy()
|
74 |
+
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
75 |
+
return image
|
76 |
+
|
77 |
+
|
78 |
+
def decode_images(self, latents, tiled=False, tile_size=64, tile_stride=32):
|
79 |
+
images = [
|
80 |
+
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
81 |
+
for frame_id in range(latents.shape[0])
|
82 |
+
]
|
83 |
+
return images
|
84 |
+
|
85 |
+
|
86 |
+
def encode_images(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
|
87 |
+
latents = []
|
88 |
+
for image in processed_images:
|
89 |
+
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
90 |
+
latent = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).cpu()
|
91 |
+
latents.append(latent)
|
92 |
+
latents = torch.concat(latents, dim=0)
|
93 |
+
return latents
|
94 |
+
|
95 |
+
|
96 |
+
@torch.no_grad()
|
97 |
+
def __call__(
|
98 |
+
self,
|
99 |
+
prompt,
|
100 |
+
negative_prompt="",
|
101 |
+
cfg_scale=7.5,
|
102 |
+
clip_skip=1,
|
103 |
+
clip_skip_2=2,
|
104 |
+
num_frames=None,
|
105 |
+
input_frames=None,
|
106 |
+
controlnet_frames=None,
|
107 |
+
denoising_strength=1.0,
|
108 |
+
height=512,
|
109 |
+
width=512,
|
110 |
+
num_inference_steps=20,
|
111 |
+
animatediff_batch_size = 16,
|
112 |
+
animatediff_stride = 8,
|
113 |
+
unet_batch_size = 1,
|
114 |
+
controlnet_batch_size = 1,
|
115 |
+
cross_frame_attention = False,
|
116 |
+
smoother=None,
|
117 |
+
smoother_progress_ids=[],
|
118 |
+
vram_limit_level=0,
|
119 |
+
progress_bar_cmd=tqdm,
|
120 |
+
progress_bar_st=None,
|
121 |
+
):
|
122 |
+
# Prepare scheduler
|
123 |
+
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
124 |
+
|
125 |
+
# Prepare latent tensors
|
126 |
+
if self.motion_modules is None:
|
127 |
+
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
|
128 |
+
else:
|
129 |
+
noise = torch.randn((num_frames, 4, height//8, width//8), device="cuda", dtype=self.torch_dtype)
|
130 |
+
if input_frames is None or denoising_strength == 1.0:
|
131 |
+
latents = noise
|
132 |
+
else:
|
133 |
+
latents = self.encode_images(input_frames)
|
134 |
+
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
135 |
+
|
136 |
+
# Encode prompts
|
137 |
+
add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt(
|
138 |
+
self.text_encoder,
|
139 |
+
self.text_encoder_2,
|
140 |
+
prompt,
|
141 |
+
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
142 |
+
device=self.device,
|
143 |
+
positive=True,
|
144 |
+
)
|
145 |
+
if cfg_scale != 1.0:
|
146 |
+
add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
|
147 |
+
self.text_encoder,
|
148 |
+
self.text_encoder_2,
|
149 |
+
negative_prompt,
|
150 |
+
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
151 |
+
device=self.device,
|
152 |
+
positive=False,
|
153 |
+
)
|
154 |
+
|
155 |
+
# Prepare positional id
|
156 |
+
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
|
157 |
+
|
158 |
+
# Denoise
|
159 |
+
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
160 |
+
timestep = torch.IntTensor((timestep,))[0].to(self.device)
|
161 |
+
|
162 |
+
# Classifier-free guidance
|
163 |
+
noise_pred_posi = lets_dance_xl(
|
164 |
+
self.unet, motion_modules=self.motion_modules, controlnet=None,
|
165 |
+
sample=latents, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
|
166 |
+
timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_frames,
|
167 |
+
cross_frame_attention=cross_frame_attention,
|
168 |
+
device=self.device, vram_limit_level=vram_limit_level
|
169 |
+
)
|
170 |
+
if cfg_scale != 1.0:
|
171 |
+
noise_pred_nega = lets_dance_xl(
|
172 |
+
self.unet, motion_modules=self.motion_modules, controlnet=None,
|
173 |
+
sample=latents, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
|
174 |
+
timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
|
175 |
+
cross_frame_attention=cross_frame_attention,
|
176 |
+
device=self.device, vram_limit_level=vram_limit_level
|
177 |
+
)
|
178 |
+
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
179 |
+
else:
|
180 |
+
noise_pred = noise_pred_posi
|
181 |
+
|
182 |
+
latents = self.scheduler.step(noise_pred, timestep, latents)
|
183 |
+
|
184 |
+
if progress_bar_st is not None:
|
185 |
+
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
186 |
+
|
187 |
+
# Decode image
|
188 |
+
image = self.decode_images(latents.to(torch.float32))
|
189 |
+
|
190 |
+
return image
|
diffsynth/pipelines/stable_video_diffusion.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..models import ModelManager, SVDImageEncoder, SVDUNet, SVDVAEEncoder, SVDVAEDecoder
|
2 |
+
from ..schedulers import ContinuousODEScheduler
|
3 |
+
import torch
|
4 |
+
from tqdm import tqdm
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
from einops import rearrange, repeat
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
class SVDVideoPipeline(torch.nn.Module):
|
12 |
+
|
13 |
+
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
14 |
+
super().__init__()
|
15 |
+
self.scheduler = ContinuousODEScheduler()
|
16 |
+
self.device = device
|
17 |
+
self.torch_dtype = torch_dtype
|
18 |
+
# models
|
19 |
+
self.image_encoder: SVDImageEncoder = None
|
20 |
+
self.unet: SVDUNet = None
|
21 |
+
self.vae_encoder: SVDVAEEncoder = None
|
22 |
+
self.vae_decoder: SVDVAEDecoder = None
|
23 |
+
|
24 |
+
|
25 |
+
def fetch_main_models(self, model_manager: ModelManager):
|
26 |
+
self.image_encoder = model_manager.image_encoder
|
27 |
+
self.unet = model_manager.unet
|
28 |
+
self.vae_encoder = model_manager.vae_encoder
|
29 |
+
self.vae_decoder = model_manager.vae_decoder
|
30 |
+
|
31 |
+
|
32 |
+
@staticmethod
|
33 |
+
def from_model_manager(model_manager: ModelManager, **kwargs):
|
34 |
+
pipe = SVDVideoPipeline(device=model_manager.device, torch_dtype=model_manager.torch_dtype)
|
35 |
+
pipe.fetch_main_models(model_manager)
|
36 |
+
return pipe
|
37 |
+
|
38 |
+
|
39 |
+
def preprocess_image(self, image):
|
40 |
+
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
41 |
+
return image
|
42 |
+
|
43 |
+
|
44 |
+
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
45 |
+
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
46 |
+
image = image.cpu().permute(1, 2, 0).numpy()
|
47 |
+
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
48 |
+
return image
|
49 |
+
|
50 |
+
|
51 |
+
def encode_image_with_clip(self, image):
|
52 |
+
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
53 |
+
image = SVDCLIPImageProcessor().resize_with_antialiasing(image, (224, 224))
|
54 |
+
image = (image + 1.0) / 2.0
|
55 |
+
mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).reshape(1, 3, 1, 1).to(device=self.device, dtype=self.torch_dtype)
|
56 |
+
std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).reshape(1, 3, 1, 1).to(device=self.device, dtype=self.torch_dtype)
|
57 |
+
image = (image - mean) / std
|
58 |
+
image_emb = self.image_encoder(image)
|
59 |
+
return image_emb
|
60 |
+
|
61 |
+
|
62 |
+
def encode_image_with_vae(self, image, noise_aug_strength):
|
63 |
+
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
64 |
+
noise = torch.randn(image.shape, device="cpu", dtype=self.torch_dtype).to(self.device)
|
65 |
+
image = image + noise_aug_strength * noise
|
66 |
+
image_emb = self.vae_encoder(image) / self.vae_encoder.scaling_factor
|
67 |
+
return image_emb
|
68 |
+
|
69 |
+
|
70 |
+
def encode_video_with_vae(self, video):
|
71 |
+
video = torch.concat([self.preprocess_image(frame) for frame in video], dim=0)
|
72 |
+
video = rearrange(video, "T C H W -> 1 C T H W")
|
73 |
+
video = video.to(device=self.device, dtype=self.torch_dtype)
|
74 |
+
latents = self.vae_encoder.encode_video(video)
|
75 |
+
latents = rearrange(latents[0], "C T H W -> T C H W")
|
76 |
+
return latents
|
77 |
+
|
78 |
+
|
79 |
+
def tensor2video(self, frames):
|
80 |
+
frames = rearrange(frames, "C T H W -> T H W C")
|
81 |
+
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
|
82 |
+
frames = [Image.fromarray(frame) for frame in frames]
|
83 |
+
return frames
|
84 |
+
|
85 |
+
|
86 |
+
def calculate_noise_pred(
|
87 |
+
self,
|
88 |
+
latents,
|
89 |
+
timestep,
|
90 |
+
add_time_id,
|
91 |
+
cfg_scales,
|
92 |
+
image_emb_vae_posi, image_emb_clip_posi,
|
93 |
+
image_emb_vae_nega, image_emb_clip_nega
|
94 |
+
):
|
95 |
+
# Positive side
|
96 |
+
noise_pred_posi = self.unet(
|
97 |
+
torch.cat([latents, image_emb_vae_posi], dim=1),
|
98 |
+
timestep, image_emb_clip_posi, add_time_id
|
99 |
+
)
|
100 |
+
# Negative side
|
101 |
+
noise_pred_nega = self.unet(
|
102 |
+
torch.cat([latents, image_emb_vae_nega], dim=1),
|
103 |
+
timestep, image_emb_clip_nega, add_time_id
|
104 |
+
)
|
105 |
+
|
106 |
+
# Classifier-free guidance
|
107 |
+
noise_pred = noise_pred_nega + cfg_scales * (noise_pred_posi - noise_pred_nega)
|
108 |
+
|
109 |
+
return noise_pred
|
110 |
+
|
111 |
+
|
112 |
+
def post_process_latents(self, latents, post_normalize=True, contrast_enhance_scale=1.0):
|
113 |
+
if post_normalize:
|
114 |
+
mean, std = latents.mean(), latents.std()
|
115 |
+
latents = (latents - latents.mean(dim=[1, 2, 3], keepdim=True)) / latents.std(dim=[1, 2, 3], keepdim=True) * std + mean
|
116 |
+
latents = latents * contrast_enhance_scale
|
117 |
+
return latents
|
118 |
+
|
119 |
+
|
120 |
+
@torch.no_grad()
|
121 |
+
def __call__(
|
122 |
+
self,
|
123 |
+
input_image=None,
|
124 |
+
input_video=None,
|
125 |
+
mask_frames=[],
|
126 |
+
mask_frame_ids=[],
|
127 |
+
min_cfg_scale=1.0,
|
128 |
+
max_cfg_scale=3.0,
|
129 |
+
denoising_strength=1.0,
|
130 |
+
num_frames=25,
|
131 |
+
height=576,
|
132 |
+
width=1024,
|
133 |
+
fps=7,
|
134 |
+
motion_bucket_id=127,
|
135 |
+
noise_aug_strength=0.02,
|
136 |
+
num_inference_steps=20,
|
137 |
+
post_normalize=True,
|
138 |
+
contrast_enhance_scale=1.2,
|
139 |
+
progress_bar_cmd=tqdm,
|
140 |
+
progress_bar_st=None,
|
141 |
+
):
|
142 |
+
# Prepare scheduler
|
143 |
+
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
|
144 |
+
|
145 |
+
# Prepare latent tensors
|
146 |
+
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).to(self.device)
|
147 |
+
if denoising_strength == 1.0:
|
148 |
+
latents = noise.clone()
|
149 |
+
else:
|
150 |
+
latents = self.encode_video_with_vae(input_video)
|
151 |
+
latents = self.scheduler.add_noise(latents, noise, self.scheduler.timesteps[0])
|
152 |
+
|
153 |
+
# Prepare mask frames
|
154 |
+
if len(mask_frames) > 0:
|
155 |
+
mask_latents = self.encode_video_with_vae(mask_frames)
|
156 |
+
|
157 |
+
# Encode image
|
158 |
+
image_emb_clip_posi = self.encode_image_with_clip(input_image)
|
159 |
+
image_emb_clip_nega = torch.zeros_like(image_emb_clip_posi)
|
160 |
+
image_emb_vae_posi = repeat(self.encode_image_with_vae(input_image, noise_aug_strength), "B C H W -> (B T) C H W", T=num_frames)
|
161 |
+
image_emb_vae_nega = torch.zeros_like(image_emb_vae_posi)
|
162 |
+
|
163 |
+
# Prepare classifier-free guidance
|
164 |
+
cfg_scales = torch.linspace(min_cfg_scale, max_cfg_scale, num_frames)
|
165 |
+
cfg_scales = cfg_scales.reshape(num_frames, 1, 1, 1).to(device=self.device, dtype=self.torch_dtype)
|
166 |
+
|
167 |
+
# Prepare positional id
|
168 |
+
add_time_id = torch.tensor([[fps-1, motion_bucket_id, noise_aug_strength]], device=self.device)
|
169 |
+
|
170 |
+
# Denoise
|
171 |
+
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
172 |
+
|
173 |
+
# Mask frames
|
174 |
+
for frame_id, mask_frame_id in enumerate(mask_frame_ids):
|
175 |
+
latents[mask_frame_id] = self.scheduler.add_noise(mask_latents[frame_id], noise[mask_frame_id], timestep)
|
176 |
+
|
177 |
+
# Fetch model output
|
178 |
+
noise_pred = self.calculate_noise_pred(
|
179 |
+
latents, timestep, add_time_id, cfg_scales,
|
180 |
+
image_emb_vae_posi, image_emb_clip_posi, image_emb_vae_nega, image_emb_clip_nega
|
181 |
+
)
|
182 |
+
|
183 |
+
# Forward Euler
|
184 |
+
latents = self.scheduler.step(noise_pred, timestep, latents)
|
185 |
+
|
186 |
+
# Update progress bar
|
187 |
+
if progress_bar_st is not None:
|
188 |
+
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
189 |
+
|
190 |
+
# Decode image
|
191 |
+
latents = self.post_process_latents(latents, post_normalize=post_normalize, contrast_enhance_scale=contrast_enhance_scale)
|
192 |
+
video = self.vae_decoder.decode_video(latents, progress_bar=progress_bar_cmd)
|
193 |
+
video = self.tensor2video(video)
|
194 |
+
|
195 |
+
return video
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
class SVDCLIPImageProcessor:
|
200 |
+
def __init__(self):
|
201 |
+
pass
|
202 |
+
|
203 |
+
def resize_with_antialiasing(self, input, size, interpolation="bicubic", align_corners=True):
|
204 |
+
h, w = input.shape[-2:]
|
205 |
+
factors = (h / size[0], w / size[1])
|
206 |
+
|
207 |
+
# First, we have to determine sigma
|
208 |
+
# Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
|
209 |
+
sigmas = (
|
210 |
+
max((factors[0] - 1.0) / 2.0, 0.001),
|
211 |
+
max((factors[1] - 1.0) / 2.0, 0.001),
|
212 |
+
)
|
213 |
+
|
214 |
+
# Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
|
215 |
+
# https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
|
216 |
+
# But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
|
217 |
+
ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
|
218 |
+
|
219 |
+
# Make sure it is odd
|
220 |
+
if (ks[0] % 2) == 0:
|
221 |
+
ks = ks[0] + 1, ks[1]
|
222 |
+
|
223 |
+
if (ks[1] % 2) == 0:
|
224 |
+
ks = ks[0], ks[1] + 1
|
225 |
+
|
226 |
+
input = self._gaussian_blur2d(input, ks, sigmas)
|
227 |
+
|
228 |
+
output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
|
229 |
+
return output
|
230 |
+
|
231 |
+
|
232 |
+
def _compute_padding(self, kernel_size):
|
233 |
+
"""Compute padding tuple."""
|
234 |
+
# 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
|
235 |
+
# https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
|
236 |
+
if len(kernel_size) < 2:
|
237 |
+
raise AssertionError(kernel_size)
|
238 |
+
computed = [k - 1 for k in kernel_size]
|
239 |
+
|
240 |
+
# for even kernels we need to do asymmetric padding :(
|
241 |
+
out_padding = 2 * len(kernel_size) * [0]
|
242 |
+
|
243 |
+
for i in range(len(kernel_size)):
|
244 |
+
computed_tmp = computed[-(i + 1)]
|
245 |
+
|
246 |
+
pad_front = computed_tmp // 2
|
247 |
+
pad_rear = computed_tmp - pad_front
|
248 |
+
|
249 |
+
out_padding[2 * i + 0] = pad_front
|
250 |
+
out_padding[2 * i + 1] = pad_rear
|
251 |
+
|
252 |
+
return out_padding
|
253 |
+
|
254 |
+
|
255 |
+
def _filter2d(self, input, kernel):
|
256 |
+
# prepare kernel
|
257 |
+
b, c, h, w = input.shape
|
258 |
+
tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
|
259 |
+
|
260 |
+
tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
|
261 |
+
|
262 |
+
height, width = tmp_kernel.shape[-2:]
|
263 |
+
|
264 |
+
padding_shape: list[int] = self._compute_padding([height, width])
|
265 |
+
input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
|
266 |
+
|
267 |
+
# kernel and input tensor reshape to align element-wise or batch-wise params
|
268 |
+
tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
|
269 |
+
input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
|
270 |
+
|
271 |
+
# convolve the tensor with the kernel.
|
272 |
+
output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
|
273 |
+
|
274 |
+
out = output.view(b, c, h, w)
|
275 |
+
return out
|
276 |
+
|
277 |
+
|
278 |
+
def _gaussian(self, window_size: int, sigma):
|
279 |
+
if isinstance(sigma, float):
|
280 |
+
sigma = torch.tensor([[sigma]])
|
281 |
+
|
282 |
+
batch_size = sigma.shape[0]
|
283 |
+
|
284 |
+
x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
|
285 |
+
|
286 |
+
if window_size % 2 == 0:
|
287 |
+
x = x + 0.5
|
288 |
+
|
289 |
+
gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
|
290 |
+
|
291 |
+
return gauss / gauss.sum(-1, keepdim=True)
|
292 |
+
|
293 |
+
|
294 |
+
def _gaussian_blur2d(self, input, kernel_size, sigma):
|
295 |
+
if isinstance(sigma, tuple):
|
296 |
+
sigma = torch.tensor([sigma], dtype=input.dtype)
|
297 |
+
else:
|
298 |
+
sigma = sigma.to(dtype=input.dtype)
|
299 |
+
|
300 |
+
ky, kx = int(kernel_size[0]), int(kernel_size[1])
|
301 |
+
bs = sigma.shape[0]
|
302 |
+
kernel_x = self._gaussian(kx, sigma[:, 1].view(bs, 1))
|
303 |
+
kernel_y = self._gaussian(ky, sigma[:, 0].view(bs, 1))
|
304 |
+
out_x = self._filter2d(input, kernel_x[..., None, :])
|
305 |
+
out = self._filter2d(out_x, kernel_y[..., None])
|
306 |
+
|
307 |
+
return out
|
diffsynth/processors/FastBlend.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import cupy as cp
|
3 |
+
import numpy as np
|
4 |
+
from tqdm import tqdm
|
5 |
+
from ..extensions.FastBlend.patch_match import PyramidPatchMatcher
|
6 |
+
from ..extensions.FastBlend.runners.fast import TableManager
|
7 |
+
from .base import VideoProcessor
|
8 |
+
|
9 |
+
|
10 |
+
class FastBlendSmoother(VideoProcessor):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
inference_mode="fast", batch_size=8, window_size=60,
|
14 |
+
minimum_patch_size=5, threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0, initialize="identity", tracking_window_size=0
|
15 |
+
):
|
16 |
+
self.inference_mode = inference_mode
|
17 |
+
self.batch_size = batch_size
|
18 |
+
self.window_size = window_size
|
19 |
+
self.ebsynth_config = {
|
20 |
+
"minimum_patch_size": minimum_patch_size,
|
21 |
+
"threads_per_block": threads_per_block,
|
22 |
+
"num_iter": num_iter,
|
23 |
+
"gpu_id": gpu_id,
|
24 |
+
"guide_weight": guide_weight,
|
25 |
+
"initialize": initialize,
|
26 |
+
"tracking_window_size": tracking_window_size
|
27 |
+
}
|
28 |
+
|
29 |
+
@staticmethod
|
30 |
+
def from_model_manager(model_manager, **kwargs):
|
31 |
+
# TODO: fetch GPU ID from model_manager
|
32 |
+
return FastBlendSmoother(**kwargs)
|
33 |
+
|
34 |
+
def inference_fast(self, frames_guide, frames_style):
|
35 |
+
table_manager = TableManager()
|
36 |
+
patch_match_engine = PyramidPatchMatcher(
|
37 |
+
image_height=frames_style[0].shape[0],
|
38 |
+
image_width=frames_style[0].shape[1],
|
39 |
+
channel=3,
|
40 |
+
**self.ebsynth_config
|
41 |
+
)
|
42 |
+
# left part
|
43 |
+
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, self.batch_size, desc="Fast Mode Step 1/4")
|
44 |
+
table_l = table_manager.remapping_table_to_blending_table(table_l)
|
45 |
+
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, self.window_size, self.batch_size, desc="Fast Mode Step 2/4")
|
46 |
+
# right part
|
47 |
+
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, self.batch_size, desc="Fast Mode Step 3/4")
|
48 |
+
table_r = table_manager.remapping_table_to_blending_table(table_r)
|
49 |
+
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, self.window_size, self.batch_size, desc="Fast Mode Step 4/4")[::-1]
|
50 |
+
# merge
|
51 |
+
frames = []
|
52 |
+
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
|
53 |
+
weight_m = -1
|
54 |
+
weight = weight_l + weight_m + weight_r
|
55 |
+
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
|
56 |
+
frames.append(frame)
|
57 |
+
frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
|
58 |
+
frames = [Image.fromarray(frame) for frame in frames]
|
59 |
+
return frames
|
60 |
+
|
61 |
+
def inference_balanced(self, frames_guide, frames_style):
|
62 |
+
patch_match_engine = PyramidPatchMatcher(
|
63 |
+
image_height=frames_style[0].shape[0],
|
64 |
+
image_width=frames_style[0].shape[1],
|
65 |
+
channel=3,
|
66 |
+
**self.ebsynth_config
|
67 |
+
)
|
68 |
+
output_frames = []
|
69 |
+
# tasks
|
70 |
+
n = len(frames_style)
|
71 |
+
tasks = []
|
72 |
+
for target in range(n):
|
73 |
+
for source in range(target - self.window_size, target + self.window_size + 1):
|
74 |
+
if source >= 0 and source < n and source != target:
|
75 |
+
tasks.append((source, target))
|
76 |
+
# run
|
77 |
+
frames = [(None, 1) for i in range(n)]
|
78 |
+
for batch_id in tqdm(range(0, len(tasks), self.batch_size), desc="Balanced Mode"):
|
79 |
+
tasks_batch = tasks[batch_id: min(batch_id+self.batch_size, len(tasks))]
|
80 |
+
source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
|
81 |
+
target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
|
82 |
+
source_style = np.stack([frames_style[source] for source, target in tasks_batch])
|
83 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
84 |
+
for (source, target), result in zip(tasks_batch, target_style):
|
85 |
+
frame, weight = frames[target]
|
86 |
+
if frame is None:
|
87 |
+
frame = frames_style[target]
|
88 |
+
frames[target] = (
|
89 |
+
frame * (weight / (weight + 1)) + result / (weight + 1),
|
90 |
+
weight + 1
|
91 |
+
)
|
92 |
+
if weight + 1 == min(n, target + self.window_size + 1) - max(0, target - self.window_size):
|
93 |
+
frame = frame.clip(0, 255).astype("uint8")
|
94 |
+
output_frames.append(Image.fromarray(frame))
|
95 |
+
frames[target] = (None, 1)
|
96 |
+
return output_frames
|
97 |
+
|
98 |
+
def inference_accurate(self, frames_guide, frames_style):
|
99 |
+
patch_match_engine = PyramidPatchMatcher(
|
100 |
+
image_height=frames_style[0].shape[0],
|
101 |
+
image_width=frames_style[0].shape[1],
|
102 |
+
channel=3,
|
103 |
+
use_mean_target_style=True,
|
104 |
+
**self.ebsynth_config
|
105 |
+
)
|
106 |
+
output_frames = []
|
107 |
+
# run
|
108 |
+
n = len(frames_style)
|
109 |
+
for target in tqdm(range(n), desc="Accurate Mode"):
|
110 |
+
l, r = max(target - self.window_size, 0), min(target + self.window_size + 1, n)
|
111 |
+
remapped_frames = []
|
112 |
+
for i in range(l, r, self.batch_size):
|
113 |
+
j = min(i + self.batch_size, r)
|
114 |
+
source_guide = np.stack([frames_guide[source] for source in range(i, j)])
|
115 |
+
target_guide = np.stack([frames_guide[target]] * (j - i))
|
116 |
+
source_style = np.stack([frames_style[source] for source in range(i, j)])
|
117 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
118 |
+
remapped_frames.append(target_style)
|
119 |
+
frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
|
120 |
+
frame = frame.clip(0, 255).astype("uint8")
|
121 |
+
output_frames.append(Image.fromarray(frame))
|
122 |
+
return output_frames
|
123 |
+
|
124 |
+
def release_vram(self):
|
125 |
+
mempool = cp.get_default_memory_pool()
|
126 |
+
pinned_mempool = cp.get_default_pinned_memory_pool()
|
127 |
+
mempool.free_all_blocks()
|
128 |
+
pinned_mempool.free_all_blocks()
|
129 |
+
|
130 |
+
def __call__(self, rendered_frames, original_frames=None, **kwargs):
|
131 |
+
rendered_frames = [np.array(frame) for frame in rendered_frames]
|
132 |
+
original_frames = [np.array(frame) for frame in original_frames]
|
133 |
+
if self.inference_mode == "fast":
|
134 |
+
output_frames = self.inference_fast(original_frames, rendered_frames)
|
135 |
+
elif self.inference_mode == "balanced":
|
136 |
+
output_frames = self.inference_balanced(original_frames, rendered_frames)
|
137 |
+
elif self.inference_mode == "accurate":
|
138 |
+
output_frames = self.inference_accurate(original_frames, rendered_frames)
|
139 |
+
else:
|
140 |
+
raise ValueError("inference_mode must be fast, balanced or accurate")
|
141 |
+
self.release_vram()
|
142 |
+
return output_frames
|