vilarin commited on
Commit
2ba49a8
1 Parent(s): fec3be6

Upload 63 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. diffsynth/__init__.py +6 -0
  2. diffsynth/controlnets/__init__.py +2 -0
  3. diffsynth/controlnets/controlnet_unit.py +53 -0
  4. diffsynth/controlnets/processors.py +51 -0
  5. diffsynth/data/__init__.py +1 -0
  6. diffsynth/data/video.py +148 -0
  7. diffsynth/extensions/ESRGAN/__init__.py +118 -0
  8. diffsynth/extensions/FastBlend/__init__.py +63 -0
  9. diffsynth/extensions/FastBlend/api.py +397 -0
  10. diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
  11. diffsynth/extensions/FastBlend/data.py +146 -0
  12. diffsynth/extensions/FastBlend/patch_match.py +298 -0
  13. diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
  14. diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
  15. diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
  16. diffsynth/extensions/FastBlend/runners/fast.py +141 -0
  17. diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
  18. diffsynth/extensions/RIFE/__init__.py +241 -0
  19. diffsynth/models/__init__.py +482 -0
  20. diffsynth/models/attention.py +89 -0
  21. diffsynth/models/hunyuan_dit.py +451 -0
  22. diffsynth/models/hunyuan_dit_text_encoder.py +161 -0
  23. diffsynth/models/sd_controlnet.py +587 -0
  24. diffsynth/models/sd_ipadapter.py +56 -0
  25. diffsynth/models/sd_lora.py +60 -0
  26. diffsynth/models/sd_motion.py +198 -0
  27. diffsynth/models/sd_text_encoder.py +320 -0
  28. diffsynth/models/sd_unet.py +0 -0
  29. diffsynth/models/sd_vae_decoder.py +332 -0
  30. diffsynth/models/sd_vae_encoder.py +278 -0
  31. diffsynth/models/sdxl_ipadapter.py +121 -0
  32. diffsynth/models/sdxl_motion.py +103 -0
  33. diffsynth/models/sdxl_text_encoder.py +757 -0
  34. diffsynth/models/sdxl_unet.py +0 -0
  35. diffsynth/models/sdxl_vae_decoder.py +15 -0
  36. diffsynth/models/sdxl_vae_encoder.py +15 -0
  37. diffsynth/models/svd_image_encoder.py +504 -0
  38. diffsynth/models/svd_unet.py +0 -0
  39. diffsynth/models/svd_vae_decoder.py +577 -0
  40. diffsynth/models/svd_vae_encoder.py +138 -0
  41. diffsynth/models/tiler.py +106 -0
  42. diffsynth/pipelines/__init__.py +6 -0
  43. diffsynth/pipelines/dancer.py +174 -0
  44. diffsynth/pipelines/hunyuan_dit.py +298 -0
  45. diffsynth/pipelines/stable_diffusion.py +167 -0
  46. diffsynth/pipelines/stable_diffusion_video.py +356 -0
  47. diffsynth/pipelines/stable_diffusion_xl.py +175 -0
  48. diffsynth/pipelines/stable_diffusion_xl_video.py +190 -0
  49. diffsynth/pipelines/stable_video_diffusion.py +307 -0
  50. diffsynth/processors/FastBlend.py +142 -0
diffsynth/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .data import *
2
+ from .models import *
3
+ from .prompts import *
4
+ from .schedulers import *
5
+ from .pipelines import *
6
+ from .controlnets import *
diffsynth/controlnets/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager
2
+ from .processors import Annotator
diffsynth/controlnets/controlnet_unit.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from .processors import Processor_id
4
+
5
+
6
+ class ControlNetConfigUnit:
7
+ def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
8
+ self.processor_id = processor_id
9
+ self.model_path = model_path
10
+ self.scale = scale
11
+
12
+
13
+ class ControlNetUnit:
14
+ def __init__(self, processor, model, scale=1.0):
15
+ self.processor = processor
16
+ self.model = model
17
+ self.scale = scale
18
+
19
+
20
+ class MultiControlNetManager:
21
+ def __init__(self, controlnet_units=[]):
22
+ self.processors = [unit.processor for unit in controlnet_units]
23
+ self.models = [unit.model for unit in controlnet_units]
24
+ self.scales = [unit.scale for unit in controlnet_units]
25
+
26
+ def process_image(self, image, processor_id=None):
27
+ if processor_id is None:
28
+ processed_image = [processor(image) for processor in self.processors]
29
+ else:
30
+ processed_image = [self.processors[processor_id](image)]
31
+ processed_image = torch.concat([
32
+ torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
33
+ for image_ in processed_image
34
+ ], dim=0)
35
+ return processed_image
36
+
37
+ def __call__(
38
+ self,
39
+ sample, timestep, encoder_hidden_states, conditionings,
40
+ tiled=False, tile_size=64, tile_stride=32
41
+ ):
42
+ res_stack = None
43
+ for conditioning, model, scale in zip(conditionings, self.models, self.scales):
44
+ res_stack_ = model(
45
+ sample, timestep, encoder_hidden_states, conditioning,
46
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
47
+ )
48
+ res_stack_ = [res * scale for res in res_stack_]
49
+ if res_stack is None:
50
+ res_stack = res_stack_
51
+ else:
52
+ res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
53
+ return res_stack
diffsynth/controlnets/processors.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import Literal, TypeAlias
2
+ import warnings
3
+ with warnings.catch_warnings():
4
+ warnings.simplefilter("ignore")
5
+ from controlnet_aux.processor import (
6
+ CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector
7
+ )
8
+
9
+
10
+ Processor_id: TypeAlias = Literal[
11
+ "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"
12
+ ]
13
+
14
+ class Annotator:
15
+ def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None):
16
+ if processor_id == "canny":
17
+ self.processor = CannyDetector()
18
+ elif processor_id == "depth":
19
+ self.processor = MidasDetector.from_pretrained(model_path).to("cuda")
20
+ elif processor_id == "softedge":
21
+ self.processor = HEDdetector.from_pretrained(model_path).to("cuda")
22
+ elif processor_id == "lineart":
23
+ self.processor = LineartDetector.from_pretrained(model_path).to("cuda")
24
+ elif processor_id == "lineart_anime":
25
+ self.processor = LineartAnimeDetector.from_pretrained(model_path).to("cuda")
26
+ elif processor_id == "openpose":
27
+ self.processor = OpenposeDetector.from_pretrained(model_path).to("cuda")
28
+ elif processor_id == "tile":
29
+ self.processor = None
30
+ else:
31
+ raise ValueError(f"Unsupported processor_id: {processor_id}")
32
+
33
+ self.processor_id = processor_id
34
+ self.detect_resolution = detect_resolution
35
+
36
+ def __call__(self, image):
37
+ width, height = image.size
38
+ if self.processor_id == "openpose":
39
+ kwargs = {
40
+ "include_body": True,
41
+ "include_hand": True,
42
+ "include_face": True
43
+ }
44
+ else:
45
+ kwargs = {}
46
+ if self.processor is not None:
47
+ detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
48
+ image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
49
+ image = image.resize((width, height))
50
+ return image
51
+
diffsynth/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .video import VideoData, save_video, save_frames
diffsynth/data/video.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imageio, os
2
+ import numpy as np
3
+ from PIL import Image
4
+ from tqdm import tqdm
5
+
6
+
7
+ class LowMemoryVideo:
8
+ def __init__(self, file_name):
9
+ self.reader = imageio.get_reader(file_name)
10
+
11
+ def __len__(self):
12
+ return self.reader.count_frames()
13
+
14
+ def __getitem__(self, item):
15
+ return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
16
+
17
+ def __del__(self):
18
+ self.reader.close()
19
+
20
+
21
+ def split_file_name(file_name):
22
+ result = []
23
+ number = -1
24
+ for i in file_name:
25
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
26
+ if number == -1:
27
+ number = 0
28
+ number = number*10 + ord(i) - ord("0")
29
+ else:
30
+ if number != -1:
31
+ result.append(number)
32
+ number = -1
33
+ result.append(i)
34
+ if number != -1:
35
+ result.append(number)
36
+ result = tuple(result)
37
+ return result
38
+
39
+
40
+ def search_for_images(folder):
41
+ file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
42
+ file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
43
+ file_list = [i[1] for i in sorted(file_list)]
44
+ file_list = [os.path.join(folder, i) for i in file_list]
45
+ return file_list
46
+
47
+
48
+ class LowMemoryImageFolder:
49
+ def __init__(self, folder, file_list=None):
50
+ if file_list is None:
51
+ self.file_list = search_for_images(folder)
52
+ else:
53
+ self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
54
+
55
+ def __len__(self):
56
+ return len(self.file_list)
57
+
58
+ def __getitem__(self, item):
59
+ return Image.open(self.file_list[item]).convert("RGB")
60
+
61
+ def __del__(self):
62
+ pass
63
+
64
+
65
+ def crop_and_resize(image, height, width):
66
+ image = np.array(image)
67
+ image_height, image_width, _ = image.shape
68
+ if image_height / image_width < height / width:
69
+ croped_width = int(image_height / height * width)
70
+ left = (image_width - croped_width) // 2
71
+ image = image[:, left: left+croped_width]
72
+ image = Image.fromarray(image).resize((width, height))
73
+ else:
74
+ croped_height = int(image_width / width * height)
75
+ left = (image_height - croped_height) // 2
76
+ image = image[left: left+croped_height, :]
77
+ image = Image.fromarray(image).resize((width, height))
78
+ return image
79
+
80
+
81
+ class VideoData:
82
+ def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):
83
+ if video_file is not None:
84
+ self.data_type = "video"
85
+ self.data = LowMemoryVideo(video_file, **kwargs)
86
+ elif image_folder is not None:
87
+ self.data_type = "images"
88
+ self.data = LowMemoryImageFolder(image_folder, **kwargs)
89
+ else:
90
+ raise ValueError("Cannot open video or image folder")
91
+ self.length = None
92
+ self.set_shape(height, width)
93
+
94
+ def raw_data(self):
95
+ frames = []
96
+ for i in range(self.__len__()):
97
+ frames.append(self.__getitem__(i))
98
+ return frames
99
+
100
+ def set_length(self, length):
101
+ self.length = length
102
+
103
+ def set_shape(self, height, width):
104
+ self.height = height
105
+ self.width = width
106
+
107
+ def __len__(self):
108
+ if self.length is None:
109
+ return len(self.data)
110
+ else:
111
+ return self.length
112
+
113
+ def shape(self):
114
+ if self.height is not None and self.width is not None:
115
+ return self.height, self.width
116
+ else:
117
+ height, width, _ = self.__getitem__(0).shape
118
+ return height, width
119
+
120
+ def __getitem__(self, item):
121
+ frame = self.data.__getitem__(item)
122
+ width, height = frame.size
123
+ if self.height is not None and self.width is not None:
124
+ if self.height != height or self.width != width:
125
+ frame = crop_and_resize(frame, self.height, self.width)
126
+ return frame
127
+
128
+ def __del__(self):
129
+ pass
130
+
131
+ def save_images(self, folder):
132
+ os.makedirs(folder, exist_ok=True)
133
+ for i in tqdm(range(self.__len__()), desc="Saving images"):
134
+ frame = self.__getitem__(i)
135
+ frame.save(os.path.join(folder, f"{i}.png"))
136
+
137
+
138
+ def save_video(frames, save_path, fps, quality=9):
139
+ writer = imageio.get_writer(save_path, fps=fps, quality=quality)
140
+ for frame in tqdm(frames, desc="Saving video"):
141
+ frame = np.array(frame)
142
+ writer.append_data(frame)
143
+ writer.close()
144
+
145
+ def save_frames(frames, save_path):
146
+ os.makedirs(save_path, exist_ok=True)
147
+ for i, frame in enumerate(tqdm(frames, desc="Saving images")):
148
+ frame.save(os.path.join(save_path, f"{i}.png"))
diffsynth/extensions/ESRGAN/__init__.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import repeat
3
+ from PIL import Image
4
+ import numpy as np
5
+
6
+
7
+ class ResidualDenseBlock(torch.nn.Module):
8
+
9
+ def __init__(self, num_feat=64, num_grow_ch=32):
10
+ super(ResidualDenseBlock, self).__init__()
11
+ self.conv1 = torch.nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
12
+ self.conv2 = torch.nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
13
+ self.conv3 = torch.nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
14
+ self.conv4 = torch.nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
15
+ self.conv5 = torch.nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
16
+ self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
17
+
18
+ def forward(self, x):
19
+ x1 = self.lrelu(self.conv1(x))
20
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
21
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
22
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
23
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
24
+ return x5 * 0.2 + x
25
+
26
+
27
+ class RRDB(torch.nn.Module):
28
+
29
+ def __init__(self, num_feat, num_grow_ch=32):
30
+ super(RRDB, self).__init__()
31
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
32
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
33
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
34
+
35
+ def forward(self, x):
36
+ out = self.rdb1(x)
37
+ out = self.rdb2(out)
38
+ out = self.rdb3(out)
39
+ return out * 0.2 + x
40
+
41
+
42
+ class RRDBNet(torch.nn.Module):
43
+
44
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32):
45
+ super(RRDBNet, self).__init__()
46
+ self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
47
+ self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
48
+ self.conv_body = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
49
+ # upsample
50
+ self.conv_up1 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
51
+ self.conv_up2 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
52
+ self.conv_hr = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
53
+ self.conv_last = torch.nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
54
+ self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
55
+
56
+ def forward(self, x):
57
+ feat = x
58
+ feat = self.conv_first(feat)
59
+ body_feat = self.conv_body(self.body(feat))
60
+ feat = feat + body_feat
61
+ # upsample
62
+ feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
63
+ feat = self.lrelu(self.conv_up1(feat))
64
+ feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
65
+ feat = self.lrelu(self.conv_up2(feat))
66
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
67
+ return out
68
+
69
+
70
+ class ESRGAN(torch.nn.Module):
71
+ def __init__(self, model):
72
+ super().__init__()
73
+ self.model = model
74
+
75
+ @staticmethod
76
+ def from_pretrained(model_path):
77
+ model = RRDBNet()
78
+ state_dict = torch.load(model_path, map_location="cpu")["params_ema"]
79
+ model.load_state_dict(state_dict)
80
+ model.eval()
81
+ return ESRGAN(model)
82
+
83
+ def process_image(self, image):
84
+ image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
85
+ return image
86
+
87
+ def process_images(self, images):
88
+ images = [self.process_image(image) for image in images]
89
+ images = torch.stack(images)
90
+ return images
91
+
92
+ def decode_images(self, images):
93
+ images = (images.permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
94
+ images = [Image.fromarray(image) for image in images]
95
+ return images
96
+
97
+ @torch.no_grad()
98
+ def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
99
+ # Preprocess
100
+ input_tensor = self.process_images(images)
101
+
102
+ # Interpolate
103
+ output_tensor = []
104
+ for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
105
+ batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
106
+ batch_input_tensor = input_tensor[batch_id: batch_id_]
107
+ batch_input_tensor = batch_input_tensor.to(
108
+ device=self.model.conv_first.weight.device,
109
+ dtype=self.model.conv_first.weight.dtype)
110
+ batch_output_tensor = self.model(batch_input_tensor)
111
+ output_tensor.append(batch_output_tensor.cpu())
112
+
113
+ # Output
114
+ output_tensor = torch.concat(output_tensor, dim=0)
115
+
116
+ # To images
117
+ output_images = self.decode_images(output_tensor)
118
+ return output_images
diffsynth/extensions/FastBlend/__init__.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .runners.fast import TableManager, PyramidPatchMatcher
2
+ from PIL import Image
3
+ import numpy as np
4
+ import cupy as cp
5
+
6
+
7
+ class FastBlendSmoother:
8
+ def __init__(self):
9
+ self.batch_size = 8
10
+ self.window_size = 64
11
+ self.ebsynth_config = {
12
+ "minimum_patch_size": 5,
13
+ "threads_per_block": 8,
14
+ "num_iter": 5,
15
+ "gpu_id": 0,
16
+ "guide_weight": 10.0,
17
+ "initialize": "identity",
18
+ "tracking_window_size": 0,
19
+ }
20
+
21
+ @staticmethod
22
+ def from_model_manager(model_manager):
23
+ # TODO: fetch GPU ID from model_manager
24
+ return FastBlendSmoother()
25
+
26
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config):
27
+ frames_guide = [np.array(frame) for frame in frames_guide]
28
+ frames_style = [np.array(frame) for frame in frames_style]
29
+ table_manager = TableManager()
30
+ patch_match_engine = PyramidPatchMatcher(
31
+ image_height=frames_style[0].shape[0],
32
+ image_width=frames_style[0].shape[1],
33
+ channel=3,
34
+ **ebsynth_config
35
+ )
36
+ # left part
37
+ table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4")
38
+ table_l = table_manager.remapping_table_to_blending_table(table_l)
39
+ table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4")
40
+ # right part
41
+ table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4")
42
+ table_r = table_manager.remapping_table_to_blending_table(table_r)
43
+ table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1]
44
+ # merge
45
+ frames = []
46
+ for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
47
+ weight_m = -1
48
+ weight = weight_l + weight_m + weight_r
49
+ frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
50
+ frames.append(frame)
51
+ frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames]
52
+ return frames
53
+
54
+ def __call__(self, rendered_frames, original_frames=None, **kwargs):
55
+ frames = self.run(
56
+ original_frames, rendered_frames,
57
+ self.batch_size, self.window_size, self.ebsynth_config
58
+ )
59
+ mempool = cp.get_default_memory_pool()
60
+ pinned_mempool = cp.get_default_pinned_memory_pool()
61
+ mempool.free_all_blocks()
62
+ pinned_mempool.free_all_blocks()
63
+ return frames
diffsynth/extensions/FastBlend/api.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .runners import AccurateModeRunner, FastModeRunner, BalancedModeRunner, InterpolationModeRunner, InterpolationModeSingleFrameRunner
2
+ from .data import VideoData, get_video_fps, save_video, search_for_images
3
+ import os
4
+ import gradio as gr
5
+
6
+
7
+ def check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder):
8
+ frames_guide = VideoData(video_guide, video_guide_folder)
9
+ frames_style = VideoData(video_style, video_style_folder)
10
+ message = ""
11
+ if len(frames_guide) < len(frames_style):
12
+ message += f"The number of frames mismatches. Only the first {len(frames_guide)} frames of style video will be used.\n"
13
+ frames_style.set_length(len(frames_guide))
14
+ elif len(frames_guide) > len(frames_style):
15
+ message += f"The number of frames mismatches. Only the first {len(frames_style)} frames of guide video will be used.\n"
16
+ frames_guide.set_length(len(frames_style))
17
+ height_guide, width_guide = frames_guide.shape()
18
+ height_style, width_style = frames_style.shape()
19
+ if height_guide != height_style or width_guide != width_style:
20
+ message += f"The shape of frames mismatches. The frames in style video will be resized to (height: {height_guide}, width: {width_guide})\n"
21
+ frames_style.set_shape(height_guide, width_guide)
22
+ return frames_guide, frames_style, message
23
+
24
+
25
+ def smooth_video(
26
+ video_guide,
27
+ video_guide_folder,
28
+ video_style,
29
+ video_style_folder,
30
+ mode,
31
+ window_size,
32
+ batch_size,
33
+ tracking_window_size,
34
+ output_path,
35
+ fps,
36
+ minimum_patch_size,
37
+ num_iter,
38
+ guide_weight,
39
+ initialize,
40
+ progress = None,
41
+ ):
42
+ # input
43
+ frames_guide, frames_style, message = check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder)
44
+ if len(message) > 0:
45
+ print(message)
46
+ # output
47
+ if output_path == "":
48
+ if video_style is None:
49
+ output_path = os.path.join(video_style_folder, "output")
50
+ else:
51
+ output_path = os.path.join(os.path.split(video_style)[0], "output")
52
+ os.makedirs(output_path, exist_ok=True)
53
+ print("No valid output_path. Your video will be saved here:", output_path)
54
+ elif not os.path.exists(output_path):
55
+ os.makedirs(output_path, exist_ok=True)
56
+ print("Your video will be saved here:", output_path)
57
+ frames_path = os.path.join(output_path, "frames")
58
+ video_path = os.path.join(output_path, "video.mp4")
59
+ os.makedirs(frames_path, exist_ok=True)
60
+ # process
61
+ if mode == "Fast" or mode == "Balanced":
62
+ tracking_window_size = 0
63
+ ebsynth_config = {
64
+ "minimum_patch_size": minimum_patch_size,
65
+ "threads_per_block": 8,
66
+ "num_iter": num_iter,
67
+ "gpu_id": 0,
68
+ "guide_weight": guide_weight,
69
+ "initialize": initialize,
70
+ "tracking_window_size": tracking_window_size,
71
+ }
72
+ if mode == "Fast":
73
+ FastModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
74
+ elif mode == "Balanced":
75
+ BalancedModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
76
+ elif mode == "Accurate":
77
+ AccurateModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
78
+ # output
79
+ try:
80
+ fps = int(fps)
81
+ except:
82
+ fps = get_video_fps(video_style) if video_style is not None else 30
83
+ print("Fps:", fps)
84
+ print("Saving video...")
85
+ video_path = save_video(frames_path, video_path, num_frames=len(frames_style), fps=fps)
86
+ print("Success!")
87
+ print("Your frames are here:", frames_path)
88
+ print("Your video is here:", video_path)
89
+ return output_path, fps, video_path
90
+
91
+
92
+ class KeyFrameMatcher:
93
+ def __init__(self):
94
+ pass
95
+
96
+ def extract_number_from_filename(self, file_name):
97
+ result = []
98
+ number = -1
99
+ for i in file_name:
100
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
101
+ if number == -1:
102
+ number = 0
103
+ number = number*10 + ord(i) - ord("0")
104
+ else:
105
+ if number != -1:
106
+ result.append(number)
107
+ number = -1
108
+ if number != -1:
109
+ result.append(number)
110
+ result = tuple(result)
111
+ return result
112
+
113
+ def extract_number_from_filenames(self, file_names):
114
+ numbers = [self.extract_number_from_filename(file_name) for file_name in file_names]
115
+ min_length = min(len(i) for i in numbers)
116
+ for i in range(min_length-1, -1, -1):
117
+ if len(set(number[i] for number in numbers))==len(file_names):
118
+ return [number[i] for number in numbers]
119
+ return list(range(len(file_names)))
120
+
121
+ def match_using_filename(self, file_names_a, file_names_b):
122
+ file_names_b_set = set(file_names_b)
123
+ matched_file_name = []
124
+ for file_name in file_names_a:
125
+ if file_name not in file_names_b_set:
126
+ matched_file_name.append(None)
127
+ else:
128
+ matched_file_name.append(file_name)
129
+ return matched_file_name
130
+
131
+ def match_using_numbers(self, file_names_a, file_names_b):
132
+ numbers_a = self.extract_number_from_filenames(file_names_a)
133
+ numbers_b = self.extract_number_from_filenames(file_names_b)
134
+ numbers_b_dict = {number: file_name for number, file_name in zip(numbers_b, file_names_b)}
135
+ matched_file_name = []
136
+ for number in numbers_a:
137
+ if number in numbers_b_dict:
138
+ matched_file_name.append(numbers_b_dict[number])
139
+ else:
140
+ matched_file_name.append(None)
141
+ return matched_file_name
142
+
143
+ def match_filenames(self, file_names_a, file_names_b):
144
+ matched_file_name = self.match_using_filename(file_names_a, file_names_b)
145
+ if sum([i is not None for i in matched_file_name]) > 0:
146
+ return matched_file_name
147
+ matched_file_name = self.match_using_numbers(file_names_a, file_names_b)
148
+ return matched_file_name
149
+
150
+
151
+ def detect_frames(frames_path, keyframes_path):
152
+ if not os.path.exists(frames_path) and not os.path.exists(keyframes_path):
153
+ return "Please input the directory of guide video and rendered frames"
154
+ elif not os.path.exists(frames_path):
155
+ return "Please input the directory of guide video"
156
+ elif not os.path.exists(keyframes_path):
157
+ return "Please input the directory of rendered frames"
158
+ frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
159
+ keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
160
+ if len(frames)==0:
161
+ return f"No images detected in {frames_path}"
162
+ if len(keyframes)==0:
163
+ return f"No images detected in {keyframes_path}"
164
+ matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
165
+ max_filename_length = max([len(i) for i in frames])
166
+ if sum([i is not None for i in matched_keyframes])==0:
167
+ message = ""
168
+ for frame, matched_keyframe in zip(frames, matched_keyframes):
169
+ message += frame + " " * (max_filename_length - len(frame) + 1)
170
+ message += "--> No matched keyframes\n"
171
+ else:
172
+ message = ""
173
+ for frame, matched_keyframe in zip(frames, matched_keyframes):
174
+ message += frame + " " * (max_filename_length - len(frame) + 1)
175
+ if matched_keyframe is None:
176
+ message += "--> [to be rendered]\n"
177
+ else:
178
+ message += f"--> {matched_keyframe}\n"
179
+ return message
180
+
181
+
182
+ def check_input_for_interpolating(frames_path, keyframes_path):
183
+ # search for images
184
+ frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
185
+ keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
186
+ # match frames
187
+ matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
188
+ file_list = [file_name for file_name in matched_keyframes if file_name is not None]
189
+ index_style = [i for i, file_name in enumerate(matched_keyframes) if file_name is not None]
190
+ frames_guide = VideoData(None, frames_path)
191
+ frames_style = VideoData(None, keyframes_path, file_list=file_list)
192
+ # match shape
193
+ message = ""
194
+ height_guide, width_guide = frames_guide.shape()
195
+ height_style, width_style = frames_style.shape()
196
+ if height_guide != height_style or width_guide != width_style:
197
+ message += f"The shape of frames mismatches. The rendered keyframes will be resized to (height: {height_guide}, width: {width_guide})\n"
198
+ frames_style.set_shape(height_guide, width_guide)
199
+ return frames_guide, frames_style, index_style, message
200
+
201
+
202
+ def interpolate_video(
203
+ frames_path,
204
+ keyframes_path,
205
+ output_path,
206
+ fps,
207
+ batch_size,
208
+ tracking_window_size,
209
+ minimum_patch_size,
210
+ num_iter,
211
+ guide_weight,
212
+ initialize,
213
+ progress = None,
214
+ ):
215
+ # input
216
+ frames_guide, frames_style, index_style, message = check_input_for_interpolating(frames_path, keyframes_path)
217
+ if len(message) > 0:
218
+ print(message)
219
+ # output
220
+ if output_path == "":
221
+ output_path = os.path.join(keyframes_path, "output")
222
+ os.makedirs(output_path, exist_ok=True)
223
+ print("No valid output_path. Your video will be saved here:", output_path)
224
+ elif not os.path.exists(output_path):
225
+ os.makedirs(output_path, exist_ok=True)
226
+ print("Your video will be saved here:", output_path)
227
+ output_frames_path = os.path.join(output_path, "frames")
228
+ output_video_path = os.path.join(output_path, "video.mp4")
229
+ os.makedirs(output_frames_path, exist_ok=True)
230
+ # process
231
+ ebsynth_config = {
232
+ "minimum_patch_size": minimum_patch_size,
233
+ "threads_per_block": 8,
234
+ "num_iter": num_iter,
235
+ "gpu_id": 0,
236
+ "guide_weight": guide_weight,
237
+ "initialize": initialize,
238
+ "tracking_window_size": tracking_window_size
239
+ }
240
+ if len(index_style)==1:
241
+ InterpolationModeSingleFrameRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
242
+ else:
243
+ InterpolationModeRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
244
+ try:
245
+ fps = int(fps)
246
+ except:
247
+ fps = 30
248
+ print("Fps:", fps)
249
+ print("Saving video...")
250
+ video_path = save_video(output_frames_path, output_video_path, num_frames=len(frames_guide), fps=fps)
251
+ print("Success!")
252
+ print("Your frames are here:", output_frames_path)
253
+ print("Your video is here:", video_path)
254
+ return output_path, fps, video_path
255
+
256
+
257
+ def on_ui_tabs():
258
+ with gr.Blocks(analytics_enabled=False) as ui_component:
259
+ with gr.Tab("Blend"):
260
+ gr.Markdown("""
261
+ # Blend
262
+
263
+ Given a guide video and a style video, this algorithm will make the style video fluent according to the motion features of the guide video. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/208d902d-6aba-48d7-b7d5-cd120ebd306d) to see the example. Note that this extension doesn't support long videos. Please use short videos (e.g., several seconds). The algorithm is mainly designed for 512*512 resolution. Please use a larger `Minimum patch size` for higher resolution.
264
+ """)
265
+ with gr.Row():
266
+ with gr.Column():
267
+ with gr.Tab("Guide video"):
268
+ video_guide = gr.Video(label="Guide video")
269
+ with gr.Tab("Guide video (images format)"):
270
+ video_guide_folder = gr.Textbox(label="Guide video (images format)", value="")
271
+ with gr.Column():
272
+ with gr.Tab("Style video"):
273
+ video_style = gr.Video(label="Style video")
274
+ with gr.Tab("Style video (images format)"):
275
+ video_style_folder = gr.Textbox(label="Style video (images format)", value="")
276
+ with gr.Column():
277
+ output_path = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of style video")
278
+ fps = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
279
+ video_output = gr.Video(label="Output video", interactive=False, show_share_button=True)
280
+ btn = gr.Button(value="Blend")
281
+ with gr.Row():
282
+ with gr.Column():
283
+ gr.Markdown("# Settings")
284
+ mode = gr.Radio(["Fast", "Balanced", "Accurate"], label="Inference mode", value="Fast", interactive=True)
285
+ window_size = gr.Slider(label="Sliding window size", value=15, minimum=1, maximum=1000, step=1, interactive=True)
286
+ batch_size = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
287
+ tracking_window_size = gr.Slider(label="Tracking window size (only for accurate mode)", value=0, minimum=0, maximum=10, step=1, interactive=True)
288
+ gr.Markdown("## Advanced Settings")
289
+ minimum_patch_size = gr.Slider(label="Minimum patch size (odd number)", value=5, minimum=5, maximum=99, step=2, interactive=True)
290
+ num_iter = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
291
+ guide_weight = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
292
+ initialize = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
293
+ with gr.Column():
294
+ gr.Markdown("""
295
+ # Reference
296
+
297
+ * Output directory: the directory to save the video.
298
+ * Inference mode
299
+
300
+ |Mode|Time|Memory|Quality|Frame by frame output|Description|
301
+ |-|-|-|-|-|-|
302
+ |Fast|■|■■■|■■|No|Blend the frames using a tree-like data structure, which requires much RAM but is fast.|
303
+ |Balanced|■■|■|■■|Yes|Blend the frames naively.|
304
+ |Accurate|■■■|■|■■■|Yes|Blend the frames and align them together for higher video quality. When [batch size] >= [sliding window size] * 2 + 1, the performance is the best.|
305
+
306
+ * Sliding window size: our algorithm will blend the frames in a sliding windows. If the size is n, each frame will be blended with the last n frames and the next n frames. A large sliding window can make the video fluent but sometimes smoggy.
307
+ * Batch size: a larger batch size makes the program faster but requires more VRAM.
308
+ * Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
309
+ * Advanced settings
310
+ * Minimum patch size (odd number): the minimum patch size used for patch matching. (Default: 5)
311
+ * Number of iterations: the number of iterations of patch matching. (Default: 5)
312
+ * Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
313
+ * NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
314
+ """)
315
+ btn.click(
316
+ smooth_video,
317
+ inputs=[
318
+ video_guide,
319
+ video_guide_folder,
320
+ video_style,
321
+ video_style_folder,
322
+ mode,
323
+ window_size,
324
+ batch_size,
325
+ tracking_window_size,
326
+ output_path,
327
+ fps,
328
+ minimum_patch_size,
329
+ num_iter,
330
+ guide_weight,
331
+ initialize
332
+ ],
333
+ outputs=[output_path, fps, video_output]
334
+ )
335
+ with gr.Tab("Interpolate"):
336
+ gr.Markdown("""
337
+ # Interpolate
338
+
339
+ Given a guide video and some rendered keyframes, this algorithm will render the remaining frames. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/3490c5b4-8f67-478f-86de-f9adc2ace16a) to see the example. The algorithm is experimental and is only tested for 512*512 resolution.
340
+ """)
341
+ with gr.Row():
342
+ with gr.Column():
343
+ with gr.Row():
344
+ with gr.Column():
345
+ video_guide_folder_ = gr.Textbox(label="Guide video (images format)", value="")
346
+ with gr.Column():
347
+ rendered_keyframes_ = gr.Textbox(label="Rendered keyframes (images format)", value="")
348
+ with gr.Row():
349
+ detected_frames = gr.Textbox(label="Detected frames", value="Please input the directory of guide video and rendered frames", lines=9, max_lines=9, interactive=False)
350
+ video_guide_folder_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
351
+ rendered_keyframes_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
352
+ with gr.Column():
353
+ output_path_ = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of rendered keyframes")
354
+ fps_ = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
355
+ video_output_ = gr.Video(label="Output video", interactive=False, show_share_button=True)
356
+ btn_ = gr.Button(value="Interpolate")
357
+ with gr.Row():
358
+ with gr.Column():
359
+ gr.Markdown("# Settings")
360
+ batch_size_ = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
361
+ tracking_window_size_ = gr.Slider(label="Tracking window size", value=0, minimum=0, maximum=10, step=1, interactive=True)
362
+ gr.Markdown("## Advanced Settings")
363
+ minimum_patch_size_ = gr.Slider(label="Minimum patch size (odd number, larger is better)", value=15, minimum=5, maximum=99, step=2, interactive=True)
364
+ num_iter_ = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
365
+ guide_weight_ = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
366
+ initialize_ = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
367
+ with gr.Column():
368
+ gr.Markdown("""
369
+ # Reference
370
+
371
+ * Output directory: the directory to save the video.
372
+ * Batch size: a larger batch size makes the program faster but requires more VRAM.
373
+ * Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
374
+ * Advanced settings
375
+ * Minimum patch size (odd number): the minimum patch size used for patch matching. **This parameter should be larger than that in blending. (Default: 15)**
376
+ * Number of iterations: the number of iterations of patch matching. (Default: 5)
377
+ * Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
378
+ * NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
379
+ """)
380
+ btn_.click(
381
+ interpolate_video,
382
+ inputs=[
383
+ video_guide_folder_,
384
+ rendered_keyframes_,
385
+ output_path_,
386
+ fps_,
387
+ batch_size_,
388
+ tracking_window_size_,
389
+ minimum_patch_size_,
390
+ num_iter_,
391
+ guide_weight_,
392
+ initialize_,
393
+ ],
394
+ outputs=[output_path_, fps_, video_output_]
395
+ )
396
+
397
+ return [(ui_component, "FastBlend", "FastBlend_ui")]
diffsynth/extensions/FastBlend/cupy_kernels.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cupy as cp
2
+
3
+ remapping_kernel = cp.RawKernel(r'''
4
+ extern "C" __global__
5
+ void remap(
6
+ const int height,
7
+ const int width,
8
+ const int channel,
9
+ const int patch_size,
10
+ const int pad_size,
11
+ const float* source_style,
12
+ const int* nnf,
13
+ float* target_style
14
+ ) {
15
+ const int r = (patch_size - 1) / 2;
16
+ const int x = blockDim.x * blockIdx.x + threadIdx.x;
17
+ const int y = blockDim.y * blockIdx.y + threadIdx.y;
18
+ if (x >= height or y >= width) return;
19
+ const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
20
+ const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size);
21
+ const int min_px = x < r ? -x : -r;
22
+ const int max_px = x + r > height - 1 ? height - 1 - x : r;
23
+ const int min_py = y < r ? -y : -r;
24
+ const int max_py = y + r > width - 1 ? width - 1 - y : r;
25
+ int num = 0;
26
+ for (int px = min_px; px <= max_px; px++){
27
+ for (int py = min_py; py <= max_py; py++){
28
+ const int nid = (x + px) * width + y + py;
29
+ const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px;
30
+ const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py;
31
+ if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue;
32
+ const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size);
33
+ num++;
34
+ for (int c = 0; c < channel; c++){
35
+ target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c];
36
+ }
37
+ }
38
+ }
39
+ for (int c = 0; c < channel; c++){
40
+ target_style[z + pid * channel + c] /= num;
41
+ }
42
+ }
43
+ ''', 'remap')
44
+
45
+
46
+ patch_error_kernel = cp.RawKernel(r'''
47
+ extern "C" __global__
48
+ void patch_error(
49
+ const int height,
50
+ const int width,
51
+ const int channel,
52
+ const int patch_size,
53
+ const int pad_size,
54
+ const float* source,
55
+ const int* nnf,
56
+ const float* target,
57
+ float* error
58
+ ) {
59
+ const int r = (patch_size - 1) / 2;
60
+ const int x = blockDim.x * blockIdx.x + threadIdx.x;
61
+ const int y = blockDim.y * blockIdx.y + threadIdx.y;
62
+ const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
63
+ if (x >= height or y >= width) return;
64
+ const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0];
65
+ const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1];
66
+ float e = 0;
67
+ for (int px = -r; px <= r; px++){
68
+ for (int py = -r; py <= r; py++){
69
+ const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py;
70
+ const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py;
71
+ for (int c = 0; c < channel; c++){
72
+ const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c];
73
+ e += diff * diff;
74
+ }
75
+ }
76
+ }
77
+ error[blockIdx.z * height * width + x * width + y] = e;
78
+ }
79
+ ''', 'patch_error')
80
+
81
+
82
+ pairwise_patch_error_kernel = cp.RawKernel(r'''
83
+ extern "C" __global__
84
+ void pairwise_patch_error(
85
+ const int height,
86
+ const int width,
87
+ const int channel,
88
+ const int patch_size,
89
+ const int pad_size,
90
+ const float* source_a,
91
+ const int* nnf_a,
92
+ const float* source_b,
93
+ const int* nnf_b,
94
+ float* error
95
+ ) {
96
+ const int r = (patch_size - 1) / 2;
97
+ const int x = blockDim.x * blockIdx.x + threadIdx.x;
98
+ const int y = blockDim.y * blockIdx.y + threadIdx.y;
99
+ const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
100
+ if (x >= height or y >= width) return;
101
+ const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2;
102
+ const int x_a = nnf_a[z_nnf + 0];
103
+ const int y_a = nnf_a[z_nnf + 1];
104
+ const int x_b = nnf_b[z_nnf + 0];
105
+ const int y_b = nnf_b[z_nnf + 1];
106
+ float e = 0;
107
+ for (int px = -r; px <= r; px++){
108
+ for (int py = -r; py <= r; py++){
109
+ const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py;
110
+ const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py;
111
+ for (int c = 0; c < channel; c++){
112
+ const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c];
113
+ e += diff * diff;
114
+ }
115
+ }
116
+ }
117
+ error[blockIdx.z * height * width + x * width + y] = e;
118
+ }
119
+ ''', 'pairwise_patch_error')
diffsynth/extensions/FastBlend/data.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imageio, os
2
+ import numpy as np
3
+ from PIL import Image
4
+
5
+
6
+ def read_video(file_name):
7
+ reader = imageio.get_reader(file_name)
8
+ video = []
9
+ for frame in reader:
10
+ frame = np.array(frame)
11
+ video.append(frame)
12
+ reader.close()
13
+ return video
14
+
15
+
16
+ def get_video_fps(file_name):
17
+ reader = imageio.get_reader(file_name)
18
+ fps = reader.get_meta_data()["fps"]
19
+ reader.close()
20
+ return fps
21
+
22
+
23
+ def save_video(frames_path, video_path, num_frames, fps):
24
+ writer = imageio.get_writer(video_path, fps=fps, quality=9)
25
+ for i in range(num_frames):
26
+ frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i)))
27
+ writer.append_data(frame)
28
+ writer.close()
29
+ return video_path
30
+
31
+
32
+ class LowMemoryVideo:
33
+ def __init__(self, file_name):
34
+ self.reader = imageio.get_reader(file_name)
35
+
36
+ def __len__(self):
37
+ return self.reader.count_frames()
38
+
39
+ def __getitem__(self, item):
40
+ return np.array(self.reader.get_data(item))
41
+
42
+ def __del__(self):
43
+ self.reader.close()
44
+
45
+
46
+ def split_file_name(file_name):
47
+ result = []
48
+ number = -1
49
+ for i in file_name:
50
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
51
+ if number == -1:
52
+ number = 0
53
+ number = number*10 + ord(i) - ord("0")
54
+ else:
55
+ if number != -1:
56
+ result.append(number)
57
+ number = -1
58
+ result.append(i)
59
+ if number != -1:
60
+ result.append(number)
61
+ result = tuple(result)
62
+ return result
63
+
64
+
65
+ def search_for_images(folder):
66
+ file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
67
+ file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
68
+ file_list = [i[1] for i in sorted(file_list)]
69
+ file_list = [os.path.join(folder, i) for i in file_list]
70
+ return file_list
71
+
72
+
73
+ def read_images(folder):
74
+ file_list = search_for_images(folder)
75
+ frames = [np.array(Image.open(i)) for i in file_list]
76
+ return frames
77
+
78
+
79
+ class LowMemoryImageFolder:
80
+ def __init__(self, folder, file_list=None):
81
+ if file_list is None:
82
+ self.file_list = search_for_images(folder)
83
+ else:
84
+ self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
85
+
86
+ def __len__(self):
87
+ return len(self.file_list)
88
+
89
+ def __getitem__(self, item):
90
+ return np.array(Image.open(self.file_list[item]))
91
+
92
+ def __del__(self):
93
+ pass
94
+
95
+
96
+ class VideoData:
97
+ def __init__(self, video_file, image_folder, **kwargs):
98
+ if video_file is not None:
99
+ self.data_type = "video"
100
+ self.data = LowMemoryVideo(video_file, **kwargs)
101
+ elif image_folder is not None:
102
+ self.data_type = "images"
103
+ self.data = LowMemoryImageFolder(image_folder, **kwargs)
104
+ else:
105
+ raise ValueError("Cannot open video or image folder")
106
+ self.length = None
107
+ self.height = None
108
+ self.width = None
109
+
110
+ def raw_data(self):
111
+ frames = []
112
+ for i in range(self.__len__()):
113
+ frames.append(self.__getitem__(i))
114
+ return frames
115
+
116
+ def set_length(self, length):
117
+ self.length = length
118
+
119
+ def set_shape(self, height, width):
120
+ self.height = height
121
+ self.width = width
122
+
123
+ def __len__(self):
124
+ if self.length is None:
125
+ return len(self.data)
126
+ else:
127
+ return self.length
128
+
129
+ def shape(self):
130
+ if self.height is not None and self.width is not None:
131
+ return self.height, self.width
132
+ else:
133
+ height, width, _ = self.__getitem__(0).shape
134
+ return height, width
135
+
136
+ def __getitem__(self, item):
137
+ frame = self.data.__getitem__(item)
138
+ height, width, _ = frame.shape
139
+ if self.height is not None and self.width is not None:
140
+ if self.height != height or self.width != width:
141
+ frame = Image.fromarray(frame).resize((self.width, self.height))
142
+ frame = np.array(frame)
143
+ return frame
144
+
145
+ def __del__(self):
146
+ pass
diffsynth/extensions/FastBlend/patch_match.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel
2
+ import numpy as np
3
+ import cupy as cp
4
+ import cv2
5
+
6
+
7
+ class PatchMatcher:
8
+ def __init__(
9
+ self, height, width, channel, minimum_patch_size,
10
+ threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
11
+ random_search_steps=3, random_search_range=4,
12
+ use_mean_target_style=False, use_pairwise_patch_error=False,
13
+ tracking_window_size=0
14
+ ):
15
+ self.height = height
16
+ self.width = width
17
+ self.channel = channel
18
+ self.minimum_patch_size = minimum_patch_size
19
+ self.threads_per_block = threads_per_block
20
+ self.num_iter = num_iter
21
+ self.gpu_id = gpu_id
22
+ self.guide_weight = guide_weight
23
+ self.random_search_steps = random_search_steps
24
+ self.random_search_range = random_search_range
25
+ self.use_mean_target_style = use_mean_target_style
26
+ self.use_pairwise_patch_error = use_pairwise_patch_error
27
+ self.tracking_window_size = tracking_window_size
28
+
29
+ self.patch_size_list = [minimum_patch_size + i*2 for i in range(num_iter)][::-1]
30
+ self.pad_size = self.patch_size_list[0] // 2
31
+ self.grid = (
32
+ (height + threads_per_block - 1) // threads_per_block,
33
+ (width + threads_per_block - 1) // threads_per_block
34
+ )
35
+ self.block = (threads_per_block, threads_per_block)
36
+
37
+ def pad_image(self, image):
38
+ return cp.pad(image, ((0, 0), (self.pad_size, self.pad_size), (self.pad_size, self.pad_size), (0, 0)))
39
+
40
+ def unpad_image(self, image):
41
+ return image[:, self.pad_size: -self.pad_size, self.pad_size: -self.pad_size, :]
42
+
43
+ def apply_nnf_to_image(self, nnf, source):
44
+ batch_size = source.shape[0]
45
+ target = cp.zeros((batch_size, self.height + self.pad_size * 2, self.width + self.pad_size * 2, self.channel), dtype=cp.float32)
46
+ remapping_kernel(
47
+ self.grid + (batch_size,),
48
+ self.block,
49
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target)
50
+ )
51
+ return target
52
+
53
+ def get_patch_error(self, source, nnf, target):
54
+ batch_size = source.shape[0]
55
+ error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
56
+ patch_error_kernel(
57
+ self.grid + (batch_size,),
58
+ self.block,
59
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target, error)
60
+ )
61
+ return error
62
+
63
+ def get_pairwise_patch_error(self, source, nnf):
64
+ batch_size = source.shape[0]//2
65
+ error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
66
+ source_a, nnf_a = source[0::2].copy(), nnf[0::2].copy()
67
+ source_b, nnf_b = source[1::2].copy(), nnf[1::2].copy()
68
+ pairwise_patch_error_kernel(
69
+ self.grid + (batch_size,),
70
+ self.block,
71
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source_a, nnf_a, source_b, nnf_b, error)
72
+ )
73
+ error = error.repeat(2, axis=0)
74
+ return error
75
+
76
+ def get_error(self, source_guide, target_guide, source_style, target_style, nnf):
77
+ error_guide = self.get_patch_error(source_guide, nnf, target_guide)
78
+ if self.use_mean_target_style:
79
+ target_style = self.apply_nnf_to_image(nnf, source_style)
80
+ target_style = target_style.mean(axis=0, keepdims=True)
81
+ target_style = target_style.repeat(source_guide.shape[0], axis=0)
82
+ if self.use_pairwise_patch_error:
83
+ error_style = self.get_pairwise_patch_error(source_style, nnf)
84
+ else:
85
+ error_style = self.get_patch_error(source_style, nnf, target_style)
86
+ error = error_guide * self.guide_weight + error_style
87
+ return error
88
+
89
+ def clamp_bound(self, nnf):
90
+ nnf[:,:,:,0] = cp.clip(nnf[:,:,:,0], 0, self.height-1)
91
+ nnf[:,:,:,1] = cp.clip(nnf[:,:,:,1], 0, self.width-1)
92
+ return nnf
93
+
94
+ def random_step(self, nnf, r):
95
+ batch_size = nnf.shape[0]
96
+ step = cp.random.randint(-r, r+1, size=(batch_size, self.height, self.width, 2), dtype=cp.int32)
97
+ upd_nnf = self.clamp_bound(nnf + step)
98
+ return upd_nnf
99
+
100
+ def neighboor_step(self, nnf, d):
101
+ if d==0:
102
+ upd_nnf = cp.concatenate([nnf[:, :1, :], nnf[:, :-1, :]], axis=1)
103
+ upd_nnf[:, :, :, 0] += 1
104
+ elif d==1:
105
+ upd_nnf = cp.concatenate([nnf[:, :, :1], nnf[:, :, :-1]], axis=2)
106
+ upd_nnf[:, :, :, 1] += 1
107
+ elif d==2:
108
+ upd_nnf = cp.concatenate([nnf[:, 1:, :], nnf[:, -1:, :]], axis=1)
109
+ upd_nnf[:, :, :, 0] -= 1
110
+ elif d==3:
111
+ upd_nnf = cp.concatenate([nnf[:, :, 1:], nnf[:, :, -1:]], axis=2)
112
+ upd_nnf[:, :, :, 1] -= 1
113
+ upd_nnf = self.clamp_bound(upd_nnf)
114
+ return upd_nnf
115
+
116
+ def shift_nnf(self, nnf, d):
117
+ if d>0:
118
+ d = min(nnf.shape[0], d)
119
+ upd_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
120
+ else:
121
+ d = max(-nnf.shape[0], d)
122
+ upd_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
123
+ return upd_nnf
124
+
125
+ def track_step(self, nnf, d):
126
+ if self.use_pairwise_patch_error:
127
+ upd_nnf = cp.zeros_like(nnf)
128
+ upd_nnf[0::2] = self.shift_nnf(nnf[0::2], d)
129
+ upd_nnf[1::2] = self.shift_nnf(nnf[1::2], d)
130
+ else:
131
+ upd_nnf = self.shift_nnf(nnf, d)
132
+ return upd_nnf
133
+
134
+ def C(self, n, m):
135
+ # not used
136
+ c = 1
137
+ for i in range(1, n+1):
138
+ c *= i
139
+ for i in range(1, m+1):
140
+ c //= i
141
+ for i in range(1, n-m+1):
142
+ c //= i
143
+ return c
144
+
145
+ def bezier_step(self, nnf, r):
146
+ # not used
147
+ n = r * 2 - 1
148
+ upd_nnf = cp.zeros(shape=nnf.shape, dtype=cp.float32)
149
+ for i, d in enumerate(list(range(-r, 0)) + list(range(1, r+1))):
150
+ if d>0:
151
+ ctl_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
152
+ elif d<0:
153
+ ctl_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
154
+ upd_nnf += ctl_nnf * (self.C(n, i) / 2**n)
155
+ upd_nnf = self.clamp_bound(upd_nnf).astype(nnf.dtype)
156
+ return upd_nnf
157
+
158
+ def update(self, source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf):
159
+ upd_err = self.get_error(source_guide, target_guide, source_style, target_style, upd_nnf)
160
+ upd_idx = (upd_err < err)
161
+ nnf[upd_idx] = upd_nnf[upd_idx]
162
+ err[upd_idx] = upd_err[upd_idx]
163
+ return nnf, err
164
+
165
+ def propagation(self, source_guide, target_guide, source_style, target_style, nnf, err):
166
+ for d in cp.random.permutation(4):
167
+ upd_nnf = self.neighboor_step(nnf, d)
168
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
169
+ return nnf, err
170
+
171
+ def random_search(self, source_guide, target_guide, source_style, target_style, nnf, err):
172
+ for i in range(self.random_search_steps):
173
+ upd_nnf = self.random_step(nnf, self.random_search_range)
174
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
175
+ return nnf, err
176
+
177
+ def track(self, source_guide, target_guide, source_style, target_style, nnf, err):
178
+ for d in range(1, self.tracking_window_size + 1):
179
+ upd_nnf = self.track_step(nnf, d)
180
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
181
+ upd_nnf = self.track_step(nnf, -d)
182
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
183
+ return nnf, err
184
+
185
+ def iteration(self, source_guide, target_guide, source_style, target_style, nnf, err):
186
+ nnf, err = self.propagation(source_guide, target_guide, source_style, target_style, nnf, err)
187
+ nnf, err = self.random_search(source_guide, target_guide, source_style, target_style, nnf, err)
188
+ nnf, err = self.track(source_guide, target_guide, source_style, target_style, nnf, err)
189
+ return nnf, err
190
+
191
+ def estimate_nnf(self, source_guide, target_guide, source_style, nnf):
192
+ with cp.cuda.Device(self.gpu_id):
193
+ source_guide = self.pad_image(source_guide)
194
+ target_guide = self.pad_image(target_guide)
195
+ source_style = self.pad_image(source_style)
196
+ for it in range(self.num_iter):
197
+ self.patch_size = self.patch_size_list[it]
198
+ target_style = self.apply_nnf_to_image(nnf, source_style)
199
+ err = self.get_error(source_guide, target_guide, source_style, target_style, nnf)
200
+ nnf, err = self.iteration(source_guide, target_guide, source_style, target_style, nnf, err)
201
+ target_style = self.unpad_image(self.apply_nnf_to_image(nnf, source_style))
202
+ return nnf, target_style
203
+
204
+
205
+ class PyramidPatchMatcher:
206
+ def __init__(
207
+ self, image_height, image_width, channel, minimum_patch_size,
208
+ threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
209
+ use_mean_target_style=False, use_pairwise_patch_error=False,
210
+ tracking_window_size=0,
211
+ initialize="identity"
212
+ ):
213
+ maximum_patch_size = minimum_patch_size + (num_iter - 1) * 2
214
+ self.pyramid_level = int(np.log2(min(image_height, image_width) / maximum_patch_size))
215
+ self.pyramid_heights = []
216
+ self.pyramid_widths = []
217
+ self.patch_matchers = []
218
+ self.minimum_patch_size = minimum_patch_size
219
+ self.num_iter = num_iter
220
+ self.gpu_id = gpu_id
221
+ self.initialize = initialize
222
+ for level in range(self.pyramid_level):
223
+ height = image_height//(2**(self.pyramid_level - 1 - level))
224
+ width = image_width//(2**(self.pyramid_level - 1 - level))
225
+ self.pyramid_heights.append(height)
226
+ self.pyramid_widths.append(width)
227
+ self.patch_matchers.append(PatchMatcher(
228
+ height, width, channel, minimum_patch_size=minimum_patch_size,
229
+ threads_per_block=threads_per_block, num_iter=num_iter, gpu_id=gpu_id, guide_weight=guide_weight,
230
+ use_mean_target_style=use_mean_target_style, use_pairwise_patch_error=use_pairwise_patch_error,
231
+ tracking_window_size=tracking_window_size
232
+ ))
233
+
234
+ def resample_image(self, images, level):
235
+ height, width = self.pyramid_heights[level], self.pyramid_widths[level]
236
+ images = images.get()
237
+ images_resample = []
238
+ for image in images:
239
+ image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
240
+ images_resample.append(image_resample)
241
+ images_resample = cp.array(np.stack(images_resample), dtype=cp.float32)
242
+ return images_resample
243
+
244
+ def initialize_nnf(self, batch_size):
245
+ if self.initialize == "random":
246
+ height, width = self.pyramid_heights[0], self.pyramid_widths[0]
247
+ nnf = cp.stack([
248
+ cp.random.randint(0, height, (batch_size, height, width), dtype=cp.int32),
249
+ cp.random.randint(0, width, (batch_size, height, width), dtype=cp.int32)
250
+ ], axis=3)
251
+ elif self.initialize == "identity":
252
+ height, width = self.pyramid_heights[0], self.pyramid_widths[0]
253
+ nnf = cp.stack([
254
+ cp.repeat(cp.arange(height), width).reshape(height, width),
255
+ cp.tile(cp.arange(width), height).reshape(height, width)
256
+ ], axis=2)
257
+ nnf = cp.stack([nnf] * batch_size)
258
+ else:
259
+ raise NotImplementedError()
260
+ return nnf
261
+
262
+ def update_nnf(self, nnf, level):
263
+ # upscale
264
+ nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
265
+ nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1
266
+ nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1
267
+ # check if scale is 2
268
+ height, width = self.pyramid_heights[level], self.pyramid_widths[level]
269
+ if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
270
+ nnf = nnf.get().astype(np.float32)
271
+ nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf]
272
+ nnf = cp.array(np.stack(nnf), dtype=cp.int32)
273
+ nnf = self.patch_matchers[level].clamp_bound(nnf)
274
+ return nnf
275
+
276
+ def apply_nnf_to_image(self, nnf, image):
277
+ with cp.cuda.Device(self.gpu_id):
278
+ image = self.patch_matchers[-1].pad_image(image)
279
+ image = self.patch_matchers[-1].apply_nnf_to_image(nnf, image)
280
+ return image
281
+
282
+ def estimate_nnf(self, source_guide, target_guide, source_style):
283
+ with cp.cuda.Device(self.gpu_id):
284
+ if not isinstance(source_guide, cp.ndarray):
285
+ source_guide = cp.array(source_guide, dtype=cp.float32)
286
+ if not isinstance(target_guide, cp.ndarray):
287
+ target_guide = cp.array(target_guide, dtype=cp.float32)
288
+ if not isinstance(source_style, cp.ndarray):
289
+ source_style = cp.array(source_style, dtype=cp.float32)
290
+ for level in range(self.pyramid_level):
291
+ nnf = self.initialize_nnf(source_guide.shape[0]) if level==0 else self.update_nnf(nnf, level)
292
+ source_guide_ = self.resample_image(source_guide, level)
293
+ target_guide_ = self.resample_image(target_guide, level)
294
+ source_style_ = self.resample_image(source_style, level)
295
+ nnf, target_style = self.patch_matchers[level].estimate_nnf(
296
+ source_guide_, target_guide_, source_style_, nnf
297
+ )
298
+ return nnf.get(), target_style.get()
diffsynth/extensions/FastBlend/runners/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .accurate import AccurateModeRunner
2
+ from .fast import FastModeRunner
3
+ from .balanced import BalancedModeRunner
4
+ from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner
diffsynth/extensions/FastBlend/runners/accurate.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class AccurateModeRunner:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None):
13
+ patch_match_engine = PyramidPatchMatcher(
14
+ image_height=frames_style[0].shape[0],
15
+ image_width=frames_style[0].shape[1],
16
+ channel=3,
17
+ use_mean_target_style=True,
18
+ **ebsynth_config
19
+ )
20
+ # run
21
+ n = len(frames_style)
22
+ for target in tqdm(range(n), desc=desc):
23
+ l, r = max(target - window_size, 0), min(target + window_size + 1, n)
24
+ remapped_frames = []
25
+ for i in range(l, r, batch_size):
26
+ j = min(i + batch_size, r)
27
+ source_guide = np.stack([frames_guide[source] for source in range(i, j)])
28
+ target_guide = np.stack([frames_guide[target]] * (j - i))
29
+ source_style = np.stack([frames_style[source] for source in range(i, j)])
30
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
31
+ remapped_frames.append(target_style)
32
+ frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
33
+ frame = frame.clip(0, 255).astype("uint8")
34
+ if save_path is not None:
35
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
diffsynth/extensions/FastBlend/runners/balanced.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class BalancedModeRunner:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None):
13
+ patch_match_engine = PyramidPatchMatcher(
14
+ image_height=frames_style[0].shape[0],
15
+ image_width=frames_style[0].shape[1],
16
+ channel=3,
17
+ **ebsynth_config
18
+ )
19
+ # tasks
20
+ n = len(frames_style)
21
+ tasks = []
22
+ for target in range(n):
23
+ for source in range(target - window_size, target + window_size + 1):
24
+ if source >= 0 and source < n and source != target:
25
+ tasks.append((source, target))
26
+ # run
27
+ frames = [(None, 1) for i in range(n)]
28
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
29
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
30
+ source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
31
+ target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
32
+ source_style = np.stack([frames_style[source] for source, target in tasks_batch])
33
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
34
+ for (source, target), result in zip(tasks_batch, target_style):
35
+ frame, weight = frames[target]
36
+ if frame is None:
37
+ frame = frames_style[target]
38
+ frames[target] = (
39
+ frame * (weight / (weight + 1)) + result / (weight + 1),
40
+ weight + 1
41
+ )
42
+ if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size):
43
+ frame = frame.clip(0, 255).astype("uint8")
44
+ if save_path is not None:
45
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
46
+ frames[target] = (None, 1)
diffsynth/extensions/FastBlend/runners/fast.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import functools, os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class TableManager:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def task_list(self, n):
13
+ tasks = []
14
+ max_level = 1
15
+ while (1<<max_level)<=n:
16
+ max_level += 1
17
+ for i in range(n):
18
+ j = i
19
+ for level in range(max_level):
20
+ if i&(1<<level):
21
+ continue
22
+ j |= 1<<level
23
+ if j>=n:
24
+ break
25
+ meta_data = {
26
+ "source": i,
27
+ "target": j,
28
+ "level": level + 1
29
+ }
30
+ tasks.append(meta_data)
31
+ tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"]))
32
+ return tasks
33
+
34
+ def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""):
35
+ n = len(frames_guide)
36
+ tasks = self.task_list(n)
37
+ remapping_table = [[(frames_style[i], 1)] for i in range(n)]
38
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
39
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
40
+ source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
41
+ target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
42
+ source_style = np.stack([frames_style[task["source"]] for task in tasks_batch])
43
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
44
+ for task, result in zip(tasks_batch, target_style):
45
+ target, level = task["target"], task["level"]
46
+ if len(remapping_table[target])==level:
47
+ remapping_table[target].append((result, 1))
48
+ else:
49
+ frame, weight = remapping_table[target][level]
50
+ remapping_table[target][level] = (
51
+ frame * (weight / (weight + 1)) + result / (weight + 1),
52
+ weight + 1
53
+ )
54
+ return remapping_table
55
+
56
+ def remapping_table_to_blending_table(self, table):
57
+ for i in range(len(table)):
58
+ for j in range(1, len(table[i])):
59
+ frame_1, weight_1 = table[i][j-1]
60
+ frame_2, weight_2 = table[i][j]
61
+ frame = (frame_1 + frame_2) / 2
62
+ weight = weight_1 + weight_2
63
+ table[i][j] = (frame, weight)
64
+ return table
65
+
66
+ def tree_query(self, leftbound, rightbound):
67
+ node_list = []
68
+ node_index = rightbound
69
+ while node_index>=leftbound:
70
+ node_level = 0
71
+ while (1<<node_level)&node_index and node_index-(1<<node_level+1)+1>=leftbound:
72
+ node_level += 1
73
+ node_list.append((node_index, node_level))
74
+ node_index -= 1<<node_level
75
+ return node_list
76
+
77
+ def process_window_sum(self, frames_guide, blending_table, patch_match_engine, window_size, batch_size, desc=""):
78
+ n = len(blending_table)
79
+ tasks = []
80
+ frames_result = []
81
+ for target in range(n):
82
+ node_list = self.tree_query(max(target-window_size, 0), target)
83
+ for source, level in node_list:
84
+ if source!=target:
85
+ meta_data = {
86
+ "source": source,
87
+ "target": target,
88
+ "level": level
89
+ }
90
+ tasks.append(meta_data)
91
+ else:
92
+ frames_result.append(blending_table[target][level])
93
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
94
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
95
+ source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
96
+ target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
97
+ source_style = np.stack([blending_table[task["source"]][task["level"]][0] for task in tasks_batch])
98
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
99
+ for task, frame_2 in zip(tasks_batch, target_style):
100
+ source, target, level = task["source"], task["target"], task["level"]
101
+ frame_1, weight_1 = frames_result[target]
102
+ weight_2 = blending_table[source][level][1]
103
+ weight = weight_1 + weight_2
104
+ frame = frame_1 * (weight_1 / weight) + frame_2 * (weight_2 / weight)
105
+ frames_result[target] = (frame, weight)
106
+ return frames_result
107
+
108
+
109
+ class FastModeRunner:
110
+ def __init__(self):
111
+ pass
112
+
113
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, save_path=None):
114
+ frames_guide = frames_guide.raw_data()
115
+ frames_style = frames_style.raw_data()
116
+ table_manager = TableManager()
117
+ patch_match_engine = PyramidPatchMatcher(
118
+ image_height=frames_style[0].shape[0],
119
+ image_width=frames_style[0].shape[1],
120
+ channel=3,
121
+ **ebsynth_config
122
+ )
123
+ # left part
124
+ table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="Fast Mode Step 1/4")
125
+ table_l = table_manager.remapping_table_to_blending_table(table_l)
126
+ table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 2/4")
127
+ # right part
128
+ table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="Fast Mode Step 3/4")
129
+ table_r = table_manager.remapping_table_to_blending_table(table_r)
130
+ table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 4/4")[::-1]
131
+ # merge
132
+ frames = []
133
+ for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
134
+ weight_m = -1
135
+ weight = weight_l + weight_m + weight_r
136
+ frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
137
+ frames.append(frame)
138
+ frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
139
+ if save_path is not None:
140
+ for target, frame in enumerate(frames):
141
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
diffsynth/extensions/FastBlend/runners/interpolation.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class InterpolationModeRunner:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def get_index_dict(self, index_style):
13
+ index_dict = {}
14
+ for i, index in enumerate(index_style):
15
+ index_dict[index] = i
16
+ return index_dict
17
+
18
+ def get_weight(self, l, m, r):
19
+ weight_l, weight_r = abs(m - r), abs(m - l)
20
+ if weight_l + weight_r == 0:
21
+ weight_l, weight_r = 0.5, 0.5
22
+ else:
23
+ weight_l, weight_r = weight_l / (weight_l + weight_r), weight_r / (weight_l + weight_r)
24
+ return weight_l, weight_r
25
+
26
+ def get_task_group(self, index_style, n):
27
+ task_group = []
28
+ index_style = sorted(index_style)
29
+ # first frame
30
+ if index_style[0]>0:
31
+ tasks = []
32
+ for m in range(index_style[0]):
33
+ tasks.append((index_style[0], m, index_style[0]))
34
+ task_group.append(tasks)
35
+ # middle frames
36
+ for l, r in zip(index_style[:-1], index_style[1:]):
37
+ tasks = []
38
+ for m in range(l, r):
39
+ tasks.append((l, m, r))
40
+ task_group.append(tasks)
41
+ # last frame
42
+ tasks = []
43
+ for m in range(index_style[-1], n):
44
+ tasks.append((index_style[-1], m, index_style[-1]))
45
+ task_group.append(tasks)
46
+ return task_group
47
+
48
+ def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
49
+ patch_match_engine = PyramidPatchMatcher(
50
+ image_height=frames_style[0].shape[0],
51
+ image_width=frames_style[0].shape[1],
52
+ channel=3,
53
+ use_mean_target_style=False,
54
+ use_pairwise_patch_error=True,
55
+ **ebsynth_config
56
+ )
57
+ # task
58
+ index_dict = self.get_index_dict(index_style)
59
+ task_group = self.get_task_group(index_style, len(frames_guide))
60
+ # run
61
+ for tasks in task_group:
62
+ index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks])
63
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"):
64
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
65
+ source_guide, target_guide, source_style = [], [], []
66
+ for l, m, r in tasks_batch:
67
+ # l -> m
68
+ source_guide.append(frames_guide[l])
69
+ target_guide.append(frames_guide[m])
70
+ source_style.append(frames_style[index_dict[l]])
71
+ # r -> m
72
+ source_guide.append(frames_guide[r])
73
+ target_guide.append(frames_guide[m])
74
+ source_style.append(frames_style[index_dict[r]])
75
+ source_guide = np.stack(source_guide)
76
+ target_guide = np.stack(target_guide)
77
+ source_style = np.stack(source_style)
78
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
79
+ if save_path is not None:
80
+ for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch):
81
+ weight_l, weight_r = self.get_weight(l, m, r)
82
+ frame = frame_l * weight_l + frame_r * weight_r
83
+ frame = frame.clip(0, 255).astype("uint8")
84
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m))
85
+
86
+
87
+ class InterpolationModeSingleFrameRunner:
88
+ def __init__(self):
89
+ pass
90
+
91
+ def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
92
+ # check input
93
+ tracking_window_size = ebsynth_config["tracking_window_size"]
94
+ if tracking_window_size * 2 >= batch_size:
95
+ raise ValueError("batch_size should be larger than track_window_size * 2")
96
+ frame_style = frames_style[0]
97
+ frame_guide = frames_guide[index_style[0]]
98
+ patch_match_engine = PyramidPatchMatcher(
99
+ image_height=frame_style.shape[0],
100
+ image_width=frame_style.shape[1],
101
+ channel=3,
102
+ **ebsynth_config
103
+ )
104
+ # run
105
+ frame_id, n = 0, len(frames_guide)
106
+ for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"):
107
+ if i + batch_size > n:
108
+ l, r = max(n - batch_size, 0), n
109
+ else:
110
+ l, r = i, i + batch_size
111
+ source_guide = np.stack([frame_guide] * (r-l))
112
+ target_guide = np.stack([frames_guide[i] for i in range(l, r)])
113
+ source_style = np.stack([frame_style] * (r-l))
114
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
115
+ for i, frame in zip(range(l, r), target_style):
116
+ if i==frame_id:
117
+ frame = frame.clip(0, 255).astype("uint8")
118
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id))
119
+ frame_id += 1
120
+ if r < n and r-frame_id <= tracking_window_size:
121
+ break
diffsynth/extensions/RIFE/__init__.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+
8
+ def warp(tenInput, tenFlow, device):
9
+ backwarp_tenGrid = {}
10
+ k = (str(tenFlow.device), str(tenFlow.size()))
11
+ if k not in backwarp_tenGrid:
12
+ tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
13
+ 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
14
+ tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
15
+ 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
16
+ backwarp_tenGrid[k] = torch.cat(
17
+ [tenHorizontal, tenVertical], 1).to(device)
18
+
19
+ tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
20
+ tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
21
+
22
+ g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
23
+ return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
24
+
25
+
26
+ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
27
+ return nn.Sequential(
28
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
29
+ padding=padding, dilation=dilation, bias=True),
30
+ nn.PReLU(out_planes)
31
+ )
32
+
33
+
34
+ class IFBlock(nn.Module):
35
+ def __init__(self, in_planes, c=64):
36
+ super(IFBlock, self).__init__()
37
+ self.conv0 = nn.Sequential(conv(in_planes, c//2, 3, 2, 1), conv(c//2, c, 3, 2, 1),)
38
+ self.convblock0 = nn.Sequential(conv(c, c), conv(c, c))
39
+ self.convblock1 = nn.Sequential(conv(c, c), conv(c, c))
40
+ self.convblock2 = nn.Sequential(conv(c, c), conv(c, c))
41
+ self.convblock3 = nn.Sequential(conv(c, c), conv(c, c))
42
+ self.conv1 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 4, 4, 2, 1))
43
+ self.conv2 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 1, 4, 2, 1))
44
+
45
+ def forward(self, x, flow, scale=1):
46
+ x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
47
+ flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale
48
+ feat = self.conv0(torch.cat((x, flow), 1))
49
+ feat = self.convblock0(feat) + feat
50
+ feat = self.convblock1(feat) + feat
51
+ feat = self.convblock2(feat) + feat
52
+ feat = self.convblock3(feat) + feat
53
+ flow = self.conv1(feat)
54
+ mask = self.conv2(feat)
55
+ flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale
56
+ mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
57
+ return flow, mask
58
+
59
+
60
+ class IFNet(nn.Module):
61
+ def __init__(self):
62
+ super(IFNet, self).__init__()
63
+ self.block0 = IFBlock(7+4, c=90)
64
+ self.block1 = IFBlock(7+4, c=90)
65
+ self.block2 = IFBlock(7+4, c=90)
66
+ self.block_tea = IFBlock(10+4, c=90)
67
+
68
+ def forward(self, x, scale_list=[4, 2, 1], training=False):
69
+ if training == False:
70
+ channel = x.shape[1] // 2
71
+ img0 = x[:, :channel]
72
+ img1 = x[:, channel:]
73
+ flow_list = []
74
+ merged = []
75
+ mask_list = []
76
+ warped_img0 = img0
77
+ warped_img1 = img1
78
+ flow = (x[:, :4]).detach() * 0
79
+ mask = (x[:, :1]).detach() * 0
80
+ block = [self.block0, self.block1, self.block2]
81
+ for i in range(3):
82
+ f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
83
+ f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
84
+ flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
85
+ mask = mask + (m0 + (-m1)) / 2
86
+ mask_list.append(mask)
87
+ flow_list.append(flow)
88
+ warped_img0 = warp(img0, flow[:, :2], device=x.device)
89
+ warped_img1 = warp(img1, flow[:, 2:4], device=x.device)
90
+ merged.append((warped_img0, warped_img1))
91
+ '''
92
+ c0 = self.contextnet(img0, flow[:, :2])
93
+ c1 = self.contextnet(img1, flow[:, 2:4])
94
+ tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
95
+ res = tmp[:, 1:4] * 2 - 1
96
+ '''
97
+ for i in range(3):
98
+ mask_list[i] = torch.sigmoid(mask_list[i])
99
+ merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
100
+ return flow_list, mask_list[2], merged
101
+
102
+ def state_dict_converter(self):
103
+ return IFNetStateDictConverter()
104
+
105
+
106
+ class IFNetStateDictConverter:
107
+ def __init__(self):
108
+ pass
109
+
110
+ def from_diffusers(self, state_dict):
111
+ state_dict_ = {k.replace("module.", ""): v for k, v in state_dict.items()}
112
+ return state_dict_
113
+
114
+ def from_civitai(self, state_dict):
115
+ return self.from_diffusers(state_dict)
116
+
117
+
118
+ class RIFEInterpolater:
119
+ def __init__(self, model, device="cuda"):
120
+ self.model = model
121
+ self.device = device
122
+ # IFNet only does not support float16
123
+ self.torch_dtype = torch.float32
124
+
125
+ @staticmethod
126
+ def from_model_manager(model_manager):
127
+ return RIFEInterpolater(model_manager.RIFE, device=model_manager.device)
128
+
129
+ def process_image(self, image):
130
+ width, height = image.size
131
+ if width % 32 != 0 or height % 32 != 0:
132
+ width = (width + 31) // 32
133
+ height = (height + 31) // 32
134
+ image = image.resize((width, height))
135
+ image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1)
136
+ return image
137
+
138
+ def process_images(self, images):
139
+ images = [self.process_image(image) for image in images]
140
+ images = torch.stack(images)
141
+ return images
142
+
143
+ def decode_images(self, images):
144
+ images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
145
+ images = [Image.fromarray(image) for image in images]
146
+ return images
147
+
148
+ def add_interpolated_images(self, images, interpolated_images):
149
+ output_images = []
150
+ for image, interpolated_image in zip(images, interpolated_images):
151
+ output_images.append(image)
152
+ output_images.append(interpolated_image)
153
+ output_images.append(images[-1])
154
+ return output_images
155
+
156
+
157
+ @torch.no_grad()
158
+ def interpolate_(self, images, scale=1.0):
159
+ input_tensor = self.process_images(images)
160
+ input_tensor = torch.cat((input_tensor[:-1], input_tensor[1:]), dim=1)
161
+ input_tensor = input_tensor.to(device=self.device, dtype=self.torch_dtype)
162
+ flow, mask, merged = self.model(input_tensor, [4/scale, 2/scale, 1/scale])
163
+ output_images = self.decode_images(merged[2].cpu())
164
+ if output_images[0].size != images[0].size:
165
+ output_images = [image.resize(images[0].size) for image in output_images]
166
+ return output_images
167
+
168
+
169
+ @torch.no_grad()
170
+ def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1, progress_bar=lambda x:x):
171
+ # Preprocess
172
+ processed_images = self.process_images(images)
173
+
174
+ for iter in range(num_iter):
175
+ # Input
176
+ input_tensor = torch.cat((processed_images[:-1], processed_images[1:]), dim=1)
177
+
178
+ # Interpolate
179
+ output_tensor = []
180
+ for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
181
+ batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
182
+ batch_input_tensor = input_tensor[batch_id: batch_id_]
183
+ batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
184
+ flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
185
+ output_tensor.append(merged[2].cpu())
186
+
187
+ # Output
188
+ output_tensor = torch.concat(output_tensor, dim=0).clip(0, 1)
189
+ processed_images = self.add_interpolated_images(processed_images, output_tensor)
190
+ processed_images = torch.stack(processed_images)
191
+
192
+ # To images
193
+ output_images = self.decode_images(processed_images)
194
+ if output_images[0].size != images[0].size:
195
+ output_images = [image.resize(images[0].size) for image in output_images]
196
+ return output_images
197
+
198
+
199
+ class RIFESmoother(RIFEInterpolater):
200
+ def __init__(self, model, device="cuda"):
201
+ super(RIFESmoother, self).__init__(model, device=device)
202
+
203
+ @staticmethod
204
+ def from_model_manager(model_manager):
205
+ return RIFESmoother(model_manager.RIFE, device=model_manager.device)
206
+
207
+ def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
208
+ output_tensor = []
209
+ for batch_id in range(0, input_tensor.shape[0], batch_size):
210
+ batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
211
+ batch_input_tensor = input_tensor[batch_id: batch_id_]
212
+ batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
213
+ flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
214
+ output_tensor.append(merged[2].cpu())
215
+ output_tensor = torch.concat(output_tensor, dim=0)
216
+ return output_tensor
217
+
218
+ @torch.no_grad()
219
+ def __call__(self, rendered_frames, scale=1.0, batch_size=4, num_iter=1, **kwargs):
220
+ # Preprocess
221
+ processed_images = self.process_images(rendered_frames)
222
+
223
+ for iter in range(num_iter):
224
+ # Input
225
+ input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1)
226
+
227
+ # Interpolate
228
+ output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
229
+
230
+ # Blend
231
+ input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1)
232
+ output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
233
+
234
+ # Add to frames
235
+ processed_images[1:-1] = output_tensor
236
+
237
+ # To images
238
+ output_images = self.decode_images(processed_images)
239
+ if output_images[0].size != rendered_frames[0].size:
240
+ output_images = [image.resize(rendered_frames[0].size) for image in output_images]
241
+ return output_images
diffsynth/models/__init__.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os
2
+ from safetensors import safe_open
3
+
4
+ from .sd_text_encoder import SDTextEncoder
5
+ from .sd_unet import SDUNet
6
+ from .sd_vae_encoder import SDVAEEncoder
7
+ from .sd_vae_decoder import SDVAEDecoder
8
+ from .sd_lora import SDLoRA
9
+
10
+ from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
11
+ from .sdxl_unet import SDXLUNet
12
+ from .sdxl_vae_decoder import SDXLVAEDecoder
13
+ from .sdxl_vae_encoder import SDXLVAEEncoder
14
+
15
+ from .sd_controlnet import SDControlNet
16
+
17
+ from .sd_motion import SDMotionModel
18
+ from .sdxl_motion import SDXLMotionModel
19
+
20
+ from .svd_image_encoder import SVDImageEncoder
21
+ from .svd_unet import SVDUNet
22
+ from .svd_vae_decoder import SVDVAEDecoder
23
+ from .svd_vae_encoder import SVDVAEEncoder
24
+
25
+ from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
26
+ from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
27
+
28
+ from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
29
+ from .hunyuan_dit import HunyuanDiT
30
+
31
+
32
+ class ModelManager:
33
+ def __init__(self, torch_dtype=torch.float16, device="cuda"):
34
+ self.torch_dtype = torch_dtype
35
+ self.device = device
36
+ self.model = {}
37
+ self.model_path = {}
38
+ self.textual_inversion_dict = {}
39
+
40
+ def is_stable_video_diffusion(self, state_dict):
41
+ param_name = "model.diffusion_model.output_blocks.9.1.time_stack.0.norm_in.weight"
42
+ return param_name in state_dict
43
+
44
+ def is_RIFE(self, state_dict):
45
+ param_name = "block_tea.convblock3.0.1.weight"
46
+ return param_name in state_dict or ("module." + param_name) in state_dict
47
+
48
+ def is_beautiful_prompt(self, state_dict):
49
+ param_name = "transformer.h.9.self_attention.query_key_value.weight"
50
+ return param_name in state_dict
51
+
52
+ def is_stabe_diffusion_xl(self, state_dict):
53
+ param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight"
54
+ return param_name in state_dict
55
+
56
+ def is_stable_diffusion(self, state_dict):
57
+ if self.is_stabe_diffusion_xl(state_dict):
58
+ return False
59
+ param_name = "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight"
60
+ return param_name in state_dict
61
+
62
+ def is_controlnet(self, state_dict):
63
+ param_name = "control_model.time_embed.0.weight"
64
+ param_name_2 = "mid_block.resnets.1.time_emb_proj.weight" # For controlnets in diffusers format
65
+ return param_name in state_dict or param_name_2 in state_dict
66
+
67
+ def is_animatediff(self, state_dict):
68
+ param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
69
+ return param_name in state_dict
70
+
71
+ def is_animatediff_xl(self, state_dict):
72
+ param_name = "up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.ff_norm.weight"
73
+ return param_name in state_dict
74
+
75
+ def is_sd_lora(self, state_dict):
76
+ param_name = "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_ff_net_2.lora_up.weight"
77
+ return param_name in state_dict
78
+
79
+ def is_translator(self, state_dict):
80
+ param_name = "model.encoder.layers.5.self_attn_layer_norm.weight"
81
+ return param_name in state_dict and len(state_dict) == 254
82
+
83
+ def is_ipadapter(self, state_dict):
84
+ return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([3072, 1024])
85
+
86
+ def is_ipadapter_image_encoder(self, state_dict):
87
+ param_name = "vision_model.encoder.layers.31.self_attn.v_proj.weight"
88
+ return param_name in state_dict and len(state_dict) == 521
89
+
90
+ def is_ipadapter_xl(self, state_dict):
91
+ return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([8192, 1280])
92
+
93
+ def is_ipadapter_xl_image_encoder(self, state_dict):
94
+ param_name = "vision_model.encoder.layers.47.self_attn.v_proj.weight"
95
+ return param_name in state_dict and len(state_dict) == 777
96
+
97
+ def is_hunyuan_dit_clip_text_encoder(self, state_dict):
98
+ param_name = "bert.encoder.layer.23.attention.output.dense.weight"
99
+ return param_name in state_dict
100
+
101
+ def is_hunyuan_dit_t5_text_encoder(self, state_dict):
102
+ param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
103
+ return param_name in state_dict
104
+
105
+ def is_hunyuan_dit(self, state_dict):
106
+ param_name = "final_layer.adaLN_modulation.1.weight"
107
+ return param_name in state_dict
108
+
109
+ def is_diffusers_vae(self, state_dict):
110
+ param_name = "quant_conv.weight"
111
+ return param_name in state_dict
112
+
113
+ def is_ExVideo_StableVideoDiffusion(self, state_dict):
114
+ param_name = "blocks.185.positional_embedding.embeddings"
115
+ return param_name in state_dict
116
+
117
+ def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None):
118
+ component_dict = {
119
+ "image_encoder": SVDImageEncoder,
120
+ "unet": SVDUNet,
121
+ "vae_decoder": SVDVAEDecoder,
122
+ "vae_encoder": SVDVAEEncoder,
123
+ }
124
+ if components is None:
125
+ components = ["image_encoder", "unet", "vae_decoder", "vae_encoder"]
126
+ for component in components:
127
+ if component == "unet":
128
+ self.model[component] = component_dict[component](add_positional_conv=add_positional_conv)
129
+ self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv), strict=False)
130
+ else:
131
+ self.model[component] = component_dict[component]()
132
+ self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
133
+ self.model[component].to(self.torch_dtype).to(self.device)
134
+ self.model_path[component] = file_path
135
+
136
+ def load_stable_diffusion(self, state_dict, components=None, file_path=""):
137
+ component_dict = {
138
+ "text_encoder": SDTextEncoder,
139
+ "unet": SDUNet,
140
+ "vae_decoder": SDVAEDecoder,
141
+ "vae_encoder": SDVAEEncoder,
142
+ "refiner": SDXLUNet,
143
+ }
144
+ if components is None:
145
+ components = ["text_encoder", "unet", "vae_decoder", "vae_encoder"]
146
+ for component in components:
147
+ if component == "text_encoder":
148
+ # Add additional token embeddings to text encoder
149
+ token_embeddings = [state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"]]
150
+ for keyword in self.textual_inversion_dict:
151
+ _, embeddings = self.textual_inversion_dict[keyword]
152
+ token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype))
153
+ token_embeddings = torch.concat(token_embeddings, dim=0)
154
+ state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings
155
+ self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0])
156
+ self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
157
+ self.model[component].to(self.torch_dtype).to(self.device)
158
+ else:
159
+ self.model[component] = component_dict[component]()
160
+ self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
161
+ self.model[component].to(self.torch_dtype).to(self.device)
162
+ self.model_path[component] = file_path
163
+
164
+ def load_stable_diffusion_xl(self, state_dict, components=None, file_path=""):
165
+ component_dict = {
166
+ "text_encoder": SDXLTextEncoder,
167
+ "text_encoder_2": SDXLTextEncoder2,
168
+ "unet": SDXLUNet,
169
+ "vae_decoder": SDXLVAEDecoder,
170
+ "vae_encoder": SDXLVAEEncoder,
171
+ }
172
+ if components is None:
173
+ components = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
174
+ for component in components:
175
+ self.model[component] = component_dict[component]()
176
+ self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
177
+ if component in ["vae_decoder", "vae_encoder"]:
178
+ # These two model will output nan when float16 is enabled.
179
+ # The precision problem happens in the last three resnet blocks.
180
+ # I do not know how to solve this problem.
181
+ self.model[component].to(torch.float32).to(self.device)
182
+ else:
183
+ self.model[component].to(self.torch_dtype).to(self.device)
184
+ self.model_path[component] = file_path
185
+
186
+ def load_controlnet(self, state_dict, file_path=""):
187
+ component = "controlnet"
188
+ if component not in self.model:
189
+ self.model[component] = []
190
+ self.model_path[component] = []
191
+ model = SDControlNet()
192
+ model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
193
+ model.to(self.torch_dtype).to(self.device)
194
+ self.model[component].append(model)
195
+ self.model_path[component].append(file_path)
196
+
197
+ def load_animatediff(self, state_dict, file_path=""):
198
+ component = "motion_modules"
199
+ model = SDMotionModel()
200
+ model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
201
+ model.to(self.torch_dtype).to(self.device)
202
+ self.model[component] = model
203
+ self.model_path[component] = file_path
204
+
205
+ def load_animatediff_xl(self, state_dict, file_path=""):
206
+ component = "motion_modules_xl"
207
+ model = SDXLMotionModel()
208
+ model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
209
+ model.to(self.torch_dtype).to(self.device)
210
+ self.model[component] = model
211
+ self.model_path[component] = file_path
212
+
213
+ def load_beautiful_prompt(self, state_dict, file_path=""):
214
+ component = "beautiful_prompt"
215
+ from transformers import AutoModelForCausalLM
216
+ model_folder = os.path.dirname(file_path)
217
+ model = AutoModelForCausalLM.from_pretrained(
218
+ model_folder, state_dict=state_dict, local_files_only=True, torch_dtype=self.torch_dtype
219
+ ).to(self.device).eval()
220
+ self.model[component] = model
221
+ self.model_path[component] = file_path
222
+
223
+ def load_RIFE(self, state_dict, file_path=""):
224
+ component = "RIFE"
225
+ from ..extensions.RIFE import IFNet
226
+ model = IFNet().eval()
227
+ model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
228
+ model.to(torch.float32).to(self.device)
229
+ self.model[component] = model
230
+ self.model_path[component] = file_path
231
+
232
+ def load_sd_lora(self, state_dict, alpha):
233
+ SDLoRA().add_lora_to_text_encoder(self.model["text_encoder"], state_dict, alpha=alpha, device=self.device)
234
+ SDLoRA().add_lora_to_unet(self.model["unet"], state_dict, alpha=alpha, device=self.device)
235
+
236
+ def load_translator(self, state_dict, file_path=""):
237
+ # This model is lightweight, we do not place it on GPU.
238
+ component = "translator"
239
+ from transformers import AutoModelForSeq2SeqLM
240
+ model_folder = os.path.dirname(file_path)
241
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_folder).eval()
242
+ self.model[component] = model
243
+ self.model_path[component] = file_path
244
+
245
+ def load_ipadapter(self, state_dict, file_path=""):
246
+ component = "ipadapter"
247
+ model = SDIpAdapter()
248
+ model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
249
+ model.to(self.torch_dtype).to(self.device)
250
+ self.model[component] = model
251
+ self.model_path[component] = file_path
252
+
253
+ def load_ipadapter_image_encoder(self, state_dict, file_path=""):
254
+ component = "ipadapter_image_encoder"
255
+ model = IpAdapterCLIPImageEmbedder()
256
+ model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
257
+ model.to(self.torch_dtype).to(self.device)
258
+ self.model[component] = model
259
+ self.model_path[component] = file_path
260
+
261
+ def load_ipadapter_xl(self, state_dict, file_path=""):
262
+ component = "ipadapter_xl"
263
+ model = SDXLIpAdapter()
264
+ model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
265
+ model.to(self.torch_dtype).to(self.device)
266
+ self.model[component] = model
267
+ self.model_path[component] = file_path
268
+
269
+ def load_ipadapter_xl_image_encoder(self, state_dict, file_path=""):
270
+ component = "ipadapter_xl_image_encoder"
271
+ model = IpAdapterXLCLIPImageEmbedder()
272
+ model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
273
+ model.to(self.torch_dtype).to(self.device)
274
+ self.model[component] = model
275
+ self.model_path[component] = file_path
276
+
277
+ def load_hunyuan_dit_clip_text_encoder(self, state_dict, file_path=""):
278
+ component = "hunyuan_dit_clip_text_encoder"
279
+ model = HunyuanDiTCLIPTextEncoder()
280
+ model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
281
+ model.to(self.torch_dtype).to(self.device)
282
+ self.model[component] = model
283
+ self.model_path[component] = file_path
284
+
285
+ def load_hunyuan_dit_t5_text_encoder(self, state_dict, file_path=""):
286
+ component = "hunyuan_dit_t5_text_encoder"
287
+ model = HunyuanDiTT5TextEncoder()
288
+ model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
289
+ model.to(self.torch_dtype).to(self.device)
290
+ self.model[component] = model
291
+ self.model_path[component] = file_path
292
+
293
+ def load_hunyuan_dit(self, state_dict, file_path=""):
294
+ component = "hunyuan_dit"
295
+ model = HunyuanDiT()
296
+ model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
297
+ model.to(self.torch_dtype).to(self.device)
298
+ self.model[component] = model
299
+ self.model_path[component] = file_path
300
+
301
+ def load_diffusers_vae(self, state_dict, file_path=""):
302
+ # TODO: detect SD and SDXL
303
+ component = "vae_encoder"
304
+ model = SDXLVAEEncoder()
305
+ model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
306
+ model.to(self.torch_dtype).to(self.device)
307
+ self.model[component] = model
308
+ self.model_path[component] = file_path
309
+ component = "vae_decoder"
310
+ model = SDXLVAEDecoder()
311
+ model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
312
+ model.to(self.torch_dtype).to(self.device)
313
+ self.model[component] = model
314
+ self.model_path[component] = file_path
315
+
316
+ def load_ExVideo_StableVideoDiffusion(self, state_dict, file_path=""):
317
+ unet_state_dict = self.model["unet"].state_dict()
318
+ self.model["unet"].to("cpu")
319
+ del self.model["unet"]
320
+ add_positional_conv = state_dict["blocks.185.positional_embedding.embeddings"].shape[0]
321
+ self.model["unet"] = SVDUNet(add_positional_conv=add_positional_conv)
322
+ self.model["unet"].load_state_dict(unet_state_dict, strict=False)
323
+ self.model["unet"].load_state_dict(state_dict, strict=False)
324
+ self.model["unet"].to(self.torch_dtype).to(self.device)
325
+
326
+ def search_for_embeddings(self, state_dict):
327
+ embeddings = []
328
+ for k in state_dict:
329
+ if isinstance(state_dict[k], torch.Tensor):
330
+ embeddings.append(state_dict[k])
331
+ elif isinstance(state_dict[k], dict):
332
+ embeddings += self.search_for_embeddings(state_dict[k])
333
+ return embeddings
334
+
335
+ def load_textual_inversions(self, folder):
336
+ # Store additional tokens here
337
+ self.textual_inversion_dict = {}
338
+
339
+ # Load every textual inversion file
340
+ for file_name in os.listdir(folder):
341
+ if file_name.endswith(".txt"):
342
+ continue
343
+ keyword = os.path.splitext(file_name)[0]
344
+ state_dict = load_state_dict(os.path.join(folder, file_name))
345
+
346
+ # Search for embeddings
347
+ for embeddings in self.search_for_embeddings(state_dict):
348
+ if len(embeddings.shape) == 2 and embeddings.shape[1] == 768:
349
+ tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])]
350
+ self.textual_inversion_dict[keyword] = (tokens, embeddings)
351
+ break
352
+
353
+ def load_model(self, file_path, components=None, lora_alphas=[]):
354
+ state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
355
+ if self.is_stable_video_diffusion(state_dict):
356
+ self.load_stable_video_diffusion(state_dict, file_path=file_path)
357
+ elif self.is_animatediff(state_dict):
358
+ self.load_animatediff(state_dict, file_path=file_path)
359
+ elif self.is_animatediff_xl(state_dict):
360
+ self.load_animatediff_xl(state_dict, file_path=file_path)
361
+ elif self.is_controlnet(state_dict):
362
+ self.load_controlnet(state_dict, file_path=file_path)
363
+ elif self.is_stabe_diffusion_xl(state_dict):
364
+ self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path)
365
+ elif self.is_stable_diffusion(state_dict):
366
+ self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
367
+ elif self.is_sd_lora(state_dict):
368
+ self.load_sd_lora(state_dict, alpha=lora_alphas.pop(0))
369
+ elif self.is_beautiful_prompt(state_dict):
370
+ self.load_beautiful_prompt(state_dict, file_path=file_path)
371
+ elif self.is_RIFE(state_dict):
372
+ self.load_RIFE(state_dict, file_path=file_path)
373
+ elif self.is_translator(state_dict):
374
+ self.load_translator(state_dict, file_path=file_path)
375
+ elif self.is_ipadapter(state_dict):
376
+ self.load_ipadapter(state_dict, file_path=file_path)
377
+ elif self.is_ipadapter_image_encoder(state_dict):
378
+ self.load_ipadapter_image_encoder(state_dict, file_path=file_path)
379
+ elif self.is_ipadapter_xl(state_dict):
380
+ self.load_ipadapter_xl(state_dict, file_path=file_path)
381
+ elif self.is_ipadapter_xl_image_encoder(state_dict):
382
+ self.load_ipadapter_xl_image_encoder(state_dict, file_path=file_path)
383
+ elif self.is_hunyuan_dit_clip_text_encoder(state_dict):
384
+ self.load_hunyuan_dit_clip_text_encoder(state_dict, file_path=file_path)
385
+ elif self.is_hunyuan_dit_t5_text_encoder(state_dict):
386
+ self.load_hunyuan_dit_t5_text_encoder(state_dict, file_path=file_path)
387
+ elif self.is_hunyuan_dit(state_dict):
388
+ self.load_hunyuan_dit(state_dict, file_path=file_path)
389
+ elif self.is_diffusers_vae(state_dict):
390
+ self.load_diffusers_vae(state_dict, file_path=file_path)
391
+ elif self.is_ExVideo_StableVideoDiffusion(state_dict):
392
+ self.load_ExVideo_StableVideoDiffusion(state_dict, file_path=file_path)
393
+
394
+ def load_models(self, file_path_list, lora_alphas=[]):
395
+ for file_path in file_path_list:
396
+ self.load_model(file_path, lora_alphas=lora_alphas)
397
+
398
+ def to(self, device):
399
+ for component in self.model:
400
+ if isinstance(self.model[component], list):
401
+ for model in self.model[component]:
402
+ model.to(device)
403
+ else:
404
+ self.model[component].to(device)
405
+ torch.cuda.empty_cache()
406
+
407
+ def get_model_with_model_path(self, model_path):
408
+ for component in self.model_path:
409
+ if isinstance(self.model_path[component], str):
410
+ if os.path.samefile(self.model_path[component], model_path):
411
+ return self.model[component]
412
+ elif isinstance(self.model_path[component], list):
413
+ for i, model_path_ in enumerate(self.model_path[component]):
414
+ if os.path.samefile(model_path_, model_path):
415
+ return self.model[component][i]
416
+ raise ValueError(f"Please load model {model_path} before you use it.")
417
+
418
+ def __getattr__(self, __name):
419
+ if __name in self.model:
420
+ return self.model[__name]
421
+ else:
422
+ return super.__getattribute__(__name)
423
+
424
+
425
+ def load_state_dict(file_path, torch_dtype=None):
426
+ if file_path.endswith(".safetensors"):
427
+ return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
428
+ else:
429
+ return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
430
+
431
+
432
+ def load_state_dict_from_safetensors(file_path, torch_dtype=None):
433
+ state_dict = {}
434
+ with safe_open(file_path, framework="pt", device="cpu") as f:
435
+ for k in f.keys():
436
+ state_dict[k] = f.get_tensor(k)
437
+ if torch_dtype is not None:
438
+ state_dict[k] = state_dict[k].to(torch_dtype)
439
+ return state_dict
440
+
441
+
442
+ def load_state_dict_from_bin(file_path, torch_dtype=None):
443
+ state_dict = torch.load(file_path, map_location="cpu")
444
+ if torch_dtype is not None:
445
+ for i in state_dict:
446
+ if isinstance(state_dict[i], torch.Tensor):
447
+ state_dict[i] = state_dict[i].to(torch_dtype)
448
+ return state_dict
449
+
450
+
451
+ def search_parameter(param, state_dict):
452
+ for name, param_ in state_dict.items():
453
+ if param.numel() == param_.numel():
454
+ if param.shape == param_.shape:
455
+ if torch.dist(param, param_) < 1e-6:
456
+ return name
457
+ else:
458
+ if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
459
+ return name
460
+ return None
461
+
462
+
463
+ def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
464
+ matched_keys = set()
465
+ with torch.no_grad():
466
+ for name in source_state_dict:
467
+ rename = search_parameter(source_state_dict[name], target_state_dict)
468
+ if rename is not None:
469
+ print(f'"{name}": "{rename}",')
470
+ matched_keys.add(rename)
471
+ elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
472
+ length = source_state_dict[name].shape[0] // 3
473
+ rename = []
474
+ for i in range(3):
475
+ rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
476
+ if None not in rename:
477
+ print(f'"{name}": {rename},')
478
+ for rename_ in rename:
479
+ matched_keys.add(rename_)
480
+ for name in target_state_dict:
481
+ if name not in matched_keys:
482
+ print("Cannot find", name, target_state_dict[name].shape)
diffsynth/models/attention.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+
4
+
5
+ def low_version_attention(query, key, value, attn_bias=None):
6
+ scale = 1 / query.shape[-1] ** 0.5
7
+ query = query * scale
8
+ attn = torch.matmul(query, key.transpose(-2, -1))
9
+ if attn_bias is not None:
10
+ attn = attn + attn_bias
11
+ attn = attn.softmax(-1)
12
+ return attn @ value
13
+
14
+
15
+ class Attention(torch.nn.Module):
16
+
17
+ def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
18
+ super().__init__()
19
+ dim_inner = head_dim * num_heads
20
+ kv_dim = kv_dim if kv_dim is not None else q_dim
21
+ self.num_heads = num_heads
22
+ self.head_dim = head_dim
23
+
24
+ self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
25
+ self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
26
+ self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
27
+ self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
28
+
29
+ def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
30
+ batch_size = q.shape[0]
31
+ ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
32
+ ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
33
+ ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
34
+ hidden_states = hidden_states + scale * ip_hidden_states
35
+ return hidden_states
36
+
37
+ def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
38
+ if encoder_hidden_states is None:
39
+ encoder_hidden_states = hidden_states
40
+
41
+ batch_size = encoder_hidden_states.shape[0]
42
+
43
+ q = self.to_q(hidden_states)
44
+ k = self.to_k(encoder_hidden_states)
45
+ v = self.to_v(encoder_hidden_states)
46
+
47
+ q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
48
+ k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
49
+ v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
50
+
51
+ if qkv_preprocessor is not None:
52
+ q, k, v = qkv_preprocessor(q, k, v)
53
+
54
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
55
+ if ipadapter_kwargs is not None:
56
+ hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
57
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
58
+ hidden_states = hidden_states.to(q.dtype)
59
+
60
+ hidden_states = self.to_out(hidden_states)
61
+
62
+ return hidden_states
63
+
64
+ def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
65
+ if encoder_hidden_states is None:
66
+ encoder_hidden_states = hidden_states
67
+
68
+ q = self.to_q(hidden_states)
69
+ k = self.to_k(encoder_hidden_states)
70
+ v = self.to_v(encoder_hidden_states)
71
+
72
+ q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
73
+ k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
74
+ v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
75
+
76
+ if attn_mask is not None:
77
+ hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
78
+ else:
79
+ import xformers.ops as xops
80
+ hidden_states = xops.memory_efficient_attention(q, k, v)
81
+ hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
82
+
83
+ hidden_states = hidden_states.to(q.dtype)
84
+ hidden_states = self.to_out(hidden_states)
85
+
86
+ return hidden_states
87
+
88
+ def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
89
+ return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
diffsynth/models/hunyuan_dit.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .attention import Attention
2
+ from .tiler import TileWorker
3
+ from einops import repeat, rearrange
4
+ import math
5
+ import torch
6
+
7
+
8
+ class HunyuanDiTRotaryEmbedding(torch.nn.Module):
9
+
10
+ def __init__(self, q_norm_shape=88, k_norm_shape=88, rotary_emb_on_k=True):
11
+ super().__init__()
12
+ self.q_norm = torch.nn.LayerNorm((q_norm_shape,), elementwise_affine=True, eps=1e-06)
13
+ self.k_norm = torch.nn.LayerNorm((k_norm_shape,), elementwise_affine=True, eps=1e-06)
14
+ self.rotary_emb_on_k = rotary_emb_on_k
15
+ self.k_cache, self.v_cache = [], []
16
+
17
+ def reshape_for_broadcast(self, freqs_cis, x):
18
+ ndim = x.ndim
19
+ shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
20
+ return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
21
+
22
+ def rotate_half(self, x):
23
+ x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
24
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
25
+
26
+ def apply_rotary_emb(self, xq, xk, freqs_cis):
27
+ xk_out = None
28
+ cos, sin = self.reshape_for_broadcast(freqs_cis, xq)
29
+ cos, sin = cos.to(xq.device), sin.to(xq.device)
30
+ xq_out = (xq.float() * cos + self.rotate_half(xq.float()) * sin).type_as(xq)
31
+ if xk is not None:
32
+ xk_out = (xk.float() * cos + self.rotate_half(xk.float()) * sin).type_as(xk)
33
+ return xq_out, xk_out
34
+
35
+ def forward(self, q, k, v, freqs_cis_img, to_cache=False):
36
+ # norm
37
+ q = self.q_norm(q)
38
+ k = self.k_norm(k)
39
+
40
+ # RoPE
41
+ if self.rotary_emb_on_k:
42
+ q, k = self.apply_rotary_emb(q, k, freqs_cis_img)
43
+ else:
44
+ q, _ = self.apply_rotary_emb(q, None, freqs_cis_img)
45
+
46
+ if to_cache:
47
+ self.k_cache.append(k)
48
+ self.v_cache.append(v)
49
+ elif len(self.k_cache) > 0 and len(self.v_cache) > 0:
50
+ k = torch.concat([k] + self.k_cache, dim=2)
51
+ v = torch.concat([v] + self.v_cache, dim=2)
52
+ self.k_cache, self.v_cache = [], []
53
+ return q, k, v
54
+
55
+
56
+ class FP32_Layernorm(torch.nn.LayerNorm):
57
+ def forward(self, inputs):
58
+ origin_dtype = inputs.dtype
59
+ return torch.nn.functional.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).to(origin_dtype)
60
+
61
+
62
+ class FP32_SiLU(torch.nn.SiLU):
63
+ def forward(self, inputs):
64
+ origin_dtype = inputs.dtype
65
+ return torch.nn.functional.silu(inputs.float(), inplace=False).to(origin_dtype)
66
+
67
+
68
+ class HunyuanDiTFinalLayer(torch.nn.Module):
69
+ def __init__(self, final_hidden_size=1408, condition_dim=1408, patch_size=2, out_channels=8):
70
+ super().__init__()
71
+ self.norm_final = torch.nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
72
+ self.linear = torch.nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
73
+ self.adaLN_modulation = torch.nn.Sequential(
74
+ FP32_SiLU(),
75
+ torch.nn.Linear(condition_dim, 2 * final_hidden_size, bias=True)
76
+ )
77
+
78
+ def modulate(self, x, shift, scale):
79
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
80
+
81
+ def forward(self, hidden_states, condition_emb):
82
+ shift, scale = self.adaLN_modulation(condition_emb).chunk(2, dim=1)
83
+ hidden_states = self.modulate(self.norm_final(hidden_states), shift, scale)
84
+ hidden_states = self.linear(hidden_states)
85
+ return hidden_states
86
+
87
+
88
+ class HunyuanDiTBlock(torch.nn.Module):
89
+
90
+ def __init__(
91
+ self,
92
+ hidden_dim=1408,
93
+ condition_dim=1408,
94
+ num_heads=16,
95
+ mlp_ratio=4.3637,
96
+ text_dim=1024,
97
+ skip_connection=False
98
+ ):
99
+ super().__init__()
100
+ self.norm1 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
101
+ self.rota1 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads)
102
+ self.attn1 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
103
+ self.norm2 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
104
+ self.rota2 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads, rotary_emb_on_k=False)
105
+ self.attn2 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, kv_dim=text_dim, bias_q=True, bias_kv=True, bias_out=True)
106
+ self.norm3 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
107
+ self.modulation = torch.nn.Sequential(FP32_SiLU(), torch.nn.Linear(condition_dim, hidden_dim, bias=True))
108
+ self.mlp = torch.nn.Sequential(
109
+ torch.nn.Linear(hidden_dim, int(hidden_dim*mlp_ratio), bias=True),
110
+ torch.nn.GELU(approximate="tanh"),
111
+ torch.nn.Linear(int(hidden_dim*mlp_ratio), hidden_dim, bias=True)
112
+ )
113
+ if skip_connection:
114
+ self.skip_norm = FP32_Layernorm((hidden_dim * 2,), eps=1e-6, elementwise_affine=True)
115
+ self.skip_linear = torch.nn.Linear(hidden_dim * 2, hidden_dim, bias=True)
116
+ else:
117
+ self.skip_norm, self.skip_linear = None, None
118
+
119
+ def forward(self, hidden_states, condition_emb, text_emb, freq_cis_img, residual=None, to_cache=False):
120
+ # Long Skip Connection
121
+ if self.skip_norm is not None and self.skip_linear is not None:
122
+ hidden_states = torch.cat([hidden_states, residual], dim=-1)
123
+ hidden_states = self.skip_norm(hidden_states)
124
+ hidden_states = self.skip_linear(hidden_states)
125
+
126
+ # Self-Attention
127
+ shift_msa = self.modulation(condition_emb).unsqueeze(dim=1)
128
+ attn_input = self.norm1(hidden_states) + shift_msa
129
+ hidden_states = hidden_states + self.attn1(attn_input, qkv_preprocessor=lambda q, k, v: self.rota1(q, k, v, freq_cis_img, to_cache=to_cache))
130
+
131
+ # Cross-Attention
132
+ attn_input = self.norm3(hidden_states)
133
+ hidden_states = hidden_states + self.attn2(attn_input, text_emb, qkv_preprocessor=lambda q, k, v: self.rota2(q, k, v, freq_cis_img))
134
+
135
+ # FFN Layer
136
+ mlp_input = self.norm2(hidden_states)
137
+ hidden_states = hidden_states + self.mlp(mlp_input)
138
+ return hidden_states
139
+
140
+
141
+ class AttentionPool(torch.nn.Module):
142
+ def __init__(self, spacial_dim, embed_dim, num_heads, output_dim = None):
143
+ super().__init__()
144
+ self.positional_embedding = torch.nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
145
+ self.k_proj = torch.nn.Linear(embed_dim, embed_dim)
146
+ self.q_proj = torch.nn.Linear(embed_dim, embed_dim)
147
+ self.v_proj = torch.nn.Linear(embed_dim, embed_dim)
148
+ self.c_proj = torch.nn.Linear(embed_dim, output_dim or embed_dim)
149
+ self.num_heads = num_heads
150
+
151
+ def forward(self, x):
152
+ x = x.permute(1, 0, 2) # NLC -> LNC
153
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
154
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
155
+ x, _ = torch.nn.functional.multi_head_attention_forward(
156
+ query=x[:1], key=x, value=x,
157
+ embed_dim_to_check=x.shape[-1],
158
+ num_heads=self.num_heads,
159
+ q_proj_weight=self.q_proj.weight,
160
+ k_proj_weight=self.k_proj.weight,
161
+ v_proj_weight=self.v_proj.weight,
162
+ in_proj_weight=None,
163
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
164
+ bias_k=None,
165
+ bias_v=None,
166
+ add_zero_attn=False,
167
+ dropout_p=0,
168
+ out_proj_weight=self.c_proj.weight,
169
+ out_proj_bias=self.c_proj.bias,
170
+ use_separate_proj_weight=True,
171
+ training=self.training,
172
+ need_weights=False
173
+ )
174
+ return x.squeeze(0)
175
+
176
+
177
+ class PatchEmbed(torch.nn.Module):
178
+ def __init__(
179
+ self,
180
+ patch_size=(2, 2),
181
+ in_chans=4,
182
+ embed_dim=1408,
183
+ bias=True,
184
+ ):
185
+ super().__init__()
186
+ self.proj = torch.nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
187
+
188
+ def forward(self, x):
189
+ x = self.proj(x)
190
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
191
+ return x
192
+
193
+
194
+ def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
195
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
196
+ if not repeat_only:
197
+ half = dim // 2
198
+ freqs = torch.exp(
199
+ -math.log(max_period)
200
+ * torch.arange(start=0, end=half, dtype=torch.float32)
201
+ / half
202
+ ).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线
203
+ args = t[:, None].float() * freqs[None]
204
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
205
+ if dim % 2:
206
+ embedding = torch.cat(
207
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
208
+ )
209
+ else:
210
+ embedding = repeat(t, "b -> b d", d=dim)
211
+ return embedding
212
+
213
+
214
+ class TimestepEmbedder(torch.nn.Module):
215
+ def __init__(self, hidden_size=1408, frequency_embedding_size=256):
216
+ super().__init__()
217
+ self.mlp = torch.nn.Sequential(
218
+ torch.nn.Linear(frequency_embedding_size, hidden_size, bias=True),
219
+ torch.nn.SiLU(),
220
+ torch.nn.Linear(hidden_size, hidden_size, bias=True),
221
+ )
222
+ self.frequency_embedding_size = frequency_embedding_size
223
+
224
+ def forward(self, t):
225
+ t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
226
+ t_emb = self.mlp(t_freq)
227
+ return t_emb
228
+
229
+
230
+ class HunyuanDiT(torch.nn.Module):
231
+ def __init__(self, num_layers_down=21, num_layers_up=19, in_channels=4, out_channels=8, hidden_dim=1408, text_dim=1024, t5_dim=2048, text_length=77, t5_length=256):
232
+ super().__init__()
233
+
234
+ # Embedders
235
+ self.text_emb_padding = torch.nn.Parameter(torch.randn(text_length + t5_length, text_dim, dtype=torch.float32))
236
+ self.t5_embedder = torch.nn.Sequential(
237
+ torch.nn.Linear(t5_dim, t5_dim * 4, bias=True),
238
+ FP32_SiLU(),
239
+ torch.nn.Linear(t5_dim * 4, text_dim, bias=True),
240
+ )
241
+ self.t5_pooler = AttentionPool(t5_length, t5_dim, num_heads=8, output_dim=1024)
242
+ self.style_embedder = torch.nn.Parameter(torch.randn(hidden_dim))
243
+ self.patch_embedder = PatchEmbed(in_chans=in_channels)
244
+ self.timestep_embedder = TimestepEmbedder()
245
+ self.extra_embedder = torch.nn.Sequential(
246
+ torch.nn.Linear(256 * 6 + 1024 + hidden_dim, hidden_dim * 4),
247
+ FP32_SiLU(),
248
+ torch.nn.Linear(hidden_dim * 4, hidden_dim),
249
+ )
250
+
251
+ # Transformer blocks
252
+ self.num_layers_down = num_layers_down
253
+ self.num_layers_up = num_layers_up
254
+ self.blocks = torch.nn.ModuleList(
255
+ [HunyuanDiTBlock(skip_connection=False) for _ in range(num_layers_down)] + \
256
+ [HunyuanDiTBlock(skip_connection=True) for _ in range(num_layers_up)]
257
+ )
258
+
259
+ # Output layers
260
+ self.final_layer = HunyuanDiTFinalLayer()
261
+ self.out_channels = out_channels
262
+
263
+ def prepare_text_emb(self, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5):
264
+ text_emb_mask = text_emb_mask.bool()
265
+ text_emb_mask_t5 = text_emb_mask_t5.bool()
266
+ text_emb_t5 = self.t5_embedder(text_emb_t5)
267
+ text_emb = torch.cat([text_emb, text_emb_t5], dim=1)
268
+ text_emb_mask = torch.cat([text_emb_mask, text_emb_mask_t5], dim=-1)
269
+ text_emb = torch.where(text_emb_mask.unsqueeze(2), text_emb, self.text_emb_padding.to(text_emb))
270
+ return text_emb
271
+
272
+ def prepare_extra_emb(self, text_emb_t5, timestep, size_emb, dtype, batch_size):
273
+ # Text embedding
274
+ pooled_text_emb_t5 = self.t5_pooler(text_emb_t5)
275
+
276
+ # Timestep embedding
277
+ timestep_emb = self.timestep_embedder(timestep)
278
+
279
+ # Size embedding
280
+ size_emb = timestep_embedding(size_emb.view(-1), 256).to(dtype)
281
+ size_emb = size_emb.view(-1, 6 * 256)
282
+
283
+ # Style embedding
284
+ style_emb = repeat(self.style_embedder, "D -> B D", B=batch_size)
285
+
286
+ # Concatenate all extra vectors
287
+ extra_emb = torch.cat([pooled_text_emb_t5, size_emb, style_emb], dim=1)
288
+ condition_emb = timestep_emb + self.extra_embedder(extra_emb)
289
+
290
+ return condition_emb
291
+
292
+ def unpatchify(self, x, h, w):
293
+ return rearrange(x, "B (H W) (P Q C) -> B C (H P) (W Q)", H=h, W=w, P=2, Q=2)
294
+
295
+ def build_mask(self, data, is_bound):
296
+ _, _, H, W = data.shape
297
+ h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
298
+ w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
299
+ border_width = (H + W) // 4
300
+ pad = torch.ones_like(h) * border_width
301
+ mask = torch.stack([
302
+ pad if is_bound[0] else h + 1,
303
+ pad if is_bound[1] else H - h,
304
+ pad if is_bound[2] else w + 1,
305
+ pad if is_bound[3] else W - w
306
+ ]).min(dim=0).values
307
+ mask = mask.clip(1, border_width)
308
+ mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
309
+ mask = rearrange(mask, "H W -> 1 H W")
310
+ return mask
311
+
312
+ def tiled_block_forward(self, block, hidden_states, condition_emb, text_emb, freq_cis_img, residual, torch_dtype, data_device, computation_device, tile_size, tile_stride):
313
+ B, C, H, W = hidden_states.shape
314
+
315
+ weight = torch.zeros((1, 1, H, W), dtype=torch_dtype, device=data_device)
316
+ values = torch.zeros((B, C, H, W), dtype=torch_dtype, device=data_device)
317
+
318
+ # Split tasks
319
+ tasks = []
320
+ for h in range(0, H, tile_stride):
321
+ for w in range(0, W, tile_stride):
322
+ if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
323
+ continue
324
+ h_, w_ = h + tile_size, w + tile_size
325
+ if h_ > H: h, h_ = H - tile_size, H
326
+ if w_ > W: w, w_ = W - tile_size, W
327
+ tasks.append((h, h_, w, w_))
328
+
329
+ # Run
330
+ for hl, hr, wl, wr in tasks:
331
+ hidden_states_batch = hidden_states[:, :, hl:hr, wl:wr].to(computation_device)
332
+ hidden_states_batch = rearrange(hidden_states_batch, "B C H W -> B (H W) C")
333
+ if residual is not None:
334
+ residual_batch = residual[:, :, hl:hr, wl:wr].to(computation_device)
335
+ residual_batch = rearrange(residual_batch, "B C H W -> B (H W) C")
336
+ else:
337
+ residual_batch = None
338
+
339
+ # Forward
340
+ hidden_states_batch = block(hidden_states_batch, condition_emb, text_emb, freq_cis_img, residual_batch).to(data_device)
341
+ hidden_states_batch = rearrange(hidden_states_batch, "B (H W) C -> B C H W", H=hr-hl)
342
+
343
+ mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
344
+ values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
345
+ weight[:, :, hl:hr, wl:wr] += mask
346
+ values /= weight
347
+ return values
348
+
349
+ def forward(
350
+ self, hidden_states, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5, timestep, size_emb, freq_cis_img,
351
+ tiled=False, tile_size=64, tile_stride=32,
352
+ to_cache=False,
353
+ use_gradient_checkpointing=False,
354
+ ):
355
+ # Embeddings
356
+ text_emb = self.prepare_text_emb(text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5)
357
+ condition_emb = self.prepare_extra_emb(text_emb_t5, timestep, size_emb, hidden_states.dtype, hidden_states.shape[0])
358
+
359
+ # Input
360
+ height, width = hidden_states.shape[-2], hidden_states.shape[-1]
361
+ hidden_states = self.patch_embedder(hidden_states)
362
+
363
+ # Blocks
364
+ def create_custom_forward(module):
365
+ def custom_forward(*inputs):
366
+ return module(*inputs)
367
+ return custom_forward
368
+ if tiled:
369
+ hidden_states = rearrange(hidden_states, "B (H W) C -> B C H W", H=height//2)
370
+ residuals = []
371
+ for block_id, block in enumerate(self.blocks):
372
+ residual = residuals.pop() if block_id >= self.num_layers_down else None
373
+ hidden_states = self.tiled_block_forward(
374
+ block, hidden_states, condition_emb, text_emb, freq_cis_img, residual,
375
+ torch_dtype=hidden_states.dtype, data_device=hidden_states.device, computation_device=hidden_states.device,
376
+ tile_size=tile_size, tile_stride=tile_stride
377
+ )
378
+ if block_id < self.num_layers_down - 2:
379
+ residuals.append(hidden_states)
380
+ hidden_states = rearrange(hidden_states, "B C H W -> B (H W) C")
381
+ else:
382
+ residuals = []
383
+ for block_id, block in enumerate(self.blocks):
384
+ residual = residuals.pop() if block_id >= self.num_layers_down else None
385
+ if self.training and use_gradient_checkpointing:
386
+ hidden_states = torch.utils.checkpoint.checkpoint(
387
+ create_custom_forward(block),
388
+ hidden_states, condition_emb, text_emb, freq_cis_img, residual,
389
+ use_reentrant=False,
390
+ )
391
+ else:
392
+ hidden_states = block(hidden_states, condition_emb, text_emb, freq_cis_img, residual, to_cache=to_cache)
393
+ if block_id < self.num_layers_down - 2:
394
+ residuals.append(hidden_states)
395
+
396
+ # Output
397
+ hidden_states = self.final_layer(hidden_states, condition_emb)
398
+ hidden_states = self.unpatchify(hidden_states, height//2, width//2)
399
+ hidden_states, _ = hidden_states.chunk(2, dim=1)
400
+ return hidden_states
401
+
402
+ def state_dict_converter(self):
403
+ return HunyuanDiTStateDictConverter()
404
+
405
+
406
+
407
+ class HunyuanDiTStateDictConverter():
408
+ def __init__(self):
409
+ pass
410
+
411
+ def from_diffusers(self, state_dict):
412
+ state_dict_ = {}
413
+ for name, param in state_dict.items():
414
+ name_ = name
415
+ name_ = name_.replace(".default_modulation.", ".modulation.")
416
+ name_ = name_.replace(".mlp.fc1.", ".mlp.0.")
417
+ name_ = name_.replace(".mlp.fc2.", ".mlp.2.")
418
+ name_ = name_.replace(".attn1.q_norm.", ".rota1.q_norm.")
419
+ name_ = name_.replace(".attn2.q_norm.", ".rota2.q_norm.")
420
+ name_ = name_.replace(".attn1.k_norm.", ".rota1.k_norm.")
421
+ name_ = name_.replace(".attn2.k_norm.", ".rota2.k_norm.")
422
+ name_ = name_.replace(".q_proj.", ".to_q.")
423
+ name_ = name_.replace(".out_proj.", ".to_out.")
424
+ name_ = name_.replace("text_embedding_padding", "text_emb_padding")
425
+ name_ = name_.replace("mlp_t5.0.", "t5_embedder.0.")
426
+ name_ = name_.replace("mlp_t5.2.", "t5_embedder.2.")
427
+ name_ = name_.replace("pooler.", "t5_pooler.")
428
+ name_ = name_.replace("x_embedder.", "patch_embedder.")
429
+ name_ = name_.replace("t_embedder.", "timestep_embedder.")
430
+ name_ = name_.replace("t5_pooler.to_q.", "t5_pooler.q_proj.")
431
+ name_ = name_.replace("style_embedder.weight", "style_embedder")
432
+ if ".kv_proj." in name_:
433
+ param_k = param[:param.shape[0]//2]
434
+ param_v = param[param.shape[0]//2:]
435
+ state_dict_[name_.replace(".kv_proj.", ".to_k.")] = param_k
436
+ state_dict_[name_.replace(".kv_proj.", ".to_v.")] = param_v
437
+ elif ".Wqkv." in name_:
438
+ param_q = param[:param.shape[0]//3]
439
+ param_k = param[param.shape[0]//3:param.shape[0]//3*2]
440
+ param_v = param[param.shape[0]//3*2:]
441
+ state_dict_[name_.replace(".Wqkv.", ".to_q.")] = param_q
442
+ state_dict_[name_.replace(".Wqkv.", ".to_k.")] = param_k
443
+ state_dict_[name_.replace(".Wqkv.", ".to_v.")] = param_v
444
+ elif "style_embedder" in name_:
445
+ state_dict_[name_] = param.squeeze()
446
+ else:
447
+ state_dict_[name_] = param
448
+ return state_dict_
449
+
450
+ def from_civitai(self, state_dict):
451
+ return self.from_diffusers(state_dict)
diffsynth/models/hunyuan_dit_text_encoder.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertModel, BertConfig, T5EncoderModel, T5Config
2
+ import torch
3
+
4
+
5
+
6
+ class HunyuanDiTCLIPTextEncoder(BertModel):
7
+ def __init__(self):
8
+ config = BertConfig(
9
+ _name_or_path = "",
10
+ architectures = ["BertModel"],
11
+ attention_probs_dropout_prob = 0.1,
12
+ bos_token_id = 0,
13
+ classifier_dropout = None,
14
+ directionality = "bidi",
15
+ eos_token_id = 2,
16
+ hidden_act = "gelu",
17
+ hidden_dropout_prob = 0.1,
18
+ hidden_size = 1024,
19
+ initializer_range = 0.02,
20
+ intermediate_size = 4096,
21
+ layer_norm_eps = 1e-12,
22
+ max_position_embeddings = 512,
23
+ model_type = "bert",
24
+ num_attention_heads = 16,
25
+ num_hidden_layers = 24,
26
+ output_past = True,
27
+ pad_token_id = 0,
28
+ pooler_fc_size = 768,
29
+ pooler_num_attention_heads = 12,
30
+ pooler_num_fc_layers = 3,
31
+ pooler_size_per_head = 128,
32
+ pooler_type = "first_token_transform",
33
+ position_embedding_type = "absolute",
34
+ torch_dtype = "float32",
35
+ transformers_version = "4.37.2",
36
+ type_vocab_size = 2,
37
+ use_cache = True,
38
+ vocab_size = 47020
39
+ )
40
+ super().__init__(config, add_pooling_layer=False)
41
+ self.eval()
42
+
43
+ def forward(self, input_ids, attention_mask, clip_skip=1):
44
+ input_shape = input_ids.size()
45
+
46
+ batch_size, seq_length = input_shape
47
+ device = input_ids.device
48
+
49
+ past_key_values_length = 0
50
+
51
+ if attention_mask is None:
52
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
53
+
54
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
55
+
56
+ embedding_output = self.embeddings(
57
+ input_ids=input_ids,
58
+ position_ids=None,
59
+ token_type_ids=None,
60
+ inputs_embeds=None,
61
+ past_key_values_length=0,
62
+ )
63
+ encoder_outputs = self.encoder(
64
+ embedding_output,
65
+ attention_mask=extended_attention_mask,
66
+ head_mask=None,
67
+ encoder_hidden_states=None,
68
+ encoder_attention_mask=None,
69
+ past_key_values=None,
70
+ use_cache=False,
71
+ output_attentions=False,
72
+ output_hidden_states=True,
73
+ return_dict=True,
74
+ )
75
+ all_hidden_states = encoder_outputs.hidden_states
76
+ prompt_emb = all_hidden_states[-clip_skip]
77
+ if clip_skip > 1:
78
+ mean, std = all_hidden_states[-1].mean(), all_hidden_states[-1].std()
79
+ prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
80
+ return prompt_emb
81
+
82
+ def state_dict_converter(self):
83
+ return HunyuanDiTCLIPTextEncoderStateDictConverter()
84
+
85
+
86
+
87
+ class HunyuanDiTT5TextEncoder(T5EncoderModel):
88
+ def __init__(self):
89
+ config = T5Config(
90
+ _name_or_path = "../HunyuanDiT/t2i/mt5",
91
+ architectures = ["MT5ForConditionalGeneration"],
92
+ classifier_dropout = 0.0,
93
+ d_ff = 5120,
94
+ d_kv = 64,
95
+ d_model = 2048,
96
+ decoder_start_token_id = 0,
97
+ dense_act_fn = "gelu_new",
98
+ dropout_rate = 0.1,
99
+ eos_token_id = 1,
100
+ feed_forward_proj = "gated-gelu",
101
+ initializer_factor = 1.0,
102
+ is_encoder_decoder = True,
103
+ is_gated_act = True,
104
+ layer_norm_epsilon = 1e-06,
105
+ model_type = "t5",
106
+ num_decoder_layers = 24,
107
+ num_heads = 32,
108
+ num_layers = 24,
109
+ output_past = True,
110
+ pad_token_id = 0,
111
+ relative_attention_max_distance = 128,
112
+ relative_attention_num_buckets = 32,
113
+ tie_word_embeddings = False,
114
+ tokenizer_class = "T5Tokenizer",
115
+ transformers_version = "4.37.2",
116
+ use_cache = True,
117
+ vocab_size = 250112
118
+ )
119
+ super().__init__(config)
120
+ self.eval()
121
+
122
+ def forward(self, input_ids, attention_mask, clip_skip=1):
123
+ outputs = super().forward(
124
+ input_ids=input_ids,
125
+ attention_mask=attention_mask,
126
+ output_hidden_states=True,
127
+ )
128
+ prompt_emb = outputs.hidden_states[-clip_skip]
129
+ if clip_skip > 1:
130
+ mean, std = outputs.hidden_states[-1].mean(), outputs.hidden_states[-1].std()
131
+ prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
132
+ return prompt_emb
133
+
134
+ def state_dict_converter(self):
135
+ return HunyuanDiTT5TextEncoderStateDictConverter()
136
+
137
+
138
+
139
+ class HunyuanDiTCLIPTextEncoderStateDictConverter():
140
+ def __init__(self):
141
+ pass
142
+
143
+ def from_diffusers(self, state_dict):
144
+ state_dict_ = {name[5:]: param for name, param in state_dict.items() if name.startswith("bert.")}
145
+ return state_dict_
146
+
147
+ def from_civitai(self, state_dict):
148
+ return self.from_diffusers(state_dict)
149
+
150
+
151
+ class HunyuanDiTT5TextEncoderStateDictConverter():
152
+ def __init__(self):
153
+ pass
154
+
155
+ def from_diffusers(self, state_dict):
156
+ state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("encoder.")}
157
+ state_dict_["shared.weight"] = state_dict["shared.weight"]
158
+ return state_dict_
159
+
160
+ def from_civitai(self, state_dict):
161
+ return self.from_diffusers(state_dict)
diffsynth/models/sd_controlnet.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
3
+ from .tiler import TileWorker
4
+
5
+
6
+ class ControlNetConditioningLayer(torch.nn.Module):
7
+ def __init__(self, channels = (3, 16, 32, 96, 256, 320)):
8
+ super().__init__()
9
+ self.blocks = torch.nn.ModuleList([])
10
+ self.blocks.append(torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1))
11
+ self.blocks.append(torch.nn.SiLU())
12
+ for i in range(1, len(channels) - 2):
13
+ self.blocks.append(torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1))
14
+ self.blocks.append(torch.nn.SiLU())
15
+ self.blocks.append(torch.nn.Conv2d(channels[i], channels[i+1], kernel_size=3, padding=1, stride=2))
16
+ self.blocks.append(torch.nn.SiLU())
17
+ self.blocks.append(torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1))
18
+
19
+ def forward(self, conditioning):
20
+ for block in self.blocks:
21
+ conditioning = block(conditioning)
22
+ return conditioning
23
+
24
+
25
+ class SDControlNet(torch.nn.Module):
26
+ def __init__(self, global_pool=False):
27
+ super().__init__()
28
+ self.time_proj = Timesteps(320)
29
+ self.time_embedding = torch.nn.Sequential(
30
+ torch.nn.Linear(320, 1280),
31
+ torch.nn.SiLU(),
32
+ torch.nn.Linear(1280, 1280)
33
+ )
34
+ self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
35
+
36
+ self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
37
+
38
+ self.blocks = torch.nn.ModuleList([
39
+ # CrossAttnDownBlock2D
40
+ ResnetBlock(320, 320, 1280),
41
+ AttentionBlock(8, 40, 320, 1, 768),
42
+ PushBlock(),
43
+ ResnetBlock(320, 320, 1280),
44
+ AttentionBlock(8, 40, 320, 1, 768),
45
+ PushBlock(),
46
+ DownSampler(320),
47
+ PushBlock(),
48
+ # CrossAttnDownBlock2D
49
+ ResnetBlock(320, 640, 1280),
50
+ AttentionBlock(8, 80, 640, 1, 768),
51
+ PushBlock(),
52
+ ResnetBlock(640, 640, 1280),
53
+ AttentionBlock(8, 80, 640, 1, 768),
54
+ PushBlock(),
55
+ DownSampler(640),
56
+ PushBlock(),
57
+ # CrossAttnDownBlock2D
58
+ ResnetBlock(640, 1280, 1280),
59
+ AttentionBlock(8, 160, 1280, 1, 768),
60
+ PushBlock(),
61
+ ResnetBlock(1280, 1280, 1280),
62
+ AttentionBlock(8, 160, 1280, 1, 768),
63
+ PushBlock(),
64
+ DownSampler(1280),
65
+ PushBlock(),
66
+ # DownBlock2D
67
+ ResnetBlock(1280, 1280, 1280),
68
+ PushBlock(),
69
+ ResnetBlock(1280, 1280, 1280),
70
+ PushBlock(),
71
+ # UNetMidBlock2DCrossAttn
72
+ ResnetBlock(1280, 1280, 1280),
73
+ AttentionBlock(8, 160, 1280, 1, 768),
74
+ ResnetBlock(1280, 1280, 1280),
75
+ PushBlock()
76
+ ])
77
+
78
+ self.controlnet_blocks = torch.nn.ModuleList([
79
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
80
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
81
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
82
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
83
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
84
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
85
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
86
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
87
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
88
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
89
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
90
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
91
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
92
+ ])
93
+
94
+ self.global_pool = global_pool
95
+
96
+ def forward(
97
+ self,
98
+ sample, timestep, encoder_hidden_states, conditioning,
99
+ tiled=False, tile_size=64, tile_stride=32,
100
+ ):
101
+ # 1. time
102
+ time_emb = self.time_proj(timestep[None]).to(sample.dtype)
103
+ time_emb = self.time_embedding(time_emb)
104
+ time_emb = time_emb.repeat(sample.shape[0], 1)
105
+
106
+ # 2. pre-process
107
+ height, width = sample.shape[2], sample.shape[3]
108
+ hidden_states = self.conv_in(sample) + self.controlnet_conv_in(conditioning)
109
+ text_emb = encoder_hidden_states
110
+ res_stack = [hidden_states]
111
+
112
+ # 3. blocks
113
+ for i, block in enumerate(self.blocks):
114
+ if tiled and not isinstance(block, PushBlock):
115
+ _, _, inter_height, _ = hidden_states.shape
116
+ resize_scale = inter_height / height
117
+ hidden_states = TileWorker().tiled_forward(
118
+ lambda x: block(x, time_emb, text_emb, res_stack)[0],
119
+ hidden_states,
120
+ int(tile_size * resize_scale),
121
+ int(tile_stride * resize_scale),
122
+ tile_device=hidden_states.device,
123
+ tile_dtype=hidden_states.dtype
124
+ )
125
+ else:
126
+ hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
127
+
128
+ # 4. ControlNet blocks
129
+ controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
130
+
131
+ # pool
132
+ if self.global_pool:
133
+ controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack]
134
+
135
+ return controlnet_res_stack
136
+
137
+ def state_dict_converter(self):
138
+ return SDControlNetStateDictConverter()
139
+
140
+
141
+ class SDControlNetStateDictConverter:
142
+ def __init__(self):
143
+ pass
144
+
145
+ def from_diffusers(self, state_dict):
146
+ # architecture
147
+ block_types = [
148
+ 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
149
+ 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
150
+ 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
151
+ 'ResnetBlock', 'PushBlock', 'ResnetBlock', 'PushBlock',
152
+ 'ResnetBlock', 'AttentionBlock', 'ResnetBlock',
153
+ 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'UpSampler',
154
+ 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
155
+ 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
156
+ 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock'
157
+ ]
158
+
159
+ # controlnet_rename_dict
160
+ controlnet_rename_dict = {
161
+ "controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight",
162
+ "controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias",
163
+ "controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight",
164
+ "controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias",
165
+ "controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight",
166
+ "controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias",
167
+ "controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight",
168
+ "controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias",
169
+ "controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight",
170
+ "controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias",
171
+ "controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight",
172
+ "controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias",
173
+ "controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
174
+ "controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
175
+ "controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
176
+ "controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
177
+ }
178
+
179
+ # Rename each parameter
180
+ name_list = sorted([name for name in state_dict])
181
+ rename_dict = {}
182
+ block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
183
+ last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
184
+ for name in name_list:
185
+ names = name.split(".")
186
+ if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
187
+ pass
188
+ elif name in controlnet_rename_dict:
189
+ names = controlnet_rename_dict[name].split(".")
190
+ elif names[0] == "controlnet_down_blocks":
191
+ names[0] = "controlnet_blocks"
192
+ elif names[0] == "controlnet_mid_block":
193
+ names = ["controlnet_blocks", "12", names[-1]]
194
+ elif names[0] in ["time_embedding", "add_embedding"]:
195
+ if names[0] == "add_embedding":
196
+ names[0] = "add_time_embedding"
197
+ names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
198
+ elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
199
+ if names[0] == "mid_block":
200
+ names.insert(1, "0")
201
+ block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
202
+ block_type_with_id = ".".join(names[:4])
203
+ if block_type_with_id != last_block_type_with_id[block_type]:
204
+ block_id[block_type] += 1
205
+ last_block_type_with_id[block_type] = block_type_with_id
206
+ while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
207
+ block_id[block_type] += 1
208
+ block_type_with_id = ".".join(names[:4])
209
+ names = ["blocks", str(block_id[block_type])] + names[4:]
210
+ if "ff" in names:
211
+ ff_index = names.index("ff")
212
+ component = ".".join(names[ff_index:ff_index+3])
213
+ component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
214
+ names = names[:ff_index] + [component] + names[ff_index+3:]
215
+ if "to_out" in names:
216
+ names.pop(names.index("to_out") + 1)
217
+ else:
218
+ raise ValueError(f"Unknown parameters: {name}")
219
+ rename_dict[name] = ".".join(names)
220
+
221
+ # Convert state_dict
222
+ state_dict_ = {}
223
+ for name, param in state_dict.items():
224
+ if ".proj_in." in name or ".proj_out." in name:
225
+ param = param.squeeze()
226
+ if rename_dict[name] in [
227
+ "controlnet_blocks.1.bias", "controlnet_blocks.2.bias", "controlnet_blocks.3.bias", "controlnet_blocks.5.bias", "controlnet_blocks.6.bias",
228
+ "controlnet_blocks.8.bias", "controlnet_blocks.9.bias", "controlnet_blocks.10.bias", "controlnet_blocks.11.bias", "controlnet_blocks.12.bias"
229
+ ]:
230
+ continue
231
+ state_dict_[rename_dict[name]] = param
232
+ return state_dict_
233
+
234
+ def from_civitai(self, state_dict):
235
+ if "mid_block.resnets.1.time_emb_proj.weight" in state_dict:
236
+ # For controlnets in diffusers format
237
+ return self.from_diffusers(state_dict)
238
+ rename_dict = {
239
+ "control_model.time_embed.0.weight": "time_embedding.0.weight",
240
+ "control_model.time_embed.0.bias": "time_embedding.0.bias",
241
+ "control_model.time_embed.2.weight": "time_embedding.2.weight",
242
+ "control_model.time_embed.2.bias": "time_embedding.2.bias",
243
+ "control_model.input_blocks.0.0.weight": "conv_in.weight",
244
+ "control_model.input_blocks.0.0.bias": "conv_in.bias",
245
+ "control_model.input_blocks.1.0.in_layers.0.weight": "blocks.0.norm1.weight",
246
+ "control_model.input_blocks.1.0.in_layers.0.bias": "blocks.0.norm1.bias",
247
+ "control_model.input_blocks.1.0.in_layers.2.weight": "blocks.0.conv1.weight",
248
+ "control_model.input_blocks.1.0.in_layers.2.bias": "blocks.0.conv1.bias",
249
+ "control_model.input_blocks.1.0.emb_layers.1.weight": "blocks.0.time_emb_proj.weight",
250
+ "control_model.input_blocks.1.0.emb_layers.1.bias": "blocks.0.time_emb_proj.bias",
251
+ "control_model.input_blocks.1.0.out_layers.0.weight": "blocks.0.norm2.weight",
252
+ "control_model.input_blocks.1.0.out_layers.0.bias": "blocks.0.norm2.bias",
253
+ "control_model.input_blocks.1.0.out_layers.3.weight": "blocks.0.conv2.weight",
254
+ "control_model.input_blocks.1.0.out_layers.3.bias": "blocks.0.conv2.bias",
255
+ "control_model.input_blocks.1.1.norm.weight": "blocks.1.norm.weight",
256
+ "control_model.input_blocks.1.1.norm.bias": "blocks.1.norm.bias",
257
+ "control_model.input_blocks.1.1.proj_in.weight": "blocks.1.proj_in.weight",
258
+ "control_model.input_blocks.1.1.proj_in.bias": "blocks.1.proj_in.bias",
259
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "blocks.1.transformer_blocks.0.attn1.to_q.weight",
260
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "blocks.1.transformer_blocks.0.attn1.to_k.weight",
261
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "blocks.1.transformer_blocks.0.attn1.to_v.weight",
262
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.1.transformer_blocks.0.attn1.to_out.weight",
263
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.1.transformer_blocks.0.attn1.to_out.bias",
264
+ "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.1.transformer_blocks.0.act_fn.proj.weight",
265
+ "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.1.transformer_blocks.0.act_fn.proj.bias",
266
+ "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "blocks.1.transformer_blocks.0.ff.weight",
267
+ "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "blocks.1.transformer_blocks.0.ff.bias",
268
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "blocks.1.transformer_blocks.0.attn2.to_q.weight",
269
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "blocks.1.transformer_blocks.0.attn2.to_k.weight",
270
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "blocks.1.transformer_blocks.0.attn2.to_v.weight",
271
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.1.transformer_blocks.0.attn2.to_out.weight",
272
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.1.transformer_blocks.0.attn2.to_out.bias",
273
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "blocks.1.transformer_blocks.0.norm1.weight",
274
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "blocks.1.transformer_blocks.0.norm1.bias",
275
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "blocks.1.transformer_blocks.0.norm2.weight",
276
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "blocks.1.transformer_blocks.0.norm2.bias",
277
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "blocks.1.transformer_blocks.0.norm3.weight",
278
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "blocks.1.transformer_blocks.0.norm3.bias",
279
+ "control_model.input_blocks.1.1.proj_out.weight": "blocks.1.proj_out.weight",
280
+ "control_model.input_blocks.1.1.proj_out.bias": "blocks.1.proj_out.bias",
281
+ "control_model.input_blocks.2.0.in_layers.0.weight": "blocks.3.norm1.weight",
282
+ "control_model.input_blocks.2.0.in_layers.0.bias": "blocks.3.norm1.bias",
283
+ "control_model.input_blocks.2.0.in_layers.2.weight": "blocks.3.conv1.weight",
284
+ "control_model.input_blocks.2.0.in_layers.2.bias": "blocks.3.conv1.bias",
285
+ "control_model.input_blocks.2.0.emb_layers.1.weight": "blocks.3.time_emb_proj.weight",
286
+ "control_model.input_blocks.2.0.emb_layers.1.bias": "blocks.3.time_emb_proj.bias",
287
+ "control_model.input_blocks.2.0.out_layers.0.weight": "blocks.3.norm2.weight",
288
+ "control_model.input_blocks.2.0.out_layers.0.bias": "blocks.3.norm2.bias",
289
+ "control_model.input_blocks.2.0.out_layers.3.weight": "blocks.3.conv2.weight",
290
+ "control_model.input_blocks.2.0.out_layers.3.bias": "blocks.3.conv2.bias",
291
+ "control_model.input_blocks.2.1.norm.weight": "blocks.4.norm.weight",
292
+ "control_model.input_blocks.2.1.norm.bias": "blocks.4.norm.bias",
293
+ "control_model.input_blocks.2.1.proj_in.weight": "blocks.4.proj_in.weight",
294
+ "control_model.input_blocks.2.1.proj_in.bias": "blocks.4.proj_in.bias",
295
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "blocks.4.transformer_blocks.0.attn1.to_q.weight",
296
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "blocks.4.transformer_blocks.0.attn1.to_k.weight",
297
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "blocks.4.transformer_blocks.0.attn1.to_v.weight",
298
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.4.transformer_blocks.0.attn1.to_out.weight",
299
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.4.transformer_blocks.0.attn1.to_out.bias",
300
+ "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.4.transformer_blocks.0.act_fn.proj.weight",
301
+ "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.4.transformer_blocks.0.act_fn.proj.bias",
302
+ "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "blocks.4.transformer_blocks.0.ff.weight",
303
+ "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "blocks.4.transformer_blocks.0.ff.bias",
304
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "blocks.4.transformer_blocks.0.attn2.to_q.weight",
305
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "blocks.4.transformer_blocks.0.attn2.to_k.weight",
306
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "blocks.4.transformer_blocks.0.attn2.to_v.weight",
307
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.4.transformer_blocks.0.attn2.to_out.weight",
308
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.4.transformer_blocks.0.attn2.to_out.bias",
309
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "blocks.4.transformer_blocks.0.norm1.weight",
310
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "blocks.4.transformer_blocks.0.norm1.bias",
311
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "blocks.4.transformer_blocks.0.norm2.weight",
312
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "blocks.4.transformer_blocks.0.norm2.bias",
313
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "blocks.4.transformer_blocks.0.norm3.weight",
314
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "blocks.4.transformer_blocks.0.norm3.bias",
315
+ "control_model.input_blocks.2.1.proj_out.weight": "blocks.4.proj_out.weight",
316
+ "control_model.input_blocks.2.1.proj_out.bias": "blocks.4.proj_out.bias",
317
+ "control_model.input_blocks.3.0.op.weight": "blocks.6.conv.weight",
318
+ "control_model.input_blocks.3.0.op.bias": "blocks.6.conv.bias",
319
+ "control_model.input_blocks.4.0.in_layers.0.weight": "blocks.8.norm1.weight",
320
+ "control_model.input_blocks.4.0.in_layers.0.bias": "blocks.8.norm1.bias",
321
+ "control_model.input_blocks.4.0.in_layers.2.weight": "blocks.8.conv1.weight",
322
+ "control_model.input_blocks.4.0.in_layers.2.bias": "blocks.8.conv1.bias",
323
+ "control_model.input_blocks.4.0.emb_layers.1.weight": "blocks.8.time_emb_proj.weight",
324
+ "control_model.input_blocks.4.0.emb_layers.1.bias": "blocks.8.time_emb_proj.bias",
325
+ "control_model.input_blocks.4.0.out_layers.0.weight": "blocks.8.norm2.weight",
326
+ "control_model.input_blocks.4.0.out_layers.0.bias": "blocks.8.norm2.bias",
327
+ "control_model.input_blocks.4.0.out_layers.3.weight": "blocks.8.conv2.weight",
328
+ "control_model.input_blocks.4.0.out_layers.3.bias": "blocks.8.conv2.bias",
329
+ "control_model.input_blocks.4.0.skip_connection.weight": "blocks.8.conv_shortcut.weight",
330
+ "control_model.input_blocks.4.0.skip_connection.bias": "blocks.8.conv_shortcut.bias",
331
+ "control_model.input_blocks.4.1.norm.weight": "blocks.9.norm.weight",
332
+ "control_model.input_blocks.4.1.norm.bias": "blocks.9.norm.bias",
333
+ "control_model.input_blocks.4.1.proj_in.weight": "blocks.9.proj_in.weight",
334
+ "control_model.input_blocks.4.1.proj_in.bias": "blocks.9.proj_in.bias",
335
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.9.transformer_blocks.0.attn1.to_q.weight",
336
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.9.transformer_blocks.0.attn1.to_k.weight",
337
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.9.transformer_blocks.0.attn1.to_v.weight",
338
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.9.transformer_blocks.0.attn1.to_out.weight",
339
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.9.transformer_blocks.0.attn1.to_out.bias",
340
+ "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.9.transformer_blocks.0.act_fn.proj.weight",
341
+ "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.9.transformer_blocks.0.act_fn.proj.bias",
342
+ "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.9.transformer_blocks.0.ff.weight",
343
+ "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.9.transformer_blocks.0.ff.bias",
344
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.9.transformer_blocks.0.attn2.to_q.weight",
345
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.9.transformer_blocks.0.attn2.to_k.weight",
346
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.9.transformer_blocks.0.attn2.to_v.weight",
347
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.9.transformer_blocks.0.attn2.to_out.weight",
348
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.9.transformer_blocks.0.attn2.to_out.bias",
349
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.9.transformer_blocks.0.norm1.weight",
350
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.9.transformer_blocks.0.norm1.bias",
351
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.9.transformer_blocks.0.norm2.weight",
352
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.9.transformer_blocks.0.norm2.bias",
353
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.9.transformer_blocks.0.norm3.weight",
354
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.9.transformer_blocks.0.norm3.bias",
355
+ "control_model.input_blocks.4.1.proj_out.weight": "blocks.9.proj_out.weight",
356
+ "control_model.input_blocks.4.1.proj_out.bias": "blocks.9.proj_out.bias",
357
+ "control_model.input_blocks.5.0.in_layers.0.weight": "blocks.11.norm1.weight",
358
+ "control_model.input_blocks.5.0.in_layers.0.bias": "blocks.11.norm1.bias",
359
+ "control_model.input_blocks.5.0.in_layers.2.weight": "blocks.11.conv1.weight",
360
+ "control_model.input_blocks.5.0.in_layers.2.bias": "blocks.11.conv1.bias",
361
+ "control_model.input_blocks.5.0.emb_layers.1.weight": "blocks.11.time_emb_proj.weight",
362
+ "control_model.input_blocks.5.0.emb_layers.1.bias": "blocks.11.time_emb_proj.bias",
363
+ "control_model.input_blocks.5.0.out_layers.0.weight": "blocks.11.norm2.weight",
364
+ "control_model.input_blocks.5.0.out_layers.0.bias": "blocks.11.norm2.bias",
365
+ "control_model.input_blocks.5.0.out_layers.3.weight": "blocks.11.conv2.weight",
366
+ "control_model.input_blocks.5.0.out_layers.3.bias": "blocks.11.conv2.bias",
367
+ "control_model.input_blocks.5.1.norm.weight": "blocks.12.norm.weight",
368
+ "control_model.input_blocks.5.1.norm.bias": "blocks.12.norm.bias",
369
+ "control_model.input_blocks.5.1.proj_in.weight": "blocks.12.proj_in.weight",
370
+ "control_model.input_blocks.5.1.proj_in.bias": "blocks.12.proj_in.bias",
371
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.12.transformer_blocks.0.attn1.to_q.weight",
372
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.12.transformer_blocks.0.attn1.to_k.weight",
373
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.12.transformer_blocks.0.attn1.to_v.weight",
374
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.12.transformer_blocks.0.attn1.to_out.weight",
375
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.12.transformer_blocks.0.attn1.to_out.bias",
376
+ "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.12.transformer_blocks.0.act_fn.proj.weight",
377
+ "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.12.transformer_blocks.0.act_fn.proj.bias",
378
+ "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.12.transformer_blocks.0.ff.weight",
379
+ "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.12.transformer_blocks.0.ff.bias",
380
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.12.transformer_blocks.0.attn2.to_q.weight",
381
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.12.transformer_blocks.0.attn2.to_k.weight",
382
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.12.transformer_blocks.0.attn2.to_v.weight",
383
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.12.transformer_blocks.0.attn2.to_out.weight",
384
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.12.transformer_blocks.0.attn2.to_out.bias",
385
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.12.transformer_blocks.0.norm1.weight",
386
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.12.transformer_blocks.0.norm1.bias",
387
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.12.transformer_blocks.0.norm2.weight",
388
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.12.transformer_blocks.0.norm2.bias",
389
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.12.transformer_blocks.0.norm3.weight",
390
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.12.transformer_blocks.0.norm3.bias",
391
+ "control_model.input_blocks.5.1.proj_out.weight": "blocks.12.proj_out.weight",
392
+ "control_model.input_blocks.5.1.proj_out.bias": "blocks.12.proj_out.bias",
393
+ "control_model.input_blocks.6.0.op.weight": "blocks.14.conv.weight",
394
+ "control_model.input_blocks.6.0.op.bias": "blocks.14.conv.bias",
395
+ "control_model.input_blocks.7.0.in_layers.0.weight": "blocks.16.norm1.weight",
396
+ "control_model.input_blocks.7.0.in_layers.0.bias": "blocks.16.norm1.bias",
397
+ "control_model.input_blocks.7.0.in_layers.2.weight": "blocks.16.conv1.weight",
398
+ "control_model.input_blocks.7.0.in_layers.2.bias": "blocks.16.conv1.bias",
399
+ "control_model.input_blocks.7.0.emb_layers.1.weight": "blocks.16.time_emb_proj.weight",
400
+ "control_model.input_blocks.7.0.emb_layers.1.bias": "blocks.16.time_emb_proj.bias",
401
+ "control_model.input_blocks.7.0.out_layers.0.weight": "blocks.16.norm2.weight",
402
+ "control_model.input_blocks.7.0.out_layers.0.bias": "blocks.16.norm2.bias",
403
+ "control_model.input_blocks.7.0.out_layers.3.weight": "blocks.16.conv2.weight",
404
+ "control_model.input_blocks.7.0.out_layers.3.bias": "blocks.16.conv2.bias",
405
+ "control_model.input_blocks.7.0.skip_connection.weight": "blocks.16.conv_shortcut.weight",
406
+ "control_model.input_blocks.7.0.skip_connection.bias": "blocks.16.conv_shortcut.bias",
407
+ "control_model.input_blocks.7.1.norm.weight": "blocks.17.norm.weight",
408
+ "control_model.input_blocks.7.1.norm.bias": "blocks.17.norm.bias",
409
+ "control_model.input_blocks.7.1.proj_in.weight": "blocks.17.proj_in.weight",
410
+ "control_model.input_blocks.7.1.proj_in.bias": "blocks.17.proj_in.bias",
411
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.17.transformer_blocks.0.attn1.to_q.weight",
412
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.17.transformer_blocks.0.attn1.to_k.weight",
413
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.17.transformer_blocks.0.attn1.to_v.weight",
414
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.17.transformer_blocks.0.attn1.to_out.weight",
415
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.17.transformer_blocks.0.attn1.to_out.bias",
416
+ "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.17.transformer_blocks.0.act_fn.proj.weight",
417
+ "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.17.transformer_blocks.0.act_fn.proj.bias",
418
+ "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.17.transformer_blocks.0.ff.weight",
419
+ "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.17.transformer_blocks.0.ff.bias",
420
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.17.transformer_blocks.0.attn2.to_q.weight",
421
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.17.transformer_blocks.0.attn2.to_k.weight",
422
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.17.transformer_blocks.0.attn2.to_v.weight",
423
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.17.transformer_blocks.0.attn2.to_out.weight",
424
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.17.transformer_blocks.0.attn2.to_out.bias",
425
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.17.transformer_blocks.0.norm1.weight",
426
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.17.transformer_blocks.0.norm1.bias",
427
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.17.transformer_blocks.0.norm2.weight",
428
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.17.transformer_blocks.0.norm2.bias",
429
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.17.transformer_blocks.0.norm3.weight",
430
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.17.transformer_blocks.0.norm3.bias",
431
+ "control_model.input_blocks.7.1.proj_out.weight": "blocks.17.proj_out.weight",
432
+ "control_model.input_blocks.7.1.proj_out.bias": "blocks.17.proj_out.bias",
433
+ "control_model.input_blocks.8.0.in_layers.0.weight": "blocks.19.norm1.weight",
434
+ "control_model.input_blocks.8.0.in_layers.0.bias": "blocks.19.norm1.bias",
435
+ "control_model.input_blocks.8.0.in_layers.2.weight": "blocks.19.conv1.weight",
436
+ "control_model.input_blocks.8.0.in_layers.2.bias": "blocks.19.conv1.bias",
437
+ "control_model.input_blocks.8.0.emb_layers.1.weight": "blocks.19.time_emb_proj.weight",
438
+ "control_model.input_blocks.8.0.emb_layers.1.bias": "blocks.19.time_emb_proj.bias",
439
+ "control_model.input_blocks.8.0.out_layers.0.weight": "blocks.19.norm2.weight",
440
+ "control_model.input_blocks.8.0.out_layers.0.bias": "blocks.19.norm2.bias",
441
+ "control_model.input_blocks.8.0.out_layers.3.weight": "blocks.19.conv2.weight",
442
+ "control_model.input_blocks.8.0.out_layers.3.bias": "blocks.19.conv2.bias",
443
+ "control_model.input_blocks.8.1.norm.weight": "blocks.20.norm.weight",
444
+ "control_model.input_blocks.8.1.norm.bias": "blocks.20.norm.bias",
445
+ "control_model.input_blocks.8.1.proj_in.weight": "blocks.20.proj_in.weight",
446
+ "control_model.input_blocks.8.1.proj_in.bias": "blocks.20.proj_in.bias",
447
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.20.transformer_blocks.0.attn1.to_q.weight",
448
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.20.transformer_blocks.0.attn1.to_k.weight",
449
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.20.transformer_blocks.0.attn1.to_v.weight",
450
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.20.transformer_blocks.0.attn1.to_out.weight",
451
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.20.transformer_blocks.0.attn1.to_out.bias",
452
+ "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.20.transformer_blocks.0.act_fn.proj.weight",
453
+ "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.20.transformer_blocks.0.act_fn.proj.bias",
454
+ "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.20.transformer_blocks.0.ff.weight",
455
+ "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.20.transformer_blocks.0.ff.bias",
456
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.20.transformer_blocks.0.attn2.to_q.weight",
457
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.20.transformer_blocks.0.attn2.to_k.weight",
458
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.20.transformer_blocks.0.attn2.to_v.weight",
459
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.20.transformer_blocks.0.attn2.to_out.weight",
460
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.20.transformer_blocks.0.attn2.to_out.bias",
461
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.20.transformer_blocks.0.norm1.weight",
462
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.20.transformer_blocks.0.norm1.bias",
463
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.20.transformer_blocks.0.norm2.weight",
464
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.20.transformer_blocks.0.norm2.bias",
465
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.20.transformer_blocks.0.norm3.weight",
466
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.20.transformer_blocks.0.norm3.bias",
467
+ "control_model.input_blocks.8.1.proj_out.weight": "blocks.20.proj_out.weight",
468
+ "control_model.input_blocks.8.1.proj_out.bias": "blocks.20.proj_out.bias",
469
+ "control_model.input_blocks.9.0.op.weight": "blocks.22.conv.weight",
470
+ "control_model.input_blocks.9.0.op.bias": "blocks.22.conv.bias",
471
+ "control_model.input_blocks.10.0.in_layers.0.weight": "blocks.24.norm1.weight",
472
+ "control_model.input_blocks.10.0.in_layers.0.bias": "blocks.24.norm1.bias",
473
+ "control_model.input_blocks.10.0.in_layers.2.weight": "blocks.24.conv1.weight",
474
+ "control_model.input_blocks.10.0.in_layers.2.bias": "blocks.24.conv1.bias",
475
+ "control_model.input_blocks.10.0.emb_layers.1.weight": "blocks.24.time_emb_proj.weight",
476
+ "control_model.input_blocks.10.0.emb_layers.1.bias": "blocks.24.time_emb_proj.bias",
477
+ "control_model.input_blocks.10.0.out_layers.0.weight": "blocks.24.norm2.weight",
478
+ "control_model.input_blocks.10.0.out_layers.0.bias": "blocks.24.norm2.bias",
479
+ "control_model.input_blocks.10.0.out_layers.3.weight": "blocks.24.conv2.weight",
480
+ "control_model.input_blocks.10.0.out_layers.3.bias": "blocks.24.conv2.bias",
481
+ "control_model.input_blocks.11.0.in_layers.0.weight": "blocks.26.norm1.weight",
482
+ "control_model.input_blocks.11.0.in_layers.0.bias": "blocks.26.norm1.bias",
483
+ "control_model.input_blocks.11.0.in_layers.2.weight": "blocks.26.conv1.weight",
484
+ "control_model.input_blocks.11.0.in_layers.2.bias": "blocks.26.conv1.bias",
485
+ "control_model.input_blocks.11.0.emb_layers.1.weight": "blocks.26.time_emb_proj.weight",
486
+ "control_model.input_blocks.11.0.emb_layers.1.bias": "blocks.26.time_emb_proj.bias",
487
+ "control_model.input_blocks.11.0.out_layers.0.weight": "blocks.26.norm2.weight",
488
+ "control_model.input_blocks.11.0.out_layers.0.bias": "blocks.26.norm2.bias",
489
+ "control_model.input_blocks.11.0.out_layers.3.weight": "blocks.26.conv2.weight",
490
+ "control_model.input_blocks.11.0.out_layers.3.bias": "blocks.26.conv2.bias",
491
+ "control_model.zero_convs.0.0.weight": "controlnet_blocks.0.weight",
492
+ "control_model.zero_convs.0.0.bias": "controlnet_blocks.0.bias",
493
+ "control_model.zero_convs.1.0.weight": "controlnet_blocks.1.weight",
494
+ "control_model.zero_convs.1.0.bias": "controlnet_blocks.0.bias",
495
+ "control_model.zero_convs.2.0.weight": "controlnet_blocks.2.weight",
496
+ "control_model.zero_convs.2.0.bias": "controlnet_blocks.0.bias",
497
+ "control_model.zero_convs.3.0.weight": "controlnet_blocks.3.weight",
498
+ "control_model.zero_convs.3.0.bias": "controlnet_blocks.0.bias",
499
+ "control_model.zero_convs.4.0.weight": "controlnet_blocks.4.weight",
500
+ "control_model.zero_convs.4.0.bias": "controlnet_blocks.4.bias",
501
+ "control_model.zero_convs.5.0.weight": "controlnet_blocks.5.weight",
502
+ "control_model.zero_convs.5.0.bias": "controlnet_blocks.4.bias",
503
+ "control_model.zero_convs.6.0.weight": "controlnet_blocks.6.weight",
504
+ "control_model.zero_convs.6.0.bias": "controlnet_blocks.4.bias",
505
+ "control_model.zero_convs.7.0.weight": "controlnet_blocks.7.weight",
506
+ "control_model.zero_convs.7.0.bias": "controlnet_blocks.7.bias",
507
+ "control_model.zero_convs.8.0.weight": "controlnet_blocks.8.weight",
508
+ "control_model.zero_convs.8.0.bias": "controlnet_blocks.7.bias",
509
+ "control_model.zero_convs.9.0.weight": "controlnet_blocks.9.weight",
510
+ "control_model.zero_convs.9.0.bias": "controlnet_blocks.7.bias",
511
+ "control_model.zero_convs.10.0.weight": "controlnet_blocks.10.weight",
512
+ "control_model.zero_convs.10.0.bias": "controlnet_blocks.7.bias",
513
+ "control_model.zero_convs.11.0.weight": "controlnet_blocks.11.weight",
514
+ "control_model.zero_convs.11.0.bias": "controlnet_blocks.7.bias",
515
+ "control_model.input_hint_block.0.weight": "controlnet_conv_in.blocks.0.weight",
516
+ "control_model.input_hint_block.0.bias": "controlnet_conv_in.blocks.0.bias",
517
+ "control_model.input_hint_block.2.weight": "controlnet_conv_in.blocks.2.weight",
518
+ "control_model.input_hint_block.2.bias": "controlnet_conv_in.blocks.2.bias",
519
+ "control_model.input_hint_block.4.weight": "controlnet_conv_in.blocks.4.weight",
520
+ "control_model.input_hint_block.4.bias": "controlnet_conv_in.blocks.4.bias",
521
+ "control_model.input_hint_block.6.weight": "controlnet_conv_in.blocks.6.weight",
522
+ "control_model.input_hint_block.6.bias": "controlnet_conv_in.blocks.6.bias",
523
+ "control_model.input_hint_block.8.weight": "controlnet_conv_in.blocks.8.weight",
524
+ "control_model.input_hint_block.8.bias": "controlnet_conv_in.blocks.8.bias",
525
+ "control_model.input_hint_block.10.weight": "controlnet_conv_in.blocks.10.weight",
526
+ "control_model.input_hint_block.10.bias": "controlnet_conv_in.blocks.10.bias",
527
+ "control_model.input_hint_block.12.weight": "controlnet_conv_in.blocks.12.weight",
528
+ "control_model.input_hint_block.12.bias": "controlnet_conv_in.blocks.12.bias",
529
+ "control_model.input_hint_block.14.weight": "controlnet_conv_in.blocks.14.weight",
530
+ "control_model.input_hint_block.14.bias": "controlnet_conv_in.blocks.14.bias",
531
+ "control_model.middle_block.0.in_layers.0.weight": "blocks.28.norm1.weight",
532
+ "control_model.middle_block.0.in_layers.0.bias": "blocks.28.norm1.bias",
533
+ "control_model.middle_block.0.in_layers.2.weight": "blocks.28.conv1.weight",
534
+ "control_model.middle_block.0.in_layers.2.bias": "blocks.28.conv1.bias",
535
+ "control_model.middle_block.0.emb_layers.1.weight": "blocks.28.time_emb_proj.weight",
536
+ "control_model.middle_block.0.emb_layers.1.bias": "blocks.28.time_emb_proj.bias",
537
+ "control_model.middle_block.0.out_layers.0.weight": "blocks.28.norm2.weight",
538
+ "control_model.middle_block.0.out_layers.0.bias": "blocks.28.norm2.bias",
539
+ "control_model.middle_block.0.out_layers.3.weight": "blocks.28.conv2.weight",
540
+ "control_model.middle_block.0.out_layers.3.bias": "blocks.28.conv2.bias",
541
+ "control_model.middle_block.1.norm.weight": "blocks.29.norm.weight",
542
+ "control_model.middle_block.1.norm.bias": "blocks.29.norm.bias",
543
+ "control_model.middle_block.1.proj_in.weight": "blocks.29.proj_in.weight",
544
+ "control_model.middle_block.1.proj_in.bias": "blocks.29.proj_in.bias",
545
+ "control_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "blocks.29.transformer_blocks.0.attn1.to_q.weight",
546
+ "control_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "blocks.29.transformer_blocks.0.attn1.to_k.weight",
547
+ "control_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "blocks.29.transformer_blocks.0.attn1.to_v.weight",
548
+ "control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.29.transformer_blocks.0.attn1.to_out.weight",
549
+ "control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.29.transformer_blocks.0.attn1.to_out.bias",
550
+ "control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.29.transformer_blocks.0.act_fn.proj.weight",
551
+ "control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.29.transformer_blocks.0.act_fn.proj.bias",
552
+ "control_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "blocks.29.transformer_blocks.0.ff.weight",
553
+ "control_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "blocks.29.transformer_blocks.0.ff.bias",
554
+ "control_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "blocks.29.transformer_blocks.0.attn2.to_q.weight",
555
+ "control_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "blocks.29.transformer_blocks.0.attn2.to_k.weight",
556
+ "control_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "blocks.29.transformer_blocks.0.attn2.to_v.weight",
557
+ "control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.29.transformer_blocks.0.attn2.to_out.weight",
558
+ "control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.29.transformer_blocks.0.attn2.to_out.bias",
559
+ "control_model.middle_block.1.transformer_blocks.0.norm1.weight": "blocks.29.transformer_blocks.0.norm1.weight",
560
+ "control_model.middle_block.1.transformer_blocks.0.norm1.bias": "blocks.29.transformer_blocks.0.norm1.bias",
561
+ "control_model.middle_block.1.transformer_blocks.0.norm2.weight": "blocks.29.transformer_blocks.0.norm2.weight",
562
+ "control_model.middle_block.1.transformer_blocks.0.norm2.bias": "blocks.29.transformer_blocks.0.norm2.bias",
563
+ "control_model.middle_block.1.transformer_blocks.0.norm3.weight": "blocks.29.transformer_blocks.0.norm3.weight",
564
+ "control_model.middle_block.1.transformer_blocks.0.norm3.bias": "blocks.29.transformer_blocks.0.norm3.bias",
565
+ "control_model.middle_block.1.proj_out.weight": "blocks.29.proj_out.weight",
566
+ "control_model.middle_block.1.proj_out.bias": "blocks.29.proj_out.bias",
567
+ "control_model.middle_block.2.in_layers.0.weight": "blocks.30.norm1.weight",
568
+ "control_model.middle_block.2.in_layers.0.bias": "blocks.30.norm1.bias",
569
+ "control_model.middle_block.2.in_layers.2.weight": "blocks.30.conv1.weight",
570
+ "control_model.middle_block.2.in_layers.2.bias": "blocks.30.conv1.bias",
571
+ "control_model.middle_block.2.emb_layers.1.weight": "blocks.30.time_emb_proj.weight",
572
+ "control_model.middle_block.2.emb_layers.1.bias": "blocks.30.time_emb_proj.bias",
573
+ "control_model.middle_block.2.out_layers.0.weight": "blocks.30.norm2.weight",
574
+ "control_model.middle_block.2.out_layers.0.bias": "blocks.30.norm2.bias",
575
+ "control_model.middle_block.2.out_layers.3.weight": "blocks.30.conv2.weight",
576
+ "control_model.middle_block.2.out_layers.3.bias": "blocks.30.conv2.bias",
577
+ "control_model.middle_block_out.0.weight": "controlnet_blocks.12.weight",
578
+ "control_model.middle_block_out.0.bias": "controlnet_blocks.7.bias",
579
+ }
580
+ state_dict_ = {}
581
+ for name in state_dict:
582
+ if name in rename_dict:
583
+ param = state_dict[name]
584
+ if ".proj_in." in name or ".proj_out." in name:
585
+ param = param.squeeze()
586
+ state_dict_[rename_dict[name]] = param
587
+ return state_dict_
diffsynth/models/sd_ipadapter.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .svd_image_encoder import SVDImageEncoder
2
+ from .sdxl_ipadapter import IpAdapterImageProjModel, IpAdapterModule, SDXLIpAdapterStateDictConverter
3
+ from transformers import CLIPImageProcessor
4
+ import torch
5
+
6
+
7
+ class IpAdapterCLIPImageEmbedder(SVDImageEncoder):
8
+ def __init__(self):
9
+ super().__init__()
10
+ self.image_processor = CLIPImageProcessor()
11
+
12
+ def forward(self, image):
13
+ pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values
14
+ pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype)
15
+ return super().forward(pixel_values)
16
+
17
+
18
+ class SDIpAdapter(torch.nn.Module):
19
+ def __init__(self):
20
+ super().__init__()
21
+ shape_list = [(768, 320)] * 2 + [(768, 640)] * 2 + [(768, 1280)] * 5 + [(768, 640)] * 3 + [(768, 320)] * 3 + [(768, 1280)] * 1
22
+ self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list])
23
+ self.image_proj = IpAdapterImageProjModel(cross_attention_dim=768, clip_embeddings_dim=1024, clip_extra_context_tokens=4)
24
+ self.set_full_adapter()
25
+
26
+ def set_full_adapter(self):
27
+ block_ids = [1, 4, 9, 12, 17, 20, 40, 43, 46, 50, 53, 56, 60, 63, 66, 29]
28
+ self.call_block_id = {(i, 0): j for j, i in enumerate(block_ids)}
29
+
30
+ def set_less_adapter(self):
31
+ # IP-Adapter for SD v1.5 doesn't support this feature.
32
+ self.set_full_adapter(self)
33
+
34
+ def forward(self, hidden_states, scale=1.0):
35
+ hidden_states = self.image_proj(hidden_states)
36
+ hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
37
+ ip_kv_dict = {}
38
+ for (block_id, transformer_id) in self.call_block_id:
39
+ ipadapter_id = self.call_block_id[(block_id, transformer_id)]
40
+ ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
41
+ if block_id not in ip_kv_dict:
42
+ ip_kv_dict[block_id] = {}
43
+ ip_kv_dict[block_id][transformer_id] = {
44
+ "ip_k": ip_k,
45
+ "ip_v": ip_v,
46
+ "scale": scale
47
+ }
48
+ return ip_kv_dict
49
+
50
+ def state_dict_converter(self):
51
+ return SDIpAdapterStateDictConverter()
52
+
53
+
54
+ class SDIpAdapterStateDictConverter(SDXLIpAdapterStateDictConverter):
55
+ def __init__(self):
56
+ pass
diffsynth/models/sd_lora.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .sd_unet import SDUNetStateDictConverter, SDUNet
3
+ from .sd_text_encoder import SDTextEncoderStateDictConverter, SDTextEncoder
4
+
5
+
6
+ class SDLoRA:
7
+ def __init__(self):
8
+ pass
9
+
10
+ def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0, device="cuda"):
11
+ special_keys = {
12
+ "down.blocks": "down_blocks",
13
+ "up.blocks": "up_blocks",
14
+ "mid.block": "mid_block",
15
+ "proj.in": "proj_in",
16
+ "proj.out": "proj_out",
17
+ "transformer.blocks": "transformer_blocks",
18
+ "to.q": "to_q",
19
+ "to.k": "to_k",
20
+ "to.v": "to_v",
21
+ "to.out": "to_out",
22
+ }
23
+ state_dict_ = {}
24
+ for key in state_dict:
25
+ if ".lora_up" not in key:
26
+ continue
27
+ if not key.startswith(lora_prefix):
28
+ continue
29
+ weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
30
+ weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
31
+ if len(weight_up.shape) == 4:
32
+ weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
33
+ weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
34
+ lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
35
+ else:
36
+ lora_weight = alpha * torch.mm(weight_up, weight_down)
37
+ target_name = key.split(".")[0].replace("_", ".")[len(lora_prefix):] + ".weight"
38
+ for special_key in special_keys:
39
+ target_name = target_name.replace(special_key, special_keys[special_key])
40
+ state_dict_[target_name] = lora_weight.cpu()
41
+ return state_dict_
42
+
43
+ def add_lora_to_unet(self, unet: SDUNet, state_dict_lora, alpha=1.0, device="cuda"):
44
+ state_dict_unet = unet.state_dict()
45
+ state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_unet_", alpha=alpha, device=device)
46
+ state_dict_lora = SDUNetStateDictConverter().from_diffusers(state_dict_lora)
47
+ if len(state_dict_lora) > 0:
48
+ for name in state_dict_lora:
49
+ state_dict_unet[name] += state_dict_lora[name].to(device=device)
50
+ unet.load_state_dict(state_dict_unet)
51
+
52
+ def add_lora_to_text_encoder(self, text_encoder: SDTextEncoder, state_dict_lora, alpha=1.0, device="cuda"):
53
+ state_dict_text_encoder = text_encoder.state_dict()
54
+ state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_te_", alpha=alpha, device=device)
55
+ state_dict_lora = SDTextEncoderStateDictConverter().from_diffusers(state_dict_lora)
56
+ if len(state_dict_lora) > 0:
57
+ for name in state_dict_lora:
58
+ state_dict_text_encoder[name] += state_dict_lora[name].to(device=device)
59
+ text_encoder.load_state_dict(state_dict_text_encoder)
60
+
diffsynth/models/sd_motion.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .sd_unet import SDUNet, Attention, GEGLU
2
+ import torch
3
+ from einops import rearrange, repeat
4
+
5
+
6
+ class TemporalTransformerBlock(torch.nn.Module):
7
+
8
+ def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32):
9
+ super().__init__()
10
+
11
+ # 1. Self-Attn
12
+ self.pe1 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
13
+ self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True)
14
+ self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
15
+
16
+ # 2. Cross-Attn
17
+ self.pe2 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
18
+ self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True)
19
+ self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
20
+
21
+ # 3. Feed-forward
22
+ self.norm3 = torch.nn.LayerNorm(dim, elementwise_affine=True)
23
+ self.act_fn = GEGLU(dim, dim * 4)
24
+ self.ff = torch.nn.Linear(dim * 4, dim)
25
+
26
+
27
+ def forward(self, hidden_states, batch_size=1):
28
+
29
+ # 1. Self-Attention
30
+ norm_hidden_states = self.norm1(hidden_states)
31
+ norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
32
+ attn_output = self.attn1(norm_hidden_states + self.pe1[:, :norm_hidden_states.shape[1]])
33
+ attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
34
+ hidden_states = attn_output + hidden_states
35
+
36
+ # 2. Cross-Attention
37
+ norm_hidden_states = self.norm2(hidden_states)
38
+ norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
39
+ attn_output = self.attn2(norm_hidden_states + self.pe2[:, :norm_hidden_states.shape[1]])
40
+ attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
41
+ hidden_states = attn_output + hidden_states
42
+
43
+ # 3. Feed-forward
44
+ norm_hidden_states = self.norm3(hidden_states)
45
+ ff_output = self.act_fn(norm_hidden_states)
46
+ ff_output = self.ff(ff_output)
47
+ hidden_states = ff_output + hidden_states
48
+
49
+ return hidden_states
50
+
51
+
52
+ class TemporalBlock(torch.nn.Module):
53
+
54
+ def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
55
+ super().__init__()
56
+ inner_dim = num_attention_heads * attention_head_dim
57
+
58
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
59
+ self.proj_in = torch.nn.Linear(in_channels, inner_dim)
60
+
61
+ self.transformer_blocks = torch.nn.ModuleList([
62
+ TemporalTransformerBlock(
63
+ inner_dim,
64
+ num_attention_heads,
65
+ attention_head_dim
66
+ )
67
+ for d in range(num_layers)
68
+ ])
69
+
70
+ self.proj_out = torch.nn.Linear(inner_dim, in_channels)
71
+
72
+ def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1):
73
+ batch, _, height, width = hidden_states.shape
74
+ residual = hidden_states
75
+
76
+ hidden_states = self.norm(hidden_states)
77
+ inner_dim = hidden_states.shape[1]
78
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
79
+ hidden_states = self.proj_in(hidden_states)
80
+
81
+ for block in self.transformer_blocks:
82
+ hidden_states = block(
83
+ hidden_states,
84
+ batch_size=batch_size
85
+ )
86
+
87
+ hidden_states = self.proj_out(hidden_states)
88
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
89
+ hidden_states = hidden_states + residual
90
+
91
+ return hidden_states, time_emb, text_emb, res_stack
92
+
93
+
94
+ class SDMotionModel(torch.nn.Module):
95
+ def __init__(self):
96
+ super().__init__()
97
+ self.motion_modules = torch.nn.ModuleList([
98
+ TemporalBlock(8, 40, 320, eps=1e-6),
99
+ TemporalBlock(8, 40, 320, eps=1e-6),
100
+ TemporalBlock(8, 80, 640, eps=1e-6),
101
+ TemporalBlock(8, 80, 640, eps=1e-6),
102
+ TemporalBlock(8, 160, 1280, eps=1e-6),
103
+ TemporalBlock(8, 160, 1280, eps=1e-6),
104
+ TemporalBlock(8, 160, 1280, eps=1e-6),
105
+ TemporalBlock(8, 160, 1280, eps=1e-6),
106
+ TemporalBlock(8, 160, 1280, eps=1e-6),
107
+ TemporalBlock(8, 160, 1280, eps=1e-6),
108
+ TemporalBlock(8, 160, 1280, eps=1e-6),
109
+ TemporalBlock(8, 160, 1280, eps=1e-6),
110
+ TemporalBlock(8, 160, 1280, eps=1e-6),
111
+ TemporalBlock(8, 160, 1280, eps=1e-6),
112
+ TemporalBlock(8, 160, 1280, eps=1e-6),
113
+ TemporalBlock(8, 80, 640, eps=1e-6),
114
+ TemporalBlock(8, 80, 640, eps=1e-6),
115
+ TemporalBlock(8, 80, 640, eps=1e-6),
116
+ TemporalBlock(8, 40, 320, eps=1e-6),
117
+ TemporalBlock(8, 40, 320, eps=1e-6),
118
+ TemporalBlock(8, 40, 320, eps=1e-6),
119
+ ])
120
+ self.call_block_id = {
121
+ 1: 0,
122
+ 4: 1,
123
+ 9: 2,
124
+ 12: 3,
125
+ 17: 4,
126
+ 20: 5,
127
+ 24: 6,
128
+ 26: 7,
129
+ 29: 8,
130
+ 32: 9,
131
+ 34: 10,
132
+ 36: 11,
133
+ 40: 12,
134
+ 43: 13,
135
+ 46: 14,
136
+ 50: 15,
137
+ 53: 16,
138
+ 56: 17,
139
+ 60: 18,
140
+ 63: 19,
141
+ 66: 20
142
+ }
143
+
144
+ def forward(self):
145
+ pass
146
+
147
+ def state_dict_converter(self):
148
+ return SDMotionModelStateDictConverter()
149
+
150
+
151
+ class SDMotionModelStateDictConverter:
152
+ def __init__(self):
153
+ pass
154
+
155
+ def from_diffusers(self, state_dict):
156
+ rename_dict = {
157
+ "norm": "norm",
158
+ "proj_in": "proj_in",
159
+ "transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q",
160
+ "transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k",
161
+ "transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v",
162
+ "transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out",
163
+ "transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1",
164
+ "transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q",
165
+ "transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k",
166
+ "transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v",
167
+ "transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out",
168
+ "transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2",
169
+ "transformer_blocks.0.norms.0": "transformer_blocks.0.norm1",
170
+ "transformer_blocks.0.norms.1": "transformer_blocks.0.norm2",
171
+ "transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj",
172
+ "transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff",
173
+ "transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3",
174
+ "proj_out": "proj_out",
175
+ }
176
+ name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")])
177
+ name_list += sorted([i for i in state_dict if i.startswith("mid_block.")])
178
+ name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")])
179
+ state_dict_ = {}
180
+ last_prefix, module_id = "", -1
181
+ for name in name_list:
182
+ names = name.split(".")
183
+ prefix_index = names.index("temporal_transformer") + 1
184
+ prefix = ".".join(names[:prefix_index])
185
+ if prefix != last_prefix:
186
+ last_prefix = prefix
187
+ module_id += 1
188
+ middle_name = ".".join(names[prefix_index:-1])
189
+ suffix = names[-1]
190
+ if "pos_encoder" in names:
191
+ rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]])
192
+ else:
193
+ rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
194
+ state_dict_[rename] = state_dict[name]
195
+ return state_dict_
196
+
197
+ def from_civitai(self, state_dict):
198
+ return self.from_diffusers(state_dict)
diffsynth/models/sd_text_encoder.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .attention import Attention
3
+
4
+
5
+ class CLIPEncoderLayer(torch.nn.Module):
6
+ def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
7
+ super().__init__()
8
+ self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
9
+ self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
10
+ self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
11
+ self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
12
+ self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
13
+
14
+ self.use_quick_gelu = use_quick_gelu
15
+
16
+ def quickGELU(self, x):
17
+ return x * torch.sigmoid(1.702 * x)
18
+
19
+ def forward(self, hidden_states, attn_mask=None):
20
+ residual = hidden_states
21
+
22
+ hidden_states = self.layer_norm1(hidden_states)
23
+ hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
24
+ hidden_states = residual + hidden_states
25
+
26
+ residual = hidden_states
27
+ hidden_states = self.layer_norm2(hidden_states)
28
+ hidden_states = self.fc1(hidden_states)
29
+ if self.use_quick_gelu:
30
+ hidden_states = self.quickGELU(hidden_states)
31
+ else:
32
+ hidden_states = torch.nn.functional.gelu(hidden_states)
33
+ hidden_states = self.fc2(hidden_states)
34
+ hidden_states = residual + hidden_states
35
+
36
+ return hidden_states
37
+
38
+
39
+ class SDTextEncoder(torch.nn.Module):
40
+ def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
41
+ super().__init__()
42
+
43
+ # token_embedding
44
+ self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
45
+
46
+ # position_embeds (This is a fixed tensor)
47
+ self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
48
+
49
+ # encoders
50
+ self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
51
+
52
+ # attn_mask
53
+ self.attn_mask = self.attention_mask(max_position_embeddings)
54
+
55
+ # final_layer_norm
56
+ self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
57
+
58
+ def attention_mask(self, length):
59
+ mask = torch.empty(length, length)
60
+ mask.fill_(float("-inf"))
61
+ mask.triu_(1)
62
+ return mask
63
+
64
+ def forward(self, input_ids, clip_skip=1):
65
+ embeds = self.token_embedding(input_ids) + self.position_embeds
66
+ attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
67
+ for encoder_id, encoder in enumerate(self.encoders):
68
+ embeds = encoder(embeds, attn_mask=attn_mask)
69
+ if encoder_id + clip_skip == len(self.encoders):
70
+ break
71
+ embeds = self.final_layer_norm(embeds)
72
+ return embeds
73
+
74
+ def state_dict_converter(self):
75
+ return SDTextEncoderStateDictConverter()
76
+
77
+
78
+ class SDTextEncoderStateDictConverter:
79
+ def __init__(self):
80
+ pass
81
+
82
+ def from_diffusers(self, state_dict):
83
+ rename_dict = {
84
+ "text_model.embeddings.token_embedding.weight": "token_embedding.weight",
85
+ "text_model.embeddings.position_embedding.weight": "position_embeds",
86
+ "text_model.final_layer_norm.weight": "final_layer_norm.weight",
87
+ "text_model.final_layer_norm.bias": "final_layer_norm.bias"
88
+ }
89
+ attn_rename_dict = {
90
+ "self_attn.q_proj": "attn.to_q",
91
+ "self_attn.k_proj": "attn.to_k",
92
+ "self_attn.v_proj": "attn.to_v",
93
+ "self_attn.out_proj": "attn.to_out",
94
+ "layer_norm1": "layer_norm1",
95
+ "layer_norm2": "layer_norm2",
96
+ "mlp.fc1": "fc1",
97
+ "mlp.fc2": "fc2",
98
+ }
99
+ state_dict_ = {}
100
+ for name in state_dict:
101
+ if name in rename_dict:
102
+ param = state_dict[name]
103
+ if name == "text_model.embeddings.position_embedding.weight":
104
+ param = param.reshape((1, param.shape[0], param.shape[1]))
105
+ state_dict_[rename_dict[name]] = param
106
+ elif name.startswith("text_model.encoder.layers."):
107
+ param = state_dict[name]
108
+ names = name.split(".")
109
+ layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
110
+ name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
111
+ state_dict_[name_] = param
112
+ return state_dict_
113
+
114
+ def from_civitai(self, state_dict):
115
+ rename_dict = {
116
+ "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
117
+ "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
118
+ "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
119
+ "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
120
+ "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
121
+ "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
122
+ "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
123
+ "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
124
+ "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
125
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
126
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
127
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
128
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
129
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
130
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
131
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
132
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
133
+ "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
134
+ "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
135
+ "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
136
+ "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
137
+ "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
138
+ "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
139
+ "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
140
+ "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
141
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
142
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
143
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
144
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
145
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
146
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
147
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
148
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
149
+ "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
150
+ "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
151
+ "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
152
+ "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
153
+ "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
154
+ "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
155
+ "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
156
+ "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
157
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
158
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
159
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
160
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
161
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
162
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
163
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
164
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
165
+ "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias",
166
+ "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight",
167
+ "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias",
168
+ "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight",
169
+ "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias",
170
+ "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight",
171
+ "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias",
172
+ "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight",
173
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias",
174
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight",
175
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias",
176
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight",
177
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias",
178
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight",
179
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias",
180
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight",
181
+ "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
182
+ "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
183
+ "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
184
+ "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
185
+ "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
186
+ "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
187
+ "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
188
+ "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
189
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
190
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
191
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
192
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
193
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
194
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
195
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
196
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
197
+ "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
198
+ "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
199
+ "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
200
+ "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
201
+ "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
202
+ "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
203
+ "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
204
+ "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
205
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
206
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
207
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
208
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
209
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
210
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
211
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
212
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
213
+ "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
214
+ "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
215
+ "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
216
+ "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
217
+ "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
218
+ "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
219
+ "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
220
+ "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
221
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
222
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
223
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
224
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
225
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
226
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
227
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
228
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
229
+ "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
230
+ "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
231
+ "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
232
+ "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
233
+ "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
234
+ "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
235
+ "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
236
+ "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
237
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
238
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
239
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
240
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
241
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
242
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
243
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
244
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
245
+ "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
246
+ "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
247
+ "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
248
+ "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
249
+ "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
250
+ "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
251
+ "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
252
+ "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
253
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
254
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
255
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
256
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
257
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
258
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
259
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
260
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
261
+ "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
262
+ "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
263
+ "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
264
+ "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
265
+ "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
266
+ "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
267
+ "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
268
+ "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
269
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
270
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
271
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
272
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
273
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
274
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
275
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
276
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
277
+ "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
278
+ "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
279
+ "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
280
+ "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
281
+ "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
282
+ "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
283
+ "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
284
+ "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
285
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
286
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
287
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
288
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
289
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
290
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
291
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
292
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
293
+ "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
294
+ "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
295
+ "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
296
+ "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
297
+ "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
298
+ "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
299
+ "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
300
+ "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
301
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
302
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
303
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
304
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
305
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
306
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
307
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
308
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
309
+ "cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias",
310
+ "cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight",
311
+ "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds"
312
+ }
313
+ state_dict_ = {}
314
+ for name in state_dict:
315
+ if name in rename_dict:
316
+ param = state_dict[name]
317
+ if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight":
318
+ param = param.reshape((1, param.shape[0], param.shape[1]))
319
+ state_dict_[rename_dict[name]] = param
320
+ return state_dict_
diffsynth/models/sd_unet.py ADDED
The diff for this file is too large to render. See raw diff
 
diffsynth/models/sd_vae_decoder.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .attention import Attention
3
+ from .sd_unet import ResnetBlock, UpSampler
4
+ from .tiler import TileWorker
5
+
6
+
7
+ class VAEAttentionBlock(torch.nn.Module):
8
+
9
+ def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
10
+ super().__init__()
11
+ inner_dim = num_attention_heads * attention_head_dim
12
+
13
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
14
+
15
+ self.transformer_blocks = torch.nn.ModuleList([
16
+ Attention(
17
+ inner_dim,
18
+ num_attention_heads,
19
+ attention_head_dim,
20
+ bias_q=True,
21
+ bias_kv=True,
22
+ bias_out=True
23
+ )
24
+ for d in range(num_layers)
25
+ ])
26
+
27
+ def forward(self, hidden_states, time_emb, text_emb, res_stack):
28
+ batch, _, height, width = hidden_states.shape
29
+ residual = hidden_states
30
+
31
+ hidden_states = self.norm(hidden_states)
32
+ inner_dim = hidden_states.shape[1]
33
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
34
+
35
+ for block in self.transformer_blocks:
36
+ hidden_states = block(hidden_states)
37
+
38
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
39
+ hidden_states = hidden_states + residual
40
+
41
+ return hidden_states, time_emb, text_emb, res_stack
42
+
43
+
44
+ class SDVAEDecoder(torch.nn.Module):
45
+ def __init__(self):
46
+ super().__init__()
47
+ self.scaling_factor = 0.18215
48
+ self.post_quant_conv = torch.nn.Conv2d(4, 4, kernel_size=1)
49
+ self.conv_in = torch.nn.Conv2d(4, 512, kernel_size=3, padding=1)
50
+
51
+ self.blocks = torch.nn.ModuleList([
52
+ # UNetMidBlock2D
53
+ ResnetBlock(512, 512, eps=1e-6),
54
+ VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
55
+ ResnetBlock(512, 512, eps=1e-6),
56
+ # UpDecoderBlock2D
57
+ ResnetBlock(512, 512, eps=1e-6),
58
+ ResnetBlock(512, 512, eps=1e-6),
59
+ ResnetBlock(512, 512, eps=1e-6),
60
+ UpSampler(512),
61
+ # UpDecoderBlock2D
62
+ ResnetBlock(512, 512, eps=1e-6),
63
+ ResnetBlock(512, 512, eps=1e-6),
64
+ ResnetBlock(512, 512, eps=1e-6),
65
+ UpSampler(512),
66
+ # UpDecoderBlock2D
67
+ ResnetBlock(512, 256, eps=1e-6),
68
+ ResnetBlock(256, 256, eps=1e-6),
69
+ ResnetBlock(256, 256, eps=1e-6),
70
+ UpSampler(256),
71
+ # UpDecoderBlock2D
72
+ ResnetBlock(256, 128, eps=1e-6),
73
+ ResnetBlock(128, 128, eps=1e-6),
74
+ ResnetBlock(128, 128, eps=1e-6),
75
+ ])
76
+
77
+ self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-5)
78
+ self.conv_act = torch.nn.SiLU()
79
+ self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
80
+
81
+ def tiled_forward(self, sample, tile_size=64, tile_stride=32):
82
+ hidden_states = TileWorker().tiled_forward(
83
+ lambda x: self.forward(x),
84
+ sample,
85
+ tile_size,
86
+ tile_stride,
87
+ tile_device=sample.device,
88
+ tile_dtype=sample.dtype
89
+ )
90
+ return hidden_states
91
+
92
+ def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
93
+ # For VAE Decoder, we do not need to apply the tiler on each layer.
94
+ if tiled:
95
+ return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
96
+
97
+ # 1. pre-process
98
+ sample = sample / self.scaling_factor
99
+ hidden_states = self.post_quant_conv(sample)
100
+ hidden_states = self.conv_in(hidden_states)
101
+ time_emb = None
102
+ text_emb = None
103
+ res_stack = None
104
+
105
+ # 2. blocks
106
+ for i, block in enumerate(self.blocks):
107
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
108
+
109
+ # 3. output
110
+ hidden_states = self.conv_norm_out(hidden_states)
111
+ hidden_states = self.conv_act(hidden_states)
112
+ hidden_states = self.conv_out(hidden_states)
113
+
114
+ return hidden_states
115
+
116
+ def state_dict_converter(self):
117
+ return SDVAEDecoderStateDictConverter()
118
+
119
+
120
+ class SDVAEDecoderStateDictConverter:
121
+ def __init__(self):
122
+ pass
123
+
124
+ def from_diffusers(self, state_dict):
125
+ # architecture
126
+ block_types = [
127
+ 'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock',
128
+ 'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
129
+ 'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
130
+ 'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
131
+ 'ResnetBlock', 'ResnetBlock', 'ResnetBlock'
132
+ ]
133
+
134
+ # Rename each parameter
135
+ local_rename_dict = {
136
+ "post_quant_conv": "post_quant_conv",
137
+ "decoder.conv_in": "conv_in",
138
+ "decoder.mid_block.attentions.0.group_norm": "blocks.1.norm",
139
+ "decoder.mid_block.attentions.0.to_q": "blocks.1.transformer_blocks.0.to_q",
140
+ "decoder.mid_block.attentions.0.to_k": "blocks.1.transformer_blocks.0.to_k",
141
+ "decoder.mid_block.attentions.0.to_v": "blocks.1.transformer_blocks.0.to_v",
142
+ "decoder.mid_block.attentions.0.to_out.0": "blocks.1.transformer_blocks.0.to_out",
143
+ "decoder.mid_block.resnets.0.norm1": "blocks.0.norm1",
144
+ "decoder.mid_block.resnets.0.conv1": "blocks.0.conv1",
145
+ "decoder.mid_block.resnets.0.norm2": "blocks.0.norm2",
146
+ "decoder.mid_block.resnets.0.conv2": "blocks.0.conv2",
147
+ "decoder.mid_block.resnets.1.norm1": "blocks.2.norm1",
148
+ "decoder.mid_block.resnets.1.conv1": "blocks.2.conv1",
149
+ "decoder.mid_block.resnets.1.norm2": "blocks.2.norm2",
150
+ "decoder.mid_block.resnets.1.conv2": "blocks.2.conv2",
151
+ "decoder.conv_norm_out": "conv_norm_out",
152
+ "decoder.conv_out": "conv_out",
153
+ }
154
+ name_list = sorted([name for name in state_dict])
155
+ rename_dict = {}
156
+ block_id = {"ResnetBlock": 2, "DownSampler": 2, "UpSampler": 2}
157
+ last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
158
+ for name in name_list:
159
+ names = name.split(".")
160
+ name_prefix = ".".join(names[:-1])
161
+ if name_prefix in local_rename_dict:
162
+ rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
163
+ elif name.startswith("decoder.up_blocks"):
164
+ block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
165
+ block_type_with_id = ".".join(names[:5])
166
+ if block_type_with_id != last_block_type_with_id[block_type]:
167
+ block_id[block_type] += 1
168
+ last_block_type_with_id[block_type] = block_type_with_id
169
+ while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
170
+ block_id[block_type] += 1
171
+ block_type_with_id = ".".join(names[:5])
172
+ names = ["blocks", str(block_id[block_type])] + names[5:]
173
+ rename_dict[name] = ".".join(names)
174
+
175
+ # Convert state_dict
176
+ state_dict_ = {}
177
+ for name, param in state_dict.items():
178
+ if name in rename_dict:
179
+ state_dict_[rename_dict[name]] = param
180
+ return state_dict_
181
+
182
+ def from_civitai(self, state_dict):
183
+ rename_dict = {
184
+ "first_stage_model.decoder.conv_in.bias": "conv_in.bias",
185
+ "first_stage_model.decoder.conv_in.weight": "conv_in.weight",
186
+ "first_stage_model.decoder.conv_out.bias": "conv_out.bias",
187
+ "first_stage_model.decoder.conv_out.weight": "conv_out.weight",
188
+ "first_stage_model.decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
189
+ "first_stage_model.decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
190
+ "first_stage_model.decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
191
+ "first_stage_model.decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
192
+ "first_stage_model.decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
193
+ "first_stage_model.decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
194
+ "first_stage_model.decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
195
+ "first_stage_model.decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
196
+ "first_stage_model.decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
197
+ "first_stage_model.decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
198
+ "first_stage_model.decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
199
+ "first_stage_model.decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
200
+ "first_stage_model.decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
201
+ "first_stage_model.decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
202
+ "first_stage_model.decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
203
+ "first_stage_model.decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
204
+ "first_stage_model.decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
205
+ "first_stage_model.decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
206
+ "first_stage_model.decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
207
+ "first_stage_model.decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
208
+ "first_stage_model.decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
209
+ "first_stage_model.decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
210
+ "first_stage_model.decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
211
+ "first_stage_model.decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
212
+ "first_stage_model.decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
213
+ "first_stage_model.decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
214
+ "first_stage_model.decoder.norm_out.bias": "conv_norm_out.bias",
215
+ "first_stage_model.decoder.norm_out.weight": "conv_norm_out.weight",
216
+ "first_stage_model.decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
217
+ "first_stage_model.decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
218
+ "first_stage_model.decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
219
+ "first_stage_model.decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
220
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
221
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
222
+ "first_stage_model.decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
223
+ "first_stage_model.decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
224
+ "first_stage_model.decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
225
+ "first_stage_model.decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
226
+ "first_stage_model.decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
227
+ "first_stage_model.decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
228
+ "first_stage_model.decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
229
+ "first_stage_model.decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
230
+ "first_stage_model.decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
231
+ "first_stage_model.decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
232
+ "first_stage_model.decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
233
+ "first_stage_model.decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
234
+ "first_stage_model.decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
235
+ "first_stage_model.decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
236
+ "first_stage_model.decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
237
+ "first_stage_model.decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
238
+ "first_stage_model.decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
239
+ "first_stage_model.decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
240
+ "first_stage_model.decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
241
+ "first_stage_model.decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
242
+ "first_stage_model.decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
243
+ "first_stage_model.decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
244
+ "first_stage_model.decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
245
+ "first_stage_model.decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
246
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
247
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
248
+ "first_stage_model.decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
249
+ "first_stage_model.decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
250
+ "first_stage_model.decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
251
+ "first_stage_model.decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
252
+ "first_stage_model.decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
253
+ "first_stage_model.decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
254
+ "first_stage_model.decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
255
+ "first_stage_model.decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
256
+ "first_stage_model.decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
257
+ "first_stage_model.decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
258
+ "first_stage_model.decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
259
+ "first_stage_model.decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
260
+ "first_stage_model.decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
261
+ "first_stage_model.decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
262
+ "first_stage_model.decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
263
+ "first_stage_model.decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
264
+ "first_stage_model.decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
265
+ "first_stage_model.decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
266
+ "first_stage_model.decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
267
+ "first_stage_model.decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
268
+ "first_stage_model.decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
269
+ "first_stage_model.decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
270
+ "first_stage_model.decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
271
+ "first_stage_model.decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
272
+ "first_stage_model.decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
273
+ "first_stage_model.decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
274
+ "first_stage_model.decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
275
+ "first_stage_model.decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
276
+ "first_stage_model.decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
277
+ "first_stage_model.decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
278
+ "first_stage_model.decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
279
+ "first_stage_model.decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
280
+ "first_stage_model.decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
281
+ "first_stage_model.decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
282
+ "first_stage_model.decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
283
+ "first_stage_model.decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
284
+ "first_stage_model.decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
285
+ "first_stage_model.decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
286
+ "first_stage_model.decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
287
+ "first_stage_model.decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
288
+ "first_stage_model.decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
289
+ "first_stage_model.decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
290
+ "first_stage_model.decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
291
+ "first_stage_model.decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
292
+ "first_stage_model.decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
293
+ "first_stage_model.decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
294
+ "first_stage_model.decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
295
+ "first_stage_model.decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
296
+ "first_stage_model.decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
297
+ "first_stage_model.decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
298
+ "first_stage_model.decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
299
+ "first_stage_model.decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
300
+ "first_stage_model.decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
301
+ "first_stage_model.decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
302
+ "first_stage_model.decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
303
+ "first_stage_model.decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
304
+ "first_stage_model.decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
305
+ "first_stage_model.decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
306
+ "first_stage_model.decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
307
+ "first_stage_model.decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
308
+ "first_stage_model.decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
309
+ "first_stage_model.decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
310
+ "first_stage_model.decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
311
+ "first_stage_model.decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
312
+ "first_stage_model.decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
313
+ "first_stage_model.decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
314
+ "first_stage_model.decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
315
+ "first_stage_model.decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
316
+ "first_stage_model.decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
317
+ "first_stage_model.decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
318
+ "first_stage_model.decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
319
+ "first_stage_model.decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
320
+ "first_stage_model.decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
321
+ "first_stage_model.decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
322
+ "first_stage_model.post_quant_conv.bias": "post_quant_conv.bias",
323
+ "first_stage_model.post_quant_conv.weight": "post_quant_conv.weight",
324
+ }
325
+ state_dict_ = {}
326
+ for name in state_dict:
327
+ if name in rename_dict:
328
+ param = state_dict[name]
329
+ if "transformer_blocks" in rename_dict[name]:
330
+ param = param.squeeze()
331
+ state_dict_[rename_dict[name]] = param
332
+ return state_dict_
diffsynth/models/sd_vae_encoder.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .sd_unet import ResnetBlock, DownSampler
3
+ from .sd_vae_decoder import VAEAttentionBlock
4
+ from .tiler import TileWorker
5
+ from einops import rearrange
6
+
7
+
8
+ class SDVAEEncoder(torch.nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+ self.scaling_factor = 0.18215
12
+ self.quant_conv = torch.nn.Conv2d(8, 8, kernel_size=1)
13
+ self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
14
+
15
+ self.blocks = torch.nn.ModuleList([
16
+ # DownEncoderBlock2D
17
+ ResnetBlock(128, 128, eps=1e-6),
18
+ ResnetBlock(128, 128, eps=1e-6),
19
+ DownSampler(128, padding=0, extra_padding=True),
20
+ # DownEncoderBlock2D
21
+ ResnetBlock(128, 256, eps=1e-6),
22
+ ResnetBlock(256, 256, eps=1e-6),
23
+ DownSampler(256, padding=0, extra_padding=True),
24
+ # DownEncoderBlock2D
25
+ ResnetBlock(256, 512, eps=1e-6),
26
+ ResnetBlock(512, 512, eps=1e-6),
27
+ DownSampler(512, padding=0, extra_padding=True),
28
+ # DownEncoderBlock2D
29
+ ResnetBlock(512, 512, eps=1e-6),
30
+ ResnetBlock(512, 512, eps=1e-6),
31
+ # UNetMidBlock2D
32
+ ResnetBlock(512, 512, eps=1e-6),
33
+ VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
34
+ ResnetBlock(512, 512, eps=1e-6),
35
+ ])
36
+
37
+ self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
38
+ self.conv_act = torch.nn.SiLU()
39
+ self.conv_out = torch.nn.Conv2d(512, 8, kernel_size=3, padding=1)
40
+
41
+ def tiled_forward(self, sample, tile_size=64, tile_stride=32):
42
+ hidden_states = TileWorker().tiled_forward(
43
+ lambda x: self.forward(x),
44
+ sample,
45
+ tile_size,
46
+ tile_stride,
47
+ tile_device=sample.device,
48
+ tile_dtype=sample.dtype
49
+ )
50
+ return hidden_states
51
+
52
+ def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
53
+ # For VAE Decoder, we do not need to apply the tiler on each layer.
54
+ if tiled:
55
+ return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
56
+
57
+ # 1. pre-process
58
+ hidden_states = self.conv_in(sample)
59
+ time_emb = None
60
+ text_emb = None
61
+ res_stack = None
62
+
63
+ # 2. blocks
64
+ for i, block in enumerate(self.blocks):
65
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
66
+
67
+ # 3. output
68
+ hidden_states = self.conv_norm_out(hidden_states)
69
+ hidden_states = self.conv_act(hidden_states)
70
+ hidden_states = self.conv_out(hidden_states)
71
+ hidden_states = self.quant_conv(hidden_states)
72
+ hidden_states = hidden_states[:, :4]
73
+ hidden_states *= self.scaling_factor
74
+
75
+ return hidden_states
76
+
77
+ def encode_video(self, sample, batch_size=8):
78
+ B = sample.shape[0]
79
+ hidden_states = []
80
+
81
+ for i in range(0, sample.shape[2], batch_size):
82
+
83
+ j = min(i + batch_size, sample.shape[2])
84
+ sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
85
+
86
+ hidden_states_batch = self(sample_batch)
87
+ hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
88
+
89
+ hidden_states.append(hidden_states_batch)
90
+
91
+ hidden_states = torch.concat(hidden_states, dim=2)
92
+ return hidden_states
93
+
94
+ def state_dict_converter(self):
95
+ return SDVAEEncoderStateDictConverter()
96
+
97
+
98
+ class SDVAEEncoderStateDictConverter:
99
+ def __init__(self):
100
+ pass
101
+
102
+ def from_diffusers(self, state_dict):
103
+ # architecture
104
+ block_types = [
105
+ 'ResnetBlock', 'ResnetBlock', 'DownSampler',
106
+ 'ResnetBlock', 'ResnetBlock', 'DownSampler',
107
+ 'ResnetBlock', 'ResnetBlock', 'DownSampler',
108
+ 'ResnetBlock', 'ResnetBlock',
109
+ 'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock'
110
+ ]
111
+
112
+ # Rename each parameter
113
+ local_rename_dict = {
114
+ "quant_conv": "quant_conv",
115
+ "encoder.conv_in": "conv_in",
116
+ "encoder.mid_block.attentions.0.group_norm": "blocks.12.norm",
117
+ "encoder.mid_block.attentions.0.to_q": "blocks.12.transformer_blocks.0.to_q",
118
+ "encoder.mid_block.attentions.0.to_k": "blocks.12.transformer_blocks.0.to_k",
119
+ "encoder.mid_block.attentions.0.to_v": "blocks.12.transformer_blocks.0.to_v",
120
+ "encoder.mid_block.attentions.0.to_out.0": "blocks.12.transformer_blocks.0.to_out",
121
+ "encoder.mid_block.resnets.0.norm1": "blocks.11.norm1",
122
+ "encoder.mid_block.resnets.0.conv1": "blocks.11.conv1",
123
+ "encoder.mid_block.resnets.0.norm2": "blocks.11.norm2",
124
+ "encoder.mid_block.resnets.0.conv2": "blocks.11.conv2",
125
+ "encoder.mid_block.resnets.1.norm1": "blocks.13.norm1",
126
+ "encoder.mid_block.resnets.1.conv1": "blocks.13.conv1",
127
+ "encoder.mid_block.resnets.1.norm2": "blocks.13.norm2",
128
+ "encoder.mid_block.resnets.1.conv2": "blocks.13.conv2",
129
+ "encoder.conv_norm_out": "conv_norm_out",
130
+ "encoder.conv_out": "conv_out",
131
+ }
132
+ name_list = sorted([name for name in state_dict])
133
+ rename_dict = {}
134
+ block_id = {"ResnetBlock": -1, "DownSampler": -1, "UpSampler": -1}
135
+ last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
136
+ for name in name_list:
137
+ names = name.split(".")
138
+ name_prefix = ".".join(names[:-1])
139
+ if name_prefix in local_rename_dict:
140
+ rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
141
+ elif name.startswith("encoder.down_blocks"):
142
+ block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
143
+ block_type_with_id = ".".join(names[:5])
144
+ if block_type_with_id != last_block_type_with_id[block_type]:
145
+ block_id[block_type] += 1
146
+ last_block_type_with_id[block_type] = block_type_with_id
147
+ while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
148
+ block_id[block_type] += 1
149
+ block_type_with_id = ".".join(names[:5])
150
+ names = ["blocks", str(block_id[block_type])] + names[5:]
151
+ rename_dict[name] = ".".join(names)
152
+
153
+ # Convert state_dict
154
+ state_dict_ = {}
155
+ for name, param in state_dict.items():
156
+ if name in rename_dict:
157
+ state_dict_[rename_dict[name]] = param
158
+ return state_dict_
159
+
160
+ def from_civitai(self, state_dict):
161
+ rename_dict = {
162
+ "first_stage_model.encoder.conv_in.bias": "conv_in.bias",
163
+ "first_stage_model.encoder.conv_in.weight": "conv_in.weight",
164
+ "first_stage_model.encoder.conv_out.bias": "conv_out.bias",
165
+ "first_stage_model.encoder.conv_out.weight": "conv_out.weight",
166
+ "first_stage_model.encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
167
+ "first_stage_model.encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
168
+ "first_stage_model.encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
169
+ "first_stage_model.encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
170
+ "first_stage_model.encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
171
+ "first_stage_model.encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
172
+ "first_stage_model.encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
173
+ "first_stage_model.encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
174
+ "first_stage_model.encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
175
+ "first_stage_model.encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
176
+ "first_stage_model.encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
177
+ "first_stage_model.encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
178
+ "first_stage_model.encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
179
+ "first_stage_model.encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
180
+ "first_stage_model.encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
181
+ "first_stage_model.encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
182
+ "first_stage_model.encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
183
+ "first_stage_model.encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
184
+ "first_stage_model.encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
185
+ "first_stage_model.encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
186
+ "first_stage_model.encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
187
+ "first_stage_model.encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
188
+ "first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
189
+ "first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
190
+ "first_stage_model.encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
191
+ "first_stage_model.encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
192
+ "first_stage_model.encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
193
+ "first_stage_model.encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
194
+ "first_stage_model.encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
195
+ "first_stage_model.encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
196
+ "first_stage_model.encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
197
+ "first_stage_model.encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
198
+ "first_stage_model.encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
199
+ "first_stage_model.encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
200
+ "first_stage_model.encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
201
+ "first_stage_model.encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
202
+ "first_stage_model.encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
203
+ "first_stage_model.encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
204
+ "first_stage_model.encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
205
+ "first_stage_model.encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
206
+ "first_stage_model.encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
207
+ "first_stage_model.encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
208
+ "first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
209
+ "first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
210
+ "first_stage_model.encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
211
+ "first_stage_model.encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
212
+ "first_stage_model.encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
213
+ "first_stage_model.encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
214
+ "first_stage_model.encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
215
+ "first_stage_model.encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
216
+ "first_stage_model.encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
217
+ "first_stage_model.encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
218
+ "first_stage_model.encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
219
+ "first_stage_model.encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
220
+ "first_stage_model.encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
221
+ "first_stage_model.encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
222
+ "first_stage_model.encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
223
+ "first_stage_model.encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
224
+ "first_stage_model.encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
225
+ "first_stage_model.encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
226
+ "first_stage_model.encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
227
+ "first_stage_model.encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
228
+ "first_stage_model.encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
229
+ "first_stage_model.encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
230
+ "first_stage_model.encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
231
+ "first_stage_model.encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
232
+ "first_stage_model.encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
233
+ "first_stage_model.encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
234
+ "first_stage_model.encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
235
+ "first_stage_model.encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
236
+ "first_stage_model.encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
237
+ "first_stage_model.encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
238
+ "first_stage_model.encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
239
+ "first_stage_model.encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
240
+ "first_stage_model.encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
241
+ "first_stage_model.encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
242
+ "first_stage_model.encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
243
+ "first_stage_model.encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
244
+ "first_stage_model.encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
245
+ "first_stage_model.encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
246
+ "first_stage_model.encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
247
+ "first_stage_model.encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
248
+ "first_stage_model.encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
249
+ "first_stage_model.encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
250
+ "first_stage_model.encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
251
+ "first_stage_model.encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
252
+ "first_stage_model.encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
253
+ "first_stage_model.encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
254
+ "first_stage_model.encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
255
+ "first_stage_model.encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
256
+ "first_stage_model.encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
257
+ "first_stage_model.encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
258
+ "first_stage_model.encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
259
+ "first_stage_model.encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
260
+ "first_stage_model.encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
261
+ "first_stage_model.encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
262
+ "first_stage_model.encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
263
+ "first_stage_model.encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
264
+ "first_stage_model.encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
265
+ "first_stage_model.encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
266
+ "first_stage_model.encoder.norm_out.bias": "conv_norm_out.bias",
267
+ "first_stage_model.encoder.norm_out.weight": "conv_norm_out.weight",
268
+ "first_stage_model.quant_conv.bias": "quant_conv.bias",
269
+ "first_stage_model.quant_conv.weight": "quant_conv.weight",
270
+ }
271
+ state_dict_ = {}
272
+ for name in state_dict:
273
+ if name in rename_dict:
274
+ param = state_dict[name]
275
+ if "transformer_blocks" in rename_dict[name]:
276
+ param = param.squeeze()
277
+ state_dict_[rename_dict[name]] = param
278
+ return state_dict_
diffsynth/models/sdxl_ipadapter.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .svd_image_encoder import SVDImageEncoder
2
+ from transformers import CLIPImageProcessor
3
+ import torch
4
+
5
+
6
+ class IpAdapterXLCLIPImageEmbedder(SVDImageEncoder):
7
+ def __init__(self):
8
+ super().__init__(embed_dim=1664, encoder_intermediate_size=8192, projection_dim=1280, num_encoder_layers=48, num_heads=16, head_dim=104)
9
+ self.image_processor = CLIPImageProcessor()
10
+
11
+ def forward(self, image):
12
+ pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values
13
+ pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype)
14
+ return super().forward(pixel_values)
15
+
16
+
17
+ class IpAdapterImageProjModel(torch.nn.Module):
18
+ def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280, clip_extra_context_tokens=4):
19
+ super().__init__()
20
+ self.cross_attention_dim = cross_attention_dim
21
+ self.clip_extra_context_tokens = clip_extra_context_tokens
22
+ self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
23
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
24
+
25
+ def forward(self, image_embeds):
26
+ clip_extra_context_tokens = self.proj(image_embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
27
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
28
+ return clip_extra_context_tokens
29
+
30
+
31
+ class IpAdapterModule(torch.nn.Module):
32
+ def __init__(self, input_dim, output_dim):
33
+ super().__init__()
34
+ self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
35
+ self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
36
+
37
+ def forward(self, hidden_states):
38
+ ip_k = self.to_k_ip(hidden_states)
39
+ ip_v = self.to_v_ip(hidden_states)
40
+ return ip_k, ip_v
41
+
42
+
43
+ class SDXLIpAdapter(torch.nn.Module):
44
+ def __init__(self):
45
+ super().__init__()
46
+ shape_list = [(2048, 640)] * 4 + [(2048, 1280)] * 50 + [(2048, 640)] * 6 + [(2048, 1280)] * 10
47
+ self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list])
48
+ self.image_proj = IpAdapterImageProjModel()
49
+ self.set_full_adapter()
50
+
51
+ def set_full_adapter(self):
52
+ map_list = sum([
53
+ [(7, i) for i in range(2)],
54
+ [(10, i) for i in range(2)],
55
+ [(15, i) for i in range(10)],
56
+ [(18, i) for i in range(10)],
57
+ [(25, i) for i in range(10)],
58
+ [(28, i) for i in range(10)],
59
+ [(31, i) for i in range(10)],
60
+ [(35, i) for i in range(2)],
61
+ [(38, i) for i in range(2)],
62
+ [(41, i) for i in range(2)],
63
+ [(21, i) for i in range(10)],
64
+ ], [])
65
+ self.call_block_id = {i: j for j, i in enumerate(map_list)}
66
+
67
+ def set_less_adapter(self):
68
+ map_list = sum([
69
+ [(7, i) for i in range(2)],
70
+ [(10, i) for i in range(2)],
71
+ [(15, i) for i in range(10)],
72
+ [(18, i) for i in range(10)],
73
+ [(25, i) for i in range(10)],
74
+ [(28, i) for i in range(10)],
75
+ [(31, i) for i in range(10)],
76
+ [(35, i) for i in range(2)],
77
+ [(38, i) for i in range(2)],
78
+ [(41, i) for i in range(2)],
79
+ [(21, i) for i in range(10)],
80
+ ], [])
81
+ self.call_block_id = {i: j for j, i in enumerate(map_list) if j>=34 and j<44}
82
+
83
+ def forward(self, hidden_states, scale=1.0):
84
+ hidden_states = self.image_proj(hidden_states)
85
+ hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
86
+ ip_kv_dict = {}
87
+ for (block_id, transformer_id) in self.call_block_id:
88
+ ipadapter_id = self.call_block_id[(block_id, transformer_id)]
89
+ ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
90
+ if block_id not in ip_kv_dict:
91
+ ip_kv_dict[block_id] = {}
92
+ ip_kv_dict[block_id][transformer_id] = {
93
+ "ip_k": ip_k,
94
+ "ip_v": ip_v,
95
+ "scale": scale
96
+ }
97
+ return ip_kv_dict
98
+
99
+ def state_dict_converter(self):
100
+ return SDXLIpAdapterStateDictConverter()
101
+
102
+
103
+ class SDXLIpAdapterStateDictConverter:
104
+ def __init__(self):
105
+ pass
106
+
107
+ def from_diffusers(self, state_dict):
108
+ state_dict_ = {}
109
+ for name in state_dict["ip_adapter"]:
110
+ names = name.split(".")
111
+ layer_id = str(int(names[0]) // 2)
112
+ name_ = ".".join(["ipadapter_modules"] + [layer_id] + names[1:])
113
+ state_dict_[name_] = state_dict["ip_adapter"][name]
114
+ for name in state_dict["image_proj"]:
115
+ name_ = "image_proj." + name
116
+ state_dict_[name_] = state_dict["image_proj"][name]
117
+ return state_dict_
118
+
119
+ def from_civitai(self, state_dict):
120
+ return self.from_diffusers(state_dict)
121
+
diffsynth/models/sdxl_motion.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .sd_motion import TemporalBlock
2
+ import torch
3
+
4
+
5
+
6
+ class SDXLMotionModel(torch.nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+ self.motion_modules = torch.nn.ModuleList([
10
+ TemporalBlock(8, 320//8, 320, eps=1e-6),
11
+ TemporalBlock(8, 320//8, 320, eps=1e-6),
12
+
13
+ TemporalBlock(8, 640//8, 640, eps=1e-6),
14
+ TemporalBlock(8, 640//8, 640, eps=1e-6),
15
+
16
+ TemporalBlock(8, 1280//8, 1280, eps=1e-6),
17
+ TemporalBlock(8, 1280//8, 1280, eps=1e-6),
18
+
19
+ TemporalBlock(8, 1280//8, 1280, eps=1e-6),
20
+ TemporalBlock(8, 1280//8, 1280, eps=1e-6),
21
+ TemporalBlock(8, 1280//8, 1280, eps=1e-6),
22
+
23
+ TemporalBlock(8, 640//8, 640, eps=1e-6),
24
+ TemporalBlock(8, 640//8, 640, eps=1e-6),
25
+ TemporalBlock(8, 640//8, 640, eps=1e-6),
26
+
27
+ TemporalBlock(8, 320//8, 320, eps=1e-6),
28
+ TemporalBlock(8, 320//8, 320, eps=1e-6),
29
+ TemporalBlock(8, 320//8, 320, eps=1e-6),
30
+ ])
31
+ self.call_block_id = {
32
+ 0: 0,
33
+ 2: 1,
34
+ 7: 2,
35
+ 10: 3,
36
+ 15: 4,
37
+ 18: 5,
38
+ 25: 6,
39
+ 28: 7,
40
+ 31: 8,
41
+ 35: 9,
42
+ 38: 10,
43
+ 41: 11,
44
+ 44: 12,
45
+ 46: 13,
46
+ 48: 14,
47
+ }
48
+
49
+ def forward(self):
50
+ pass
51
+
52
+ def state_dict_converter(self):
53
+ return SDMotionModelStateDictConverter()
54
+
55
+
56
+ class SDMotionModelStateDictConverter:
57
+ def __init__(self):
58
+ pass
59
+
60
+ def from_diffusers(self, state_dict):
61
+ rename_dict = {
62
+ "norm": "norm",
63
+ "proj_in": "proj_in",
64
+ "transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q",
65
+ "transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k",
66
+ "transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v",
67
+ "transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out",
68
+ "transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1",
69
+ "transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q",
70
+ "transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k",
71
+ "transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v",
72
+ "transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out",
73
+ "transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2",
74
+ "transformer_blocks.0.norms.0": "transformer_blocks.0.norm1",
75
+ "transformer_blocks.0.norms.1": "transformer_blocks.0.norm2",
76
+ "transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj",
77
+ "transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff",
78
+ "transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3",
79
+ "proj_out": "proj_out",
80
+ }
81
+ name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")])
82
+ name_list += sorted([i for i in state_dict if i.startswith("mid_block.")])
83
+ name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")])
84
+ state_dict_ = {}
85
+ last_prefix, module_id = "", -1
86
+ for name in name_list:
87
+ names = name.split(".")
88
+ prefix_index = names.index("temporal_transformer") + 1
89
+ prefix = ".".join(names[:prefix_index])
90
+ if prefix != last_prefix:
91
+ last_prefix = prefix
92
+ module_id += 1
93
+ middle_name = ".".join(names[prefix_index:-1])
94
+ suffix = names[-1]
95
+ if "pos_encoder" in names:
96
+ rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]])
97
+ else:
98
+ rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
99
+ state_dict_[rename] = state_dict[name]
100
+ return state_dict_
101
+
102
+ def from_civitai(self, state_dict):
103
+ return self.from_diffusers(state_dict)
diffsynth/models/sdxl_text_encoder.py ADDED
@@ -0,0 +1,757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .sd_text_encoder import CLIPEncoderLayer
3
+
4
+
5
+ class SDXLTextEncoder(torch.nn.Module):
6
+ def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=11, encoder_intermediate_size=3072):
7
+ super().__init__()
8
+
9
+ # token_embedding
10
+ self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
11
+
12
+ # position_embeds (This is a fixed tensor)
13
+ self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
14
+
15
+ # encoders
16
+ self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
17
+
18
+ # attn_mask
19
+ self.attn_mask = self.attention_mask(max_position_embeddings)
20
+
21
+ # The text encoder is different to that in Stable Diffusion 1.x.
22
+ # It does not include final_layer_norm.
23
+
24
+ def attention_mask(self, length):
25
+ mask = torch.empty(length, length)
26
+ mask.fill_(float("-inf"))
27
+ mask.triu_(1)
28
+ return mask
29
+
30
+ def forward(self, input_ids, clip_skip=1):
31
+ embeds = self.token_embedding(input_ids) + self.position_embeds
32
+ attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
33
+ for encoder_id, encoder in enumerate(self.encoders):
34
+ embeds = encoder(embeds, attn_mask=attn_mask)
35
+ if encoder_id + clip_skip == len(self.encoders):
36
+ break
37
+ return embeds
38
+
39
+ def state_dict_converter(self):
40
+ return SDXLTextEncoderStateDictConverter()
41
+
42
+
43
+ class SDXLTextEncoder2(torch.nn.Module):
44
+ def __init__(self, embed_dim=1280, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=32, encoder_intermediate_size=5120):
45
+ super().__init__()
46
+
47
+ # token_embedding
48
+ self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
49
+
50
+ # position_embeds (This is a fixed tensor)
51
+ self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
52
+
53
+ # encoders
54
+ self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=20, head_dim=64, use_quick_gelu=False) for _ in range(num_encoder_layers)])
55
+
56
+ # attn_mask
57
+ self.attn_mask = self.attention_mask(max_position_embeddings)
58
+
59
+ # final_layer_norm
60
+ self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
61
+
62
+ # text_projection
63
+ self.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False)
64
+
65
+ def attention_mask(self, length):
66
+ mask = torch.empty(length, length)
67
+ mask.fill_(float("-inf"))
68
+ mask.triu_(1)
69
+ return mask
70
+
71
+ def forward(self, input_ids, clip_skip=2):
72
+ embeds = self.token_embedding(input_ids) + self.position_embeds
73
+ attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
74
+ for encoder_id, encoder in enumerate(self.encoders):
75
+ embeds = encoder(embeds, attn_mask=attn_mask)
76
+ if encoder_id + clip_skip == len(self.encoders):
77
+ hidden_states = embeds
78
+ embeds = self.final_layer_norm(embeds)
79
+ pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
80
+ pooled_embeds = self.text_projection(pooled_embeds)
81
+ return pooled_embeds, hidden_states
82
+
83
+ def state_dict_converter(self):
84
+ return SDXLTextEncoder2StateDictConverter()
85
+
86
+
87
+ class SDXLTextEncoderStateDictConverter:
88
+ def __init__(self):
89
+ pass
90
+
91
+ def from_diffusers(self, state_dict):
92
+ rename_dict = {
93
+ "text_model.embeddings.token_embedding.weight": "token_embedding.weight",
94
+ "text_model.embeddings.position_embedding.weight": "position_embeds",
95
+ "text_model.final_layer_norm.weight": "final_layer_norm.weight",
96
+ "text_model.final_layer_norm.bias": "final_layer_norm.bias"
97
+ }
98
+ attn_rename_dict = {
99
+ "self_attn.q_proj": "attn.to_q",
100
+ "self_attn.k_proj": "attn.to_k",
101
+ "self_attn.v_proj": "attn.to_v",
102
+ "self_attn.out_proj": "attn.to_out",
103
+ "layer_norm1": "layer_norm1",
104
+ "layer_norm2": "layer_norm2",
105
+ "mlp.fc1": "fc1",
106
+ "mlp.fc2": "fc2",
107
+ }
108
+ state_dict_ = {}
109
+ for name in state_dict:
110
+ if name in rename_dict:
111
+ param = state_dict[name]
112
+ if name == "text_model.embeddings.position_embedding.weight":
113
+ param = param.reshape((1, param.shape[0], param.shape[1]))
114
+ state_dict_[rename_dict[name]] = param
115
+ elif name.startswith("text_model.encoder.layers."):
116
+ param = state_dict[name]
117
+ names = name.split(".")
118
+ layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
119
+ name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
120
+ state_dict_[name_] = param
121
+ return state_dict_
122
+
123
+ def from_civitai(self, state_dict):
124
+ rename_dict = {
125
+ "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight": "position_embeds",
126
+ "conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
127
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
128
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
129
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
130
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
131
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
132
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
133
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
134
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
135
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
136
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
137
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
138
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
139
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
140
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
141
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
142
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
143
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
144
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
145
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
146
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
147
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
148
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
149
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
150
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
151
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
152
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
153
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
154
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
155
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
156
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
157
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
158
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
159
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
160
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
161
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
162
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
163
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
164
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
165
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
166
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
167
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
168
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
169
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
170
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
171
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
172
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
173
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
174
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
175
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
176
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
177
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
178
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
179
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
180
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
181
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
182
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
183
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
184
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
185
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
186
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
187
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
188
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
189
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
190
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
191
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
192
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
193
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
194
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
195
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
196
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
197
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
198
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
199
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
200
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
201
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
202
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
203
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
204
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
205
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
206
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
207
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
208
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
209
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
210
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
211
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
212
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
213
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
214
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
215
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
216
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
217
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
218
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
219
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
220
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
221
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
222
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
223
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
224
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
225
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
226
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
227
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
228
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
229
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
230
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
231
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
232
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
233
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
234
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
235
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
236
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
237
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
238
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
239
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
240
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
241
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
242
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
243
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
244
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
245
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
246
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
247
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
248
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
249
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
250
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
251
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
252
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
253
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
254
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
255
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
256
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
257
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
258
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
259
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
260
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
261
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
262
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
263
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
264
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
265
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
266
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
267
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
268
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
269
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
270
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
271
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
272
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
273
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
274
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
275
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
276
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
277
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
278
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
279
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
280
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
281
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
282
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
283
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
284
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
285
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
286
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
287
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
288
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
289
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
290
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
291
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
292
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
293
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
294
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
295
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
296
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
297
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
298
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
299
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
300
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
301
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
302
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
303
+ }
304
+ state_dict_ = {}
305
+ for name in state_dict:
306
+ if name in rename_dict:
307
+ param = state_dict[name]
308
+ if name == "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight":
309
+ param = param.reshape((1, param.shape[0], param.shape[1]))
310
+ state_dict_[rename_dict[name]] = param
311
+ return state_dict_
312
+
313
+
314
+ class SDXLTextEncoder2StateDictConverter:
315
+ def __init__(self):
316
+ pass
317
+
318
+ def from_diffusers(self, state_dict):
319
+ rename_dict = {
320
+ "text_model.embeddings.token_embedding.weight": "token_embedding.weight",
321
+ "text_model.embeddings.position_embedding.weight": "position_embeds",
322
+ "text_model.final_layer_norm.weight": "final_layer_norm.weight",
323
+ "text_model.final_layer_norm.bias": "final_layer_norm.bias",
324
+ "text_projection.weight": "text_projection.weight"
325
+ }
326
+ attn_rename_dict = {
327
+ "self_attn.q_proj": "attn.to_q",
328
+ "self_attn.k_proj": "attn.to_k",
329
+ "self_attn.v_proj": "attn.to_v",
330
+ "self_attn.out_proj": "attn.to_out",
331
+ "layer_norm1": "layer_norm1",
332
+ "layer_norm2": "layer_norm2",
333
+ "mlp.fc1": "fc1",
334
+ "mlp.fc2": "fc2",
335
+ }
336
+ state_dict_ = {}
337
+ for name in state_dict:
338
+ if name in rename_dict:
339
+ param = state_dict[name]
340
+ if name == "text_model.embeddings.position_embedding.weight":
341
+ param = param.reshape((1, param.shape[0], param.shape[1]))
342
+ state_dict_[rename_dict[name]] = param
343
+ elif name.startswith("text_model.encoder.layers."):
344
+ param = state_dict[name]
345
+ names = name.split(".")
346
+ layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
347
+ name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
348
+ state_dict_[name_] = param
349
+ return state_dict_
350
+
351
+ def from_civitai(self, state_dict):
352
+ rename_dict = {
353
+ "conditioner.embedders.1.model.ln_final.bias": "final_layer_norm.bias",
354
+ "conditioner.embedders.1.model.ln_final.weight": "final_layer_norm.weight",
355
+ "conditioner.embedders.1.model.positional_embedding": "position_embeds",
356
+ "conditioner.embedders.1.model.token_embedding.weight": "token_embedding.weight",
357
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias": ['encoders.0.attn.to_q.bias', 'encoders.0.attn.to_k.bias', 'encoders.0.attn.to_v.bias'],
358
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight": ['encoders.0.attn.to_q.weight', 'encoders.0.attn.to_k.weight', 'encoders.0.attn.to_v.weight'],
359
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.bias": "encoders.0.attn.to_out.bias",
360
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.weight": "encoders.0.attn.to_out.weight",
361
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias": "encoders.0.layer_norm1.bias",
362
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.weight": "encoders.0.layer_norm1.weight",
363
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.bias": "encoders.0.layer_norm2.bias",
364
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.weight": "encoders.0.layer_norm2.weight",
365
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.bias": "encoders.0.fc1.bias",
366
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.weight": "encoders.0.fc1.weight",
367
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.bias": "encoders.0.fc2.bias",
368
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.weight": "encoders.0.fc2.weight",
369
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias": ['encoders.1.attn.to_q.bias', 'encoders.1.attn.to_k.bias', 'encoders.1.attn.to_v.bias'],
370
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight": ['encoders.1.attn.to_q.weight', 'encoders.1.attn.to_k.weight', 'encoders.1.attn.to_v.weight'],
371
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.bias": "encoders.1.attn.to_out.bias",
372
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.weight": "encoders.1.attn.to_out.weight",
373
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.bias": "encoders.1.layer_norm1.bias",
374
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.weight": "encoders.1.layer_norm1.weight",
375
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.bias": "encoders.1.layer_norm2.bias",
376
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.weight": "encoders.1.layer_norm2.weight",
377
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.bias": "encoders.1.fc1.bias",
378
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.weight": "encoders.1.fc1.weight",
379
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.bias": "encoders.1.fc2.bias",
380
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.weight": "encoders.1.fc2.weight",
381
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias": ['encoders.10.attn.to_q.bias', 'encoders.10.attn.to_k.bias', 'encoders.10.attn.to_v.bias'],
382
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight": ['encoders.10.attn.to_q.weight', 'encoders.10.attn.to_k.weight', 'encoders.10.attn.to_v.weight'],
383
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.bias": "encoders.10.attn.to_out.bias",
384
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.weight": "encoders.10.attn.to_out.weight",
385
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.bias": "encoders.10.layer_norm1.bias",
386
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.weight": "encoders.10.layer_norm1.weight",
387
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.bias": "encoders.10.layer_norm2.bias",
388
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.weight": "encoders.10.layer_norm2.weight",
389
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.bias": "encoders.10.fc1.bias",
390
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.weight": "encoders.10.fc1.weight",
391
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.bias": "encoders.10.fc2.bias",
392
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.weight": "encoders.10.fc2.weight",
393
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias": ['encoders.11.attn.to_q.bias', 'encoders.11.attn.to_k.bias', 'encoders.11.attn.to_v.bias'],
394
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight": ['encoders.11.attn.to_q.weight', 'encoders.11.attn.to_k.weight', 'encoders.11.attn.to_v.weight'],
395
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.bias": "encoders.11.attn.to_out.bias",
396
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.weight": "encoders.11.attn.to_out.weight",
397
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.bias": "encoders.11.layer_norm1.bias",
398
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.weight": "encoders.11.layer_norm1.weight",
399
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.bias": "encoders.11.layer_norm2.bias",
400
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.weight": "encoders.11.layer_norm2.weight",
401
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.bias": "encoders.11.fc1.bias",
402
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.weight": "encoders.11.fc1.weight",
403
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.bias": "encoders.11.fc2.bias",
404
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.weight": "encoders.11.fc2.weight",
405
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias": ['encoders.12.attn.to_q.bias', 'encoders.12.attn.to_k.bias', 'encoders.12.attn.to_v.bias'],
406
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight": ['encoders.12.attn.to_q.weight', 'encoders.12.attn.to_k.weight', 'encoders.12.attn.to_v.weight'],
407
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.bias": "encoders.12.attn.to_out.bias",
408
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.weight": "encoders.12.attn.to_out.weight",
409
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.bias": "encoders.12.layer_norm1.bias",
410
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.weight": "encoders.12.layer_norm1.weight",
411
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.bias": "encoders.12.layer_norm2.bias",
412
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.weight": "encoders.12.layer_norm2.weight",
413
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.bias": "encoders.12.fc1.bias",
414
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.weight": "encoders.12.fc1.weight",
415
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.bias": "encoders.12.fc2.bias",
416
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.weight": "encoders.12.fc2.weight",
417
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias": ['encoders.13.attn.to_q.bias', 'encoders.13.attn.to_k.bias', 'encoders.13.attn.to_v.bias'],
418
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight": ['encoders.13.attn.to_q.weight', 'encoders.13.attn.to_k.weight', 'encoders.13.attn.to_v.weight'],
419
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.bias": "encoders.13.attn.to_out.bias",
420
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.weight": "encoders.13.attn.to_out.weight",
421
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.bias": "encoders.13.layer_norm1.bias",
422
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.weight": "encoders.13.layer_norm1.weight",
423
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.bias": "encoders.13.layer_norm2.bias",
424
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.weight": "encoders.13.layer_norm2.weight",
425
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.bias": "encoders.13.fc1.bias",
426
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.weight": "encoders.13.fc1.weight",
427
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.bias": "encoders.13.fc2.bias",
428
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.weight": "encoders.13.fc2.weight",
429
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias": ['encoders.14.attn.to_q.bias', 'encoders.14.attn.to_k.bias', 'encoders.14.attn.to_v.bias'],
430
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight": ['encoders.14.attn.to_q.weight', 'encoders.14.attn.to_k.weight', 'encoders.14.attn.to_v.weight'],
431
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.bias": "encoders.14.attn.to_out.bias",
432
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.weight": "encoders.14.attn.to_out.weight",
433
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.bias": "encoders.14.layer_norm1.bias",
434
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.weight": "encoders.14.layer_norm1.weight",
435
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.bias": "encoders.14.layer_norm2.bias",
436
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.weight": "encoders.14.layer_norm2.weight",
437
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.bias": "encoders.14.fc1.bias",
438
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.weight": "encoders.14.fc1.weight",
439
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.bias": "encoders.14.fc2.bias",
440
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.weight": "encoders.14.fc2.weight",
441
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias": ['encoders.15.attn.to_q.bias', 'encoders.15.attn.to_k.bias', 'encoders.15.attn.to_v.bias'],
442
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight": ['encoders.15.attn.to_q.weight', 'encoders.15.attn.to_k.weight', 'encoders.15.attn.to_v.weight'],
443
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.bias": "encoders.15.attn.to_out.bias",
444
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.weight": "encoders.15.attn.to_out.weight",
445
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.bias": "encoders.15.layer_norm1.bias",
446
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.weight": "encoders.15.layer_norm1.weight",
447
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.bias": "encoders.15.layer_norm2.bias",
448
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.weight": "encoders.15.layer_norm2.weight",
449
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.bias": "encoders.15.fc1.bias",
450
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.weight": "encoders.15.fc1.weight",
451
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.bias": "encoders.15.fc2.bias",
452
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.weight": "encoders.15.fc2.weight",
453
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias": ['encoders.16.attn.to_q.bias', 'encoders.16.attn.to_k.bias', 'encoders.16.attn.to_v.bias'],
454
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight": ['encoders.16.attn.to_q.weight', 'encoders.16.attn.to_k.weight', 'encoders.16.attn.to_v.weight'],
455
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.bias": "encoders.16.attn.to_out.bias",
456
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.weight": "encoders.16.attn.to_out.weight",
457
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.bias": "encoders.16.layer_norm1.bias",
458
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.weight": "encoders.16.layer_norm1.weight",
459
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.bias": "encoders.16.layer_norm2.bias",
460
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.weight": "encoders.16.layer_norm2.weight",
461
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.bias": "encoders.16.fc1.bias",
462
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.weight": "encoders.16.fc1.weight",
463
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.bias": "encoders.16.fc2.bias",
464
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.weight": "encoders.16.fc2.weight",
465
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias": ['encoders.17.attn.to_q.bias', 'encoders.17.attn.to_k.bias', 'encoders.17.attn.to_v.bias'],
466
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight": ['encoders.17.attn.to_q.weight', 'encoders.17.attn.to_k.weight', 'encoders.17.attn.to_v.weight'],
467
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.bias": "encoders.17.attn.to_out.bias",
468
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.weight": "encoders.17.attn.to_out.weight",
469
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.bias": "encoders.17.layer_norm1.bias",
470
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.weight": "encoders.17.layer_norm1.weight",
471
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.bias": "encoders.17.layer_norm2.bias",
472
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.weight": "encoders.17.layer_norm2.weight",
473
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.bias": "encoders.17.fc1.bias",
474
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.weight": "encoders.17.fc1.weight",
475
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.bias": "encoders.17.fc2.bias",
476
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.weight": "encoders.17.fc2.weight",
477
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias": ['encoders.18.attn.to_q.bias', 'encoders.18.attn.to_k.bias', 'encoders.18.attn.to_v.bias'],
478
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight": ['encoders.18.attn.to_q.weight', 'encoders.18.attn.to_k.weight', 'encoders.18.attn.to_v.weight'],
479
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.bias": "encoders.18.attn.to_out.bias",
480
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.weight": "encoders.18.attn.to_out.weight",
481
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.bias": "encoders.18.layer_norm1.bias",
482
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.weight": "encoders.18.layer_norm1.weight",
483
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.bias": "encoders.18.layer_norm2.bias",
484
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.weight": "encoders.18.layer_norm2.weight",
485
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.bias": "encoders.18.fc1.bias",
486
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.weight": "encoders.18.fc1.weight",
487
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.bias": "encoders.18.fc2.bias",
488
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.weight": "encoders.18.fc2.weight",
489
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias": ['encoders.19.attn.to_q.bias', 'encoders.19.attn.to_k.bias', 'encoders.19.attn.to_v.bias'],
490
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight": ['encoders.19.attn.to_q.weight', 'encoders.19.attn.to_k.weight', 'encoders.19.attn.to_v.weight'],
491
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.bias": "encoders.19.attn.to_out.bias",
492
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.weight": "encoders.19.attn.to_out.weight",
493
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.bias": "encoders.19.layer_norm1.bias",
494
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.weight": "encoders.19.layer_norm1.weight",
495
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.bias": "encoders.19.layer_norm2.bias",
496
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.weight": "encoders.19.layer_norm2.weight",
497
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.bias": "encoders.19.fc1.bias",
498
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.weight": "encoders.19.fc1.weight",
499
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.bias": "encoders.19.fc2.bias",
500
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.weight": "encoders.19.fc2.weight",
501
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias": ['encoders.2.attn.to_q.bias', 'encoders.2.attn.to_k.bias', 'encoders.2.attn.to_v.bias'],
502
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight": ['encoders.2.attn.to_q.weight', 'encoders.2.attn.to_k.weight', 'encoders.2.attn.to_v.weight'],
503
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.bias": "encoders.2.attn.to_out.bias",
504
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.weight": "encoders.2.attn.to_out.weight",
505
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.bias": "encoders.2.layer_norm1.bias",
506
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.weight": "encoders.2.layer_norm1.weight",
507
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.bias": "encoders.2.layer_norm2.bias",
508
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.weight": "encoders.2.layer_norm2.weight",
509
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.bias": "encoders.2.fc1.bias",
510
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.weight": "encoders.2.fc1.weight",
511
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.bias": "encoders.2.fc2.bias",
512
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.weight": "encoders.2.fc2.weight",
513
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias": ['encoders.20.attn.to_q.bias', 'encoders.20.attn.to_k.bias', 'encoders.20.attn.to_v.bias'],
514
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight": ['encoders.20.attn.to_q.weight', 'encoders.20.attn.to_k.weight', 'encoders.20.attn.to_v.weight'],
515
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.bias": "encoders.20.attn.to_out.bias",
516
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.weight": "encoders.20.attn.to_out.weight",
517
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.bias": "encoders.20.layer_norm1.bias",
518
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.weight": "encoders.20.layer_norm1.weight",
519
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.bias": "encoders.20.layer_norm2.bias",
520
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.weight": "encoders.20.layer_norm2.weight",
521
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.bias": "encoders.20.fc1.bias",
522
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.weight": "encoders.20.fc1.weight",
523
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.bias": "encoders.20.fc2.bias",
524
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.weight": "encoders.20.fc2.weight",
525
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias": ['encoders.21.attn.to_q.bias', 'encoders.21.attn.to_k.bias', 'encoders.21.attn.to_v.bias'],
526
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight": ['encoders.21.attn.to_q.weight', 'encoders.21.attn.to_k.weight', 'encoders.21.attn.to_v.weight'],
527
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.bias": "encoders.21.attn.to_out.bias",
528
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.weight": "encoders.21.attn.to_out.weight",
529
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.bias": "encoders.21.layer_norm1.bias",
530
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.weight": "encoders.21.layer_norm1.weight",
531
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.bias": "encoders.21.layer_norm2.bias",
532
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.weight": "encoders.21.layer_norm2.weight",
533
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.bias": "encoders.21.fc1.bias",
534
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.weight": "encoders.21.fc1.weight",
535
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.bias": "encoders.21.fc2.bias",
536
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.weight": "encoders.21.fc2.weight",
537
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias": ['encoders.22.attn.to_q.bias', 'encoders.22.attn.to_k.bias', 'encoders.22.attn.to_v.bias'],
538
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight": ['encoders.22.attn.to_q.weight', 'encoders.22.attn.to_k.weight', 'encoders.22.attn.to_v.weight'],
539
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.bias": "encoders.22.attn.to_out.bias",
540
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.weight": "encoders.22.attn.to_out.weight",
541
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.bias": "encoders.22.layer_norm1.bias",
542
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.weight": "encoders.22.layer_norm1.weight",
543
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.bias": "encoders.22.layer_norm2.bias",
544
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.weight": "encoders.22.layer_norm2.weight",
545
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.bias": "encoders.22.fc1.bias",
546
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.weight": "encoders.22.fc1.weight",
547
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.bias": "encoders.22.fc2.bias",
548
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.weight": "encoders.22.fc2.weight",
549
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias": ['encoders.23.attn.to_q.bias', 'encoders.23.attn.to_k.bias', 'encoders.23.attn.to_v.bias'],
550
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight": ['encoders.23.attn.to_q.weight', 'encoders.23.attn.to_k.weight', 'encoders.23.attn.to_v.weight'],
551
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.bias": "encoders.23.attn.to_out.bias",
552
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.weight": "encoders.23.attn.to_out.weight",
553
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.bias": "encoders.23.layer_norm1.bias",
554
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.weight": "encoders.23.layer_norm1.weight",
555
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.bias": "encoders.23.layer_norm2.bias",
556
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.weight": "encoders.23.layer_norm2.weight",
557
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.bias": "encoders.23.fc1.bias",
558
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.weight": "encoders.23.fc1.weight",
559
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.bias": "encoders.23.fc2.bias",
560
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.weight": "encoders.23.fc2.weight",
561
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias": ['encoders.24.attn.to_q.bias', 'encoders.24.attn.to_k.bias', 'encoders.24.attn.to_v.bias'],
562
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight": ['encoders.24.attn.to_q.weight', 'encoders.24.attn.to_k.weight', 'encoders.24.attn.to_v.weight'],
563
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.bias": "encoders.24.attn.to_out.bias",
564
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.weight": "encoders.24.attn.to_out.weight",
565
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.bias": "encoders.24.layer_norm1.bias",
566
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.weight": "encoders.24.layer_norm1.weight",
567
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.bias": "encoders.24.layer_norm2.bias",
568
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.weight": "encoders.24.layer_norm2.weight",
569
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.bias": "encoders.24.fc1.bias",
570
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.weight": "encoders.24.fc1.weight",
571
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.bias": "encoders.24.fc2.bias",
572
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.weight": "encoders.24.fc2.weight",
573
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias": ['encoders.25.attn.to_q.bias', 'encoders.25.attn.to_k.bias', 'encoders.25.attn.to_v.bias'],
574
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight": ['encoders.25.attn.to_q.weight', 'encoders.25.attn.to_k.weight', 'encoders.25.attn.to_v.weight'],
575
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.bias": "encoders.25.attn.to_out.bias",
576
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.weight": "encoders.25.attn.to_out.weight",
577
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.bias": "encoders.25.layer_norm1.bias",
578
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.weight": "encoders.25.layer_norm1.weight",
579
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.bias": "encoders.25.layer_norm2.bias",
580
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.weight": "encoders.25.layer_norm2.weight",
581
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.bias": "encoders.25.fc1.bias",
582
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.weight": "encoders.25.fc1.weight",
583
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.bias": "encoders.25.fc2.bias",
584
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.weight": "encoders.25.fc2.weight",
585
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias": ['encoders.26.attn.to_q.bias', 'encoders.26.attn.to_k.bias', 'encoders.26.attn.to_v.bias'],
586
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight": ['encoders.26.attn.to_q.weight', 'encoders.26.attn.to_k.weight', 'encoders.26.attn.to_v.weight'],
587
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.bias": "encoders.26.attn.to_out.bias",
588
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.weight": "encoders.26.attn.to_out.weight",
589
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.bias": "encoders.26.layer_norm1.bias",
590
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.weight": "encoders.26.layer_norm1.weight",
591
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.bias": "encoders.26.layer_norm2.bias",
592
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.weight": "encoders.26.layer_norm2.weight",
593
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.bias": "encoders.26.fc1.bias",
594
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.weight": "encoders.26.fc1.weight",
595
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.bias": "encoders.26.fc2.bias",
596
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.weight": "encoders.26.fc2.weight",
597
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias": ['encoders.27.attn.to_q.bias', 'encoders.27.attn.to_k.bias', 'encoders.27.attn.to_v.bias'],
598
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight": ['encoders.27.attn.to_q.weight', 'encoders.27.attn.to_k.weight', 'encoders.27.attn.to_v.weight'],
599
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.bias": "encoders.27.attn.to_out.bias",
600
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.weight": "encoders.27.attn.to_out.weight",
601
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.bias": "encoders.27.layer_norm1.bias",
602
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.weight": "encoders.27.layer_norm1.weight",
603
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.bias": "encoders.27.layer_norm2.bias",
604
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.weight": "encoders.27.layer_norm2.weight",
605
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.bias": "encoders.27.fc1.bias",
606
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.weight": "encoders.27.fc1.weight",
607
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.bias": "encoders.27.fc2.bias",
608
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.weight": "encoders.27.fc2.weight",
609
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias": ['encoders.28.attn.to_q.bias', 'encoders.28.attn.to_k.bias', 'encoders.28.attn.to_v.bias'],
610
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight": ['encoders.28.attn.to_q.weight', 'encoders.28.attn.to_k.weight', 'encoders.28.attn.to_v.weight'],
611
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.bias": "encoders.28.attn.to_out.bias",
612
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.weight": "encoders.28.attn.to_out.weight",
613
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.bias": "encoders.28.layer_norm1.bias",
614
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.weight": "encoders.28.layer_norm1.weight",
615
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.bias": "encoders.28.layer_norm2.bias",
616
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.weight": "encoders.28.layer_norm2.weight",
617
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.bias": "encoders.28.fc1.bias",
618
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.weight": "encoders.28.fc1.weight",
619
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.bias": "encoders.28.fc2.bias",
620
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.weight": "encoders.28.fc2.weight",
621
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias": ['encoders.29.attn.to_q.bias', 'encoders.29.attn.to_k.bias', 'encoders.29.attn.to_v.bias'],
622
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight": ['encoders.29.attn.to_q.weight', 'encoders.29.attn.to_k.weight', 'encoders.29.attn.to_v.weight'],
623
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.bias": "encoders.29.attn.to_out.bias",
624
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.weight": "encoders.29.attn.to_out.weight",
625
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.bias": "encoders.29.layer_norm1.bias",
626
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.weight": "encoders.29.layer_norm1.weight",
627
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.bias": "encoders.29.layer_norm2.bias",
628
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.weight": "encoders.29.layer_norm2.weight",
629
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.bias": "encoders.29.fc1.bias",
630
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.weight": "encoders.29.fc1.weight",
631
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.bias": "encoders.29.fc2.bias",
632
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.weight": "encoders.29.fc2.weight",
633
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias": ['encoders.3.attn.to_q.bias', 'encoders.3.attn.to_k.bias', 'encoders.3.attn.to_v.bias'],
634
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight": ['encoders.3.attn.to_q.weight', 'encoders.3.attn.to_k.weight', 'encoders.3.attn.to_v.weight'],
635
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.bias": "encoders.3.attn.to_out.bias",
636
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.weight": "encoders.3.attn.to_out.weight",
637
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.bias": "encoders.3.layer_norm1.bias",
638
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.weight": "encoders.3.layer_norm1.weight",
639
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.bias": "encoders.3.layer_norm2.bias",
640
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.weight": "encoders.3.layer_norm2.weight",
641
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.bias": "encoders.3.fc1.bias",
642
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.weight": "encoders.3.fc1.weight",
643
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.bias": "encoders.3.fc2.bias",
644
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.weight": "encoders.3.fc2.weight",
645
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias": ['encoders.30.attn.to_q.bias', 'encoders.30.attn.to_k.bias', 'encoders.30.attn.to_v.bias'],
646
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight": ['encoders.30.attn.to_q.weight', 'encoders.30.attn.to_k.weight', 'encoders.30.attn.to_v.weight'],
647
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.bias": "encoders.30.attn.to_out.bias",
648
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.weight": "encoders.30.attn.to_out.weight",
649
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.bias": "encoders.30.layer_norm1.bias",
650
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.weight": "encoders.30.layer_norm1.weight",
651
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.bias": "encoders.30.layer_norm2.bias",
652
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.weight": "encoders.30.layer_norm2.weight",
653
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.bias": "encoders.30.fc1.bias",
654
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.weight": "encoders.30.fc1.weight",
655
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.bias": "encoders.30.fc2.bias",
656
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.weight": "encoders.30.fc2.weight",
657
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias": ['encoders.31.attn.to_q.bias', 'encoders.31.attn.to_k.bias', 'encoders.31.attn.to_v.bias'],
658
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight": ['encoders.31.attn.to_q.weight', 'encoders.31.attn.to_k.weight', 'encoders.31.attn.to_v.weight'],
659
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.bias": "encoders.31.attn.to_out.bias",
660
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.weight": "encoders.31.attn.to_out.weight",
661
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.bias": "encoders.31.layer_norm1.bias",
662
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.weight": "encoders.31.layer_norm1.weight",
663
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.bias": "encoders.31.layer_norm2.bias",
664
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.weight": "encoders.31.layer_norm2.weight",
665
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.bias": "encoders.31.fc1.bias",
666
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.weight": "encoders.31.fc1.weight",
667
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.bias": "encoders.31.fc2.bias",
668
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.weight": "encoders.31.fc2.weight",
669
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias": ['encoders.4.attn.to_q.bias', 'encoders.4.attn.to_k.bias', 'encoders.4.attn.to_v.bias'],
670
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight": ['encoders.4.attn.to_q.weight', 'encoders.4.attn.to_k.weight', 'encoders.4.attn.to_v.weight'],
671
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.bias": "encoders.4.attn.to_out.bias",
672
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.weight": "encoders.4.attn.to_out.weight",
673
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.bias": "encoders.4.layer_norm1.bias",
674
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.weight": "encoders.4.layer_norm1.weight",
675
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.bias": "encoders.4.layer_norm2.bias",
676
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.weight": "encoders.4.layer_norm2.weight",
677
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.bias": "encoders.4.fc1.bias",
678
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.weight": "encoders.4.fc1.weight",
679
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.bias": "encoders.4.fc2.bias",
680
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.weight": "encoders.4.fc2.weight",
681
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias": ['encoders.5.attn.to_q.bias', 'encoders.5.attn.to_k.bias', 'encoders.5.attn.to_v.bias'],
682
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight": ['encoders.5.attn.to_q.weight', 'encoders.5.attn.to_k.weight', 'encoders.5.attn.to_v.weight'],
683
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.bias": "encoders.5.attn.to_out.bias",
684
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.weight": "encoders.5.attn.to_out.weight",
685
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.bias": "encoders.5.layer_norm1.bias",
686
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.weight": "encoders.5.layer_norm1.weight",
687
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.bias": "encoders.5.layer_norm2.bias",
688
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.weight": "encoders.5.layer_norm2.weight",
689
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.bias": "encoders.5.fc1.bias",
690
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.weight": "encoders.5.fc1.weight",
691
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.bias": "encoders.5.fc2.bias",
692
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.weight": "encoders.5.fc2.weight",
693
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias": ['encoders.6.attn.to_q.bias', 'encoders.6.attn.to_k.bias', 'encoders.6.attn.to_v.bias'],
694
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight": ['encoders.6.attn.to_q.weight', 'encoders.6.attn.to_k.weight', 'encoders.6.attn.to_v.weight'],
695
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.bias": "encoders.6.attn.to_out.bias",
696
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.weight": "encoders.6.attn.to_out.weight",
697
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.bias": "encoders.6.layer_norm1.bias",
698
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.weight": "encoders.6.layer_norm1.weight",
699
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.bias": "encoders.6.layer_norm2.bias",
700
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.weight": "encoders.6.layer_norm2.weight",
701
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.bias": "encoders.6.fc1.bias",
702
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.weight": "encoders.6.fc1.weight",
703
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.bias": "encoders.6.fc2.bias",
704
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.weight": "encoders.6.fc2.weight",
705
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias": ['encoders.7.attn.to_q.bias', 'encoders.7.attn.to_k.bias', 'encoders.7.attn.to_v.bias'],
706
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight": ['encoders.7.attn.to_q.weight', 'encoders.7.attn.to_k.weight', 'encoders.7.attn.to_v.weight'],
707
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.bias": "encoders.7.attn.to_out.bias",
708
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.weight": "encoders.7.attn.to_out.weight",
709
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.bias": "encoders.7.layer_norm1.bias",
710
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.weight": "encoders.7.layer_norm1.weight",
711
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.bias": "encoders.7.layer_norm2.bias",
712
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.weight": "encoders.7.layer_norm2.weight",
713
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.bias": "encoders.7.fc1.bias",
714
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.weight": "encoders.7.fc1.weight",
715
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.bias": "encoders.7.fc2.bias",
716
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.weight": "encoders.7.fc2.weight",
717
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias": ['encoders.8.attn.to_q.bias', 'encoders.8.attn.to_k.bias', 'encoders.8.attn.to_v.bias'],
718
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight": ['encoders.8.attn.to_q.weight', 'encoders.8.attn.to_k.weight', 'encoders.8.attn.to_v.weight'],
719
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.bias": "encoders.8.attn.to_out.bias",
720
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.weight": "encoders.8.attn.to_out.weight",
721
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.bias": "encoders.8.layer_norm1.bias",
722
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.weight": "encoders.8.layer_norm1.weight",
723
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.bias": "encoders.8.layer_norm2.bias",
724
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.weight": "encoders.8.layer_norm2.weight",
725
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.bias": "encoders.8.fc1.bias",
726
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.weight": "encoders.8.fc1.weight",
727
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.bias": "encoders.8.fc2.bias",
728
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.weight": "encoders.8.fc2.weight",
729
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias": ['encoders.9.attn.to_q.bias', 'encoders.9.attn.to_k.bias', 'encoders.9.attn.to_v.bias'],
730
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight": ['encoders.9.attn.to_q.weight', 'encoders.9.attn.to_k.weight', 'encoders.9.attn.to_v.weight'],
731
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.bias": "encoders.9.attn.to_out.bias",
732
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.weight": "encoders.9.attn.to_out.weight",
733
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.bias": "encoders.9.layer_norm1.bias",
734
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.weight": "encoders.9.layer_norm1.weight",
735
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.bias": "encoders.9.layer_norm2.bias",
736
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.weight": "encoders.9.layer_norm2.weight",
737
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.bias": "encoders.9.fc1.bias",
738
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.weight": "encoders.9.fc1.weight",
739
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias": "encoders.9.fc2.bias",
740
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight": "encoders.9.fc2.weight",
741
+ "conditioner.embedders.1.model.text_projection": "text_projection.weight",
742
+ }
743
+ state_dict_ = {}
744
+ for name in state_dict:
745
+ if name in rename_dict:
746
+ param = state_dict[name]
747
+ if name == "conditioner.embedders.1.model.positional_embedding":
748
+ param = param.reshape((1, param.shape[0], param.shape[1]))
749
+ elif name == "conditioner.embedders.1.model.text_projection":
750
+ param = param.T
751
+ if isinstance(rename_dict[name], str):
752
+ state_dict_[rename_dict[name]] = param
753
+ else:
754
+ length = param.shape[0] // 3
755
+ for i, rename in enumerate(rename_dict[name]):
756
+ state_dict_[rename] = param[i*length: i*length+length]
757
+ return state_dict_
diffsynth/models/sdxl_unet.py ADDED
The diff for this file is too large to render. See raw diff
 
diffsynth/models/sdxl_vae_decoder.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .sd_vae_decoder import SDVAEDecoder, SDVAEDecoderStateDictConverter
2
+
3
+
4
+ class SDXLVAEDecoder(SDVAEDecoder):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.scaling_factor = 0.13025
8
+
9
+ def state_dict_converter(self):
10
+ return SDXLVAEDecoderStateDictConverter()
11
+
12
+
13
+ class SDXLVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
14
+ def __init__(self):
15
+ super().__init__()
diffsynth/models/sdxl_vae_encoder.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder
2
+
3
+
4
+ class SDXLVAEEncoder(SDVAEEncoder):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.scaling_factor = 0.13025
8
+
9
+ def state_dict_converter(self):
10
+ return SDXLVAEEncoderStateDictConverter()
11
+
12
+
13
+ class SDXLVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
14
+ def __init__(self):
15
+ super().__init__()
diffsynth/models/svd_image_encoder.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .sd_text_encoder import CLIPEncoderLayer
3
+
4
+
5
+ class CLIPVisionEmbeddings(torch.nn.Module):
6
+ def __init__(self, embed_dim=1280, image_size=224, patch_size=14, num_channels=3):
7
+ super().__init__()
8
+
9
+ # class_embeds (This is a fixed tensor)
10
+ self.class_embedding = torch.nn.Parameter(torch.randn(1, 1, embed_dim))
11
+
12
+ # position_embeds
13
+ self.patch_embedding = torch.nn.Conv2d(in_channels=num_channels, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size, bias=False)
14
+
15
+ # position_embeds (This is a fixed tensor)
16
+ self.position_embeds = torch.nn.Parameter(torch.zeros(1, (image_size // patch_size) ** 2 + 1, embed_dim))
17
+
18
+ def forward(self, pixel_values):
19
+ batch_size = pixel_values.shape[0]
20
+ patch_embeds = self.patch_embedding(pixel_values)
21
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
22
+ class_embeds = self.class_embedding.repeat(batch_size, 1, 1)
23
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + self.position_embeds
24
+ return embeddings
25
+
26
+
27
+ class SVDImageEncoder(torch.nn.Module):
28
+ def __init__(self, embed_dim=1280, layer_norm_eps=1e-5, num_encoder_layers=32, encoder_intermediate_size=5120, projection_dim=1024, num_heads=16, head_dim=80):
29
+ super().__init__()
30
+ self.embeddings = CLIPVisionEmbeddings(embed_dim=embed_dim)
31
+ self.pre_layernorm = torch.nn.LayerNorm(embed_dim, eps=layer_norm_eps)
32
+ self.encoders = torch.nn.ModuleList([
33
+ CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=num_heads, head_dim=head_dim, use_quick_gelu=False)
34
+ for _ in range(num_encoder_layers)])
35
+ self.post_layernorm = torch.nn.LayerNorm(embed_dim, eps=layer_norm_eps)
36
+ self.visual_projection = torch.nn.Linear(embed_dim, projection_dim, bias=False)
37
+
38
+ def forward(self, pixel_values):
39
+ embeds = self.embeddings(pixel_values)
40
+ embeds = self.pre_layernorm(embeds)
41
+ for encoder_id, encoder in enumerate(self.encoders):
42
+ embeds = encoder(embeds)
43
+ embeds = self.post_layernorm(embeds[:, 0, :])
44
+ embeds = self.visual_projection(embeds)
45
+ return embeds
46
+
47
+ def state_dict_converter(self):
48
+ return SVDImageEncoderStateDictConverter()
49
+
50
+
51
+ class SVDImageEncoderStateDictConverter:
52
+ def __init__(self):
53
+ pass
54
+
55
+ def from_diffusers(self, state_dict):
56
+ rename_dict = {
57
+ "vision_model.embeddings.patch_embedding.weight": "embeddings.patch_embedding.weight",
58
+ "vision_model.embeddings.class_embedding": "embeddings.class_embedding",
59
+ "vision_model.embeddings.position_embedding.weight": "embeddings.position_embeds",
60
+ "vision_model.pre_layrnorm.weight": "pre_layernorm.weight",
61
+ "vision_model.pre_layrnorm.bias": "pre_layernorm.bias",
62
+ "vision_model.post_layernorm.weight": "post_layernorm.weight",
63
+ "vision_model.post_layernorm.bias": "post_layernorm.bias",
64
+ "visual_projection.weight": "visual_projection.weight"
65
+ }
66
+ attn_rename_dict = {
67
+ "self_attn.q_proj": "attn.to_q",
68
+ "self_attn.k_proj": "attn.to_k",
69
+ "self_attn.v_proj": "attn.to_v",
70
+ "self_attn.out_proj": "attn.to_out",
71
+ "layer_norm1": "layer_norm1",
72
+ "layer_norm2": "layer_norm2",
73
+ "mlp.fc1": "fc1",
74
+ "mlp.fc2": "fc2",
75
+ }
76
+ state_dict_ = {}
77
+ for name in state_dict:
78
+ if name in rename_dict:
79
+ param = state_dict[name]
80
+ if name == "vision_model.embeddings.class_embedding":
81
+ param = state_dict[name].view(1, 1, -1)
82
+ elif name == "vision_model.embeddings.position_embedding.weight":
83
+ param = state_dict[name].unsqueeze(0)
84
+ state_dict_[rename_dict[name]] = param
85
+ elif name.startswith("vision_model.encoder.layers."):
86
+ param = state_dict[name]
87
+ names = name.split(".")
88
+ layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
89
+ name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
90
+ state_dict_[name_] = param
91
+ return state_dict_
92
+
93
+ def from_civitai(self, state_dict):
94
+ rename_dict = {
95
+ "conditioner.embedders.0.open_clip.model.visual.class_embedding": "embeddings.class_embedding",
96
+ "conditioner.embedders.0.open_clip.model.visual.conv1.weight": "embeddings.patch_embedding.weight",
97
+ "conditioner.embedders.0.open_clip.model.visual.ln_post.bias": "post_layernorm.bias",
98
+ "conditioner.embedders.0.open_clip.model.visual.ln_post.weight": "post_layernorm.weight",
99
+ "conditioner.embedders.0.open_clip.model.visual.ln_pre.bias": "pre_layernorm.bias",
100
+ "conditioner.embedders.0.open_clip.model.visual.ln_pre.weight": "pre_layernorm.weight",
101
+ "conditioner.embedders.0.open_clip.model.visual.positional_embedding": "embeddings.position_embeds",
102
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.in_proj_bias": ['encoders.0.attn.to_q.bias', 'encoders.0.attn.to_k.bias', 'encoders.0.attn.to_v.bias'],
103
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.in_proj_weight": ['encoders.0.attn.to_q.weight', 'encoders.0.attn.to_k.weight', 'encoders.0.attn.to_v.weight'],
104
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.out_proj.bias": "encoders.0.attn.to_out.bias",
105
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.out_proj.weight": "encoders.0.attn.to_out.weight",
106
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_1.bias": "encoders.0.layer_norm1.bias",
107
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_1.weight": "encoders.0.layer_norm1.weight",
108
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_2.bias": "encoders.0.layer_norm2.bias",
109
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_2.weight": "encoders.0.layer_norm2.weight",
110
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_fc.bias": "encoders.0.fc1.bias",
111
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_fc.weight": "encoders.0.fc1.weight",
112
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_proj.bias": "encoders.0.fc2.bias",
113
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_proj.weight": "encoders.0.fc2.weight",
114
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.in_proj_bias": ['encoders.1.attn.to_q.bias', 'encoders.1.attn.to_k.bias', 'encoders.1.attn.to_v.bias'],
115
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.in_proj_weight": ['encoders.1.attn.to_q.weight', 'encoders.1.attn.to_k.weight', 'encoders.1.attn.to_v.weight'],
116
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.out_proj.bias": "encoders.1.attn.to_out.bias",
117
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.out_proj.weight": "encoders.1.attn.to_out.weight",
118
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_1.bias": "encoders.1.layer_norm1.bias",
119
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_1.weight": "encoders.1.layer_norm1.weight",
120
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_2.bias": "encoders.1.layer_norm2.bias",
121
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_2.weight": "encoders.1.layer_norm2.weight",
122
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_fc.bias": "encoders.1.fc1.bias",
123
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_fc.weight": "encoders.1.fc1.weight",
124
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_proj.bias": "encoders.1.fc2.bias",
125
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_proj.weight": "encoders.1.fc2.weight",
126
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.in_proj_bias": ['encoders.10.attn.to_q.bias', 'encoders.10.attn.to_k.bias', 'encoders.10.attn.to_v.bias'],
127
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.in_proj_weight": ['encoders.10.attn.to_q.weight', 'encoders.10.attn.to_k.weight', 'encoders.10.attn.to_v.weight'],
128
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.out_proj.bias": "encoders.10.attn.to_out.bias",
129
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.out_proj.weight": "encoders.10.attn.to_out.weight",
130
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_1.bias": "encoders.10.layer_norm1.bias",
131
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_1.weight": "encoders.10.layer_norm1.weight",
132
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_2.bias": "encoders.10.layer_norm2.bias",
133
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_2.weight": "encoders.10.layer_norm2.weight",
134
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_fc.bias": "encoders.10.fc1.bias",
135
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_fc.weight": "encoders.10.fc1.weight",
136
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_proj.bias": "encoders.10.fc2.bias",
137
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_proj.weight": "encoders.10.fc2.weight",
138
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.in_proj_bias": ['encoders.11.attn.to_q.bias', 'encoders.11.attn.to_k.bias', 'encoders.11.attn.to_v.bias'],
139
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.in_proj_weight": ['encoders.11.attn.to_q.weight', 'encoders.11.attn.to_k.weight', 'encoders.11.attn.to_v.weight'],
140
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.out_proj.bias": "encoders.11.attn.to_out.bias",
141
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.out_proj.weight": "encoders.11.attn.to_out.weight",
142
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_1.bias": "encoders.11.layer_norm1.bias",
143
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_1.weight": "encoders.11.layer_norm1.weight",
144
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_2.bias": "encoders.11.layer_norm2.bias",
145
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_2.weight": "encoders.11.layer_norm2.weight",
146
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_fc.bias": "encoders.11.fc1.bias",
147
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_fc.weight": "encoders.11.fc1.weight",
148
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_proj.bias": "encoders.11.fc2.bias",
149
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_proj.weight": "encoders.11.fc2.weight",
150
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.in_proj_bias": ['encoders.12.attn.to_q.bias', 'encoders.12.attn.to_k.bias', 'encoders.12.attn.to_v.bias'],
151
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.in_proj_weight": ['encoders.12.attn.to_q.weight', 'encoders.12.attn.to_k.weight', 'encoders.12.attn.to_v.weight'],
152
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.out_proj.bias": "encoders.12.attn.to_out.bias",
153
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.out_proj.weight": "encoders.12.attn.to_out.weight",
154
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_1.bias": "encoders.12.layer_norm1.bias",
155
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_1.weight": "encoders.12.layer_norm1.weight",
156
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_2.bias": "encoders.12.layer_norm2.bias",
157
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_2.weight": "encoders.12.layer_norm2.weight",
158
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_fc.bias": "encoders.12.fc1.bias",
159
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_fc.weight": "encoders.12.fc1.weight",
160
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_proj.bias": "encoders.12.fc2.bias",
161
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_proj.weight": "encoders.12.fc2.weight",
162
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.in_proj_bias": ['encoders.13.attn.to_q.bias', 'encoders.13.attn.to_k.bias', 'encoders.13.attn.to_v.bias'],
163
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.in_proj_weight": ['encoders.13.attn.to_q.weight', 'encoders.13.attn.to_k.weight', 'encoders.13.attn.to_v.weight'],
164
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.out_proj.bias": "encoders.13.attn.to_out.bias",
165
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.out_proj.weight": "encoders.13.attn.to_out.weight",
166
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_1.bias": "encoders.13.layer_norm1.bias",
167
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_1.weight": "encoders.13.layer_norm1.weight",
168
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_2.bias": "encoders.13.layer_norm2.bias",
169
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_2.weight": "encoders.13.layer_norm2.weight",
170
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_fc.bias": "encoders.13.fc1.bias",
171
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_fc.weight": "encoders.13.fc1.weight",
172
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_proj.bias": "encoders.13.fc2.bias",
173
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_proj.weight": "encoders.13.fc2.weight",
174
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.in_proj_bias": ['encoders.14.attn.to_q.bias', 'encoders.14.attn.to_k.bias', 'encoders.14.attn.to_v.bias'],
175
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.in_proj_weight": ['encoders.14.attn.to_q.weight', 'encoders.14.attn.to_k.weight', 'encoders.14.attn.to_v.weight'],
176
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.out_proj.bias": "encoders.14.attn.to_out.bias",
177
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.out_proj.weight": "encoders.14.attn.to_out.weight",
178
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_1.bias": "encoders.14.layer_norm1.bias",
179
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_1.weight": "encoders.14.layer_norm1.weight",
180
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_2.bias": "encoders.14.layer_norm2.bias",
181
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_2.weight": "encoders.14.layer_norm2.weight",
182
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_fc.bias": "encoders.14.fc1.bias",
183
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_fc.weight": "encoders.14.fc1.weight",
184
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_proj.bias": "encoders.14.fc2.bias",
185
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_proj.weight": "encoders.14.fc2.weight",
186
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.in_proj_bias": ['encoders.15.attn.to_q.bias', 'encoders.15.attn.to_k.bias', 'encoders.15.attn.to_v.bias'],
187
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.in_proj_weight": ['encoders.15.attn.to_q.weight', 'encoders.15.attn.to_k.weight', 'encoders.15.attn.to_v.weight'],
188
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.out_proj.bias": "encoders.15.attn.to_out.bias",
189
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.out_proj.weight": "encoders.15.attn.to_out.weight",
190
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_1.bias": "encoders.15.layer_norm1.bias",
191
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_1.weight": "encoders.15.layer_norm1.weight",
192
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_2.bias": "encoders.15.layer_norm2.bias",
193
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_2.weight": "encoders.15.layer_norm2.weight",
194
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_fc.bias": "encoders.15.fc1.bias",
195
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_fc.weight": "encoders.15.fc1.weight",
196
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_proj.bias": "encoders.15.fc2.bias",
197
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_proj.weight": "encoders.15.fc2.weight",
198
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.in_proj_bias": ['encoders.16.attn.to_q.bias', 'encoders.16.attn.to_k.bias', 'encoders.16.attn.to_v.bias'],
199
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.in_proj_weight": ['encoders.16.attn.to_q.weight', 'encoders.16.attn.to_k.weight', 'encoders.16.attn.to_v.weight'],
200
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.out_proj.bias": "encoders.16.attn.to_out.bias",
201
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.out_proj.weight": "encoders.16.attn.to_out.weight",
202
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_1.bias": "encoders.16.layer_norm1.bias",
203
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_1.weight": "encoders.16.layer_norm1.weight",
204
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_2.bias": "encoders.16.layer_norm2.bias",
205
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_2.weight": "encoders.16.layer_norm2.weight",
206
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_fc.bias": "encoders.16.fc1.bias",
207
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_fc.weight": "encoders.16.fc1.weight",
208
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_proj.bias": "encoders.16.fc2.bias",
209
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_proj.weight": "encoders.16.fc2.weight",
210
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.in_proj_bias": ['encoders.17.attn.to_q.bias', 'encoders.17.attn.to_k.bias', 'encoders.17.attn.to_v.bias'],
211
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.in_proj_weight": ['encoders.17.attn.to_q.weight', 'encoders.17.attn.to_k.weight', 'encoders.17.attn.to_v.weight'],
212
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.out_proj.bias": "encoders.17.attn.to_out.bias",
213
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.out_proj.weight": "encoders.17.attn.to_out.weight",
214
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_1.bias": "encoders.17.layer_norm1.bias",
215
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_1.weight": "encoders.17.layer_norm1.weight",
216
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_2.bias": "encoders.17.layer_norm2.bias",
217
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_2.weight": "encoders.17.layer_norm2.weight",
218
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_fc.bias": "encoders.17.fc1.bias",
219
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_fc.weight": "encoders.17.fc1.weight",
220
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_proj.bias": "encoders.17.fc2.bias",
221
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_proj.weight": "encoders.17.fc2.weight",
222
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.in_proj_bias": ['encoders.18.attn.to_q.bias', 'encoders.18.attn.to_k.bias', 'encoders.18.attn.to_v.bias'],
223
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.in_proj_weight": ['encoders.18.attn.to_q.weight', 'encoders.18.attn.to_k.weight', 'encoders.18.attn.to_v.weight'],
224
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.out_proj.bias": "encoders.18.attn.to_out.bias",
225
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.out_proj.weight": "encoders.18.attn.to_out.weight",
226
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_1.bias": "encoders.18.layer_norm1.bias",
227
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_1.weight": "encoders.18.layer_norm1.weight",
228
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_2.bias": "encoders.18.layer_norm2.bias",
229
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_2.weight": "encoders.18.layer_norm2.weight",
230
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_fc.bias": "encoders.18.fc1.bias",
231
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_fc.weight": "encoders.18.fc1.weight",
232
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_proj.bias": "encoders.18.fc2.bias",
233
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_proj.weight": "encoders.18.fc2.weight",
234
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.in_proj_bias": ['encoders.19.attn.to_q.bias', 'encoders.19.attn.to_k.bias', 'encoders.19.attn.to_v.bias'],
235
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.in_proj_weight": ['encoders.19.attn.to_q.weight', 'encoders.19.attn.to_k.weight', 'encoders.19.attn.to_v.weight'],
236
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.out_proj.bias": "encoders.19.attn.to_out.bias",
237
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.out_proj.weight": "encoders.19.attn.to_out.weight",
238
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_1.bias": "encoders.19.layer_norm1.bias",
239
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_1.weight": "encoders.19.layer_norm1.weight",
240
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_2.bias": "encoders.19.layer_norm2.bias",
241
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_2.weight": "encoders.19.layer_norm2.weight",
242
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_fc.bias": "encoders.19.fc1.bias",
243
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_fc.weight": "encoders.19.fc1.weight",
244
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_proj.bias": "encoders.19.fc2.bias",
245
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_proj.weight": "encoders.19.fc2.weight",
246
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.in_proj_bias": ['encoders.2.attn.to_q.bias', 'encoders.2.attn.to_k.bias', 'encoders.2.attn.to_v.bias'],
247
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.in_proj_weight": ['encoders.2.attn.to_q.weight', 'encoders.2.attn.to_k.weight', 'encoders.2.attn.to_v.weight'],
248
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.out_proj.bias": "encoders.2.attn.to_out.bias",
249
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.out_proj.weight": "encoders.2.attn.to_out.weight",
250
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_1.bias": "encoders.2.layer_norm1.bias",
251
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_1.weight": "encoders.2.layer_norm1.weight",
252
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_2.bias": "encoders.2.layer_norm2.bias",
253
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_2.weight": "encoders.2.layer_norm2.weight",
254
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_fc.bias": "encoders.2.fc1.bias",
255
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_fc.weight": "encoders.2.fc1.weight",
256
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_proj.bias": "encoders.2.fc2.bias",
257
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_proj.weight": "encoders.2.fc2.weight",
258
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.in_proj_bias": ['encoders.20.attn.to_q.bias', 'encoders.20.attn.to_k.bias', 'encoders.20.attn.to_v.bias'],
259
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.in_proj_weight": ['encoders.20.attn.to_q.weight', 'encoders.20.attn.to_k.weight', 'encoders.20.attn.to_v.weight'],
260
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.out_proj.bias": "encoders.20.attn.to_out.bias",
261
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.out_proj.weight": "encoders.20.attn.to_out.weight",
262
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_1.bias": "encoders.20.layer_norm1.bias",
263
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_1.weight": "encoders.20.layer_norm1.weight",
264
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_2.bias": "encoders.20.layer_norm2.bias",
265
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_2.weight": "encoders.20.layer_norm2.weight",
266
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_fc.bias": "encoders.20.fc1.bias",
267
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_fc.weight": "encoders.20.fc1.weight",
268
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_proj.bias": "encoders.20.fc2.bias",
269
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_proj.weight": "encoders.20.fc2.weight",
270
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.in_proj_bias": ['encoders.21.attn.to_q.bias', 'encoders.21.attn.to_k.bias', 'encoders.21.attn.to_v.bias'],
271
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.in_proj_weight": ['encoders.21.attn.to_q.weight', 'encoders.21.attn.to_k.weight', 'encoders.21.attn.to_v.weight'],
272
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.out_proj.bias": "encoders.21.attn.to_out.bias",
273
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.out_proj.weight": "encoders.21.attn.to_out.weight",
274
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_1.bias": "encoders.21.layer_norm1.bias",
275
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_1.weight": "encoders.21.layer_norm1.weight",
276
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_2.bias": "encoders.21.layer_norm2.bias",
277
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_2.weight": "encoders.21.layer_norm2.weight",
278
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_fc.bias": "encoders.21.fc1.bias",
279
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_fc.weight": "encoders.21.fc1.weight",
280
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_proj.bias": "encoders.21.fc2.bias",
281
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_proj.weight": "encoders.21.fc2.weight",
282
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.in_proj_bias": ['encoders.22.attn.to_q.bias', 'encoders.22.attn.to_k.bias', 'encoders.22.attn.to_v.bias'],
283
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.in_proj_weight": ['encoders.22.attn.to_q.weight', 'encoders.22.attn.to_k.weight', 'encoders.22.attn.to_v.weight'],
284
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.out_proj.bias": "encoders.22.attn.to_out.bias",
285
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.out_proj.weight": "encoders.22.attn.to_out.weight",
286
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_1.bias": "encoders.22.layer_norm1.bias",
287
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_1.weight": "encoders.22.layer_norm1.weight",
288
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_2.bias": "encoders.22.layer_norm2.bias",
289
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_2.weight": "encoders.22.layer_norm2.weight",
290
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_fc.bias": "encoders.22.fc1.bias",
291
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_fc.weight": "encoders.22.fc1.weight",
292
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_proj.bias": "encoders.22.fc2.bias",
293
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_proj.weight": "encoders.22.fc2.weight",
294
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.in_proj_bias": ['encoders.23.attn.to_q.bias', 'encoders.23.attn.to_k.bias', 'encoders.23.attn.to_v.bias'],
295
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.in_proj_weight": ['encoders.23.attn.to_q.weight', 'encoders.23.attn.to_k.weight', 'encoders.23.attn.to_v.weight'],
296
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.out_proj.bias": "encoders.23.attn.to_out.bias",
297
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.out_proj.weight": "encoders.23.attn.to_out.weight",
298
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_1.bias": "encoders.23.layer_norm1.bias",
299
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_1.weight": "encoders.23.layer_norm1.weight",
300
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_2.bias": "encoders.23.layer_norm2.bias",
301
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_2.weight": "encoders.23.layer_norm2.weight",
302
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_fc.bias": "encoders.23.fc1.bias",
303
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_fc.weight": "encoders.23.fc1.weight",
304
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_proj.bias": "encoders.23.fc2.bias",
305
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_proj.weight": "encoders.23.fc2.weight",
306
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.in_proj_bias": ['encoders.24.attn.to_q.bias', 'encoders.24.attn.to_k.bias', 'encoders.24.attn.to_v.bias'],
307
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.in_proj_weight": ['encoders.24.attn.to_q.weight', 'encoders.24.attn.to_k.weight', 'encoders.24.attn.to_v.weight'],
308
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.out_proj.bias": "encoders.24.attn.to_out.bias",
309
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.out_proj.weight": "encoders.24.attn.to_out.weight",
310
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_1.bias": "encoders.24.layer_norm1.bias",
311
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_1.weight": "encoders.24.layer_norm1.weight",
312
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_2.bias": "encoders.24.layer_norm2.bias",
313
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_2.weight": "encoders.24.layer_norm2.weight",
314
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_fc.bias": "encoders.24.fc1.bias",
315
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_fc.weight": "encoders.24.fc1.weight",
316
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_proj.bias": "encoders.24.fc2.bias",
317
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_proj.weight": "encoders.24.fc2.weight",
318
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.in_proj_bias": ['encoders.25.attn.to_q.bias', 'encoders.25.attn.to_k.bias', 'encoders.25.attn.to_v.bias'],
319
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.in_proj_weight": ['encoders.25.attn.to_q.weight', 'encoders.25.attn.to_k.weight', 'encoders.25.attn.to_v.weight'],
320
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.out_proj.bias": "encoders.25.attn.to_out.bias",
321
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.out_proj.weight": "encoders.25.attn.to_out.weight",
322
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_1.bias": "encoders.25.layer_norm1.bias",
323
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_1.weight": "encoders.25.layer_norm1.weight",
324
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_2.bias": "encoders.25.layer_norm2.bias",
325
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_2.weight": "encoders.25.layer_norm2.weight",
326
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_fc.bias": "encoders.25.fc1.bias",
327
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_fc.weight": "encoders.25.fc1.weight",
328
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_proj.bias": "encoders.25.fc2.bias",
329
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_proj.weight": "encoders.25.fc2.weight",
330
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.in_proj_bias": ['encoders.26.attn.to_q.bias', 'encoders.26.attn.to_k.bias', 'encoders.26.attn.to_v.bias'],
331
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.in_proj_weight": ['encoders.26.attn.to_q.weight', 'encoders.26.attn.to_k.weight', 'encoders.26.attn.to_v.weight'],
332
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.out_proj.bias": "encoders.26.attn.to_out.bias",
333
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.out_proj.weight": "encoders.26.attn.to_out.weight",
334
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_1.bias": "encoders.26.layer_norm1.bias",
335
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_1.weight": "encoders.26.layer_norm1.weight",
336
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_2.bias": "encoders.26.layer_norm2.bias",
337
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_2.weight": "encoders.26.layer_norm2.weight",
338
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_fc.bias": "encoders.26.fc1.bias",
339
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_fc.weight": "encoders.26.fc1.weight",
340
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_proj.bias": "encoders.26.fc2.bias",
341
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_proj.weight": "encoders.26.fc2.weight",
342
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.in_proj_bias": ['encoders.27.attn.to_q.bias', 'encoders.27.attn.to_k.bias', 'encoders.27.attn.to_v.bias'],
343
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.in_proj_weight": ['encoders.27.attn.to_q.weight', 'encoders.27.attn.to_k.weight', 'encoders.27.attn.to_v.weight'],
344
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.out_proj.bias": "encoders.27.attn.to_out.bias",
345
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.out_proj.weight": "encoders.27.attn.to_out.weight",
346
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_1.bias": "encoders.27.layer_norm1.bias",
347
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_1.weight": "encoders.27.layer_norm1.weight",
348
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_2.bias": "encoders.27.layer_norm2.bias",
349
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_2.weight": "encoders.27.layer_norm2.weight",
350
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_fc.bias": "encoders.27.fc1.bias",
351
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_fc.weight": "encoders.27.fc1.weight",
352
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_proj.bias": "encoders.27.fc2.bias",
353
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_proj.weight": "encoders.27.fc2.weight",
354
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.in_proj_bias": ['encoders.28.attn.to_q.bias', 'encoders.28.attn.to_k.bias', 'encoders.28.attn.to_v.bias'],
355
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.in_proj_weight": ['encoders.28.attn.to_q.weight', 'encoders.28.attn.to_k.weight', 'encoders.28.attn.to_v.weight'],
356
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.out_proj.bias": "encoders.28.attn.to_out.bias",
357
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.out_proj.weight": "encoders.28.attn.to_out.weight",
358
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_1.bias": "encoders.28.layer_norm1.bias",
359
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_1.weight": "encoders.28.layer_norm1.weight",
360
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_2.bias": "encoders.28.layer_norm2.bias",
361
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_2.weight": "encoders.28.layer_norm2.weight",
362
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_fc.bias": "encoders.28.fc1.bias",
363
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_fc.weight": "encoders.28.fc1.weight",
364
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_proj.bias": "encoders.28.fc2.bias",
365
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_proj.weight": "encoders.28.fc2.weight",
366
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.in_proj_bias": ['encoders.29.attn.to_q.bias', 'encoders.29.attn.to_k.bias', 'encoders.29.attn.to_v.bias'],
367
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.in_proj_weight": ['encoders.29.attn.to_q.weight', 'encoders.29.attn.to_k.weight', 'encoders.29.attn.to_v.weight'],
368
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.out_proj.bias": "encoders.29.attn.to_out.bias",
369
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.out_proj.weight": "encoders.29.attn.to_out.weight",
370
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_1.bias": "encoders.29.layer_norm1.bias",
371
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_1.weight": "encoders.29.layer_norm1.weight",
372
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_2.bias": "encoders.29.layer_norm2.bias",
373
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_2.weight": "encoders.29.layer_norm2.weight",
374
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_fc.bias": "encoders.29.fc1.bias",
375
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_fc.weight": "encoders.29.fc1.weight",
376
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_proj.bias": "encoders.29.fc2.bias",
377
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_proj.weight": "encoders.29.fc2.weight",
378
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.in_proj_bias": ['encoders.3.attn.to_q.bias', 'encoders.3.attn.to_k.bias', 'encoders.3.attn.to_v.bias'],
379
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.in_proj_weight": ['encoders.3.attn.to_q.weight', 'encoders.3.attn.to_k.weight', 'encoders.3.attn.to_v.weight'],
380
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.out_proj.bias": "encoders.3.attn.to_out.bias",
381
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.out_proj.weight": "encoders.3.attn.to_out.weight",
382
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_1.bias": "encoders.3.layer_norm1.bias",
383
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_1.weight": "encoders.3.layer_norm1.weight",
384
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_2.bias": "encoders.3.layer_norm2.bias",
385
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_2.weight": "encoders.3.layer_norm2.weight",
386
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_fc.bias": "encoders.3.fc1.bias",
387
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_fc.weight": "encoders.3.fc1.weight",
388
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_proj.bias": "encoders.3.fc2.bias",
389
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_proj.weight": "encoders.3.fc2.weight",
390
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.in_proj_bias": ['encoders.30.attn.to_q.bias', 'encoders.30.attn.to_k.bias', 'encoders.30.attn.to_v.bias'],
391
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.in_proj_weight": ['encoders.30.attn.to_q.weight', 'encoders.30.attn.to_k.weight', 'encoders.30.attn.to_v.weight'],
392
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.out_proj.bias": "encoders.30.attn.to_out.bias",
393
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.out_proj.weight": "encoders.30.attn.to_out.weight",
394
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_1.bias": "encoders.30.layer_norm1.bias",
395
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_1.weight": "encoders.30.layer_norm1.weight",
396
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_2.bias": "encoders.30.layer_norm2.bias",
397
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_2.weight": "encoders.30.layer_norm2.weight",
398
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_fc.bias": "encoders.30.fc1.bias",
399
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_fc.weight": "encoders.30.fc1.weight",
400
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_proj.bias": "encoders.30.fc2.bias",
401
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_proj.weight": "encoders.30.fc2.weight",
402
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.in_proj_bias": ['encoders.31.attn.to_q.bias', 'encoders.31.attn.to_k.bias', 'encoders.31.attn.to_v.bias'],
403
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.in_proj_weight": ['encoders.31.attn.to_q.weight', 'encoders.31.attn.to_k.weight', 'encoders.31.attn.to_v.weight'],
404
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.out_proj.bias": "encoders.31.attn.to_out.bias",
405
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.out_proj.weight": "encoders.31.attn.to_out.weight",
406
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_1.bias": "encoders.31.layer_norm1.bias",
407
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_1.weight": "encoders.31.layer_norm1.weight",
408
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_2.bias": "encoders.31.layer_norm2.bias",
409
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_2.weight": "encoders.31.layer_norm2.weight",
410
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_fc.bias": "encoders.31.fc1.bias",
411
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_fc.weight": "encoders.31.fc1.weight",
412
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_proj.bias": "encoders.31.fc2.bias",
413
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_proj.weight": "encoders.31.fc2.weight",
414
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.in_proj_bias": ['encoders.4.attn.to_q.bias', 'encoders.4.attn.to_k.bias', 'encoders.4.attn.to_v.bias'],
415
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.in_proj_weight": ['encoders.4.attn.to_q.weight', 'encoders.4.attn.to_k.weight', 'encoders.4.attn.to_v.weight'],
416
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.out_proj.bias": "encoders.4.attn.to_out.bias",
417
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.out_proj.weight": "encoders.4.attn.to_out.weight",
418
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_1.bias": "encoders.4.layer_norm1.bias",
419
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_1.weight": "encoders.4.layer_norm1.weight",
420
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_2.bias": "encoders.4.layer_norm2.bias",
421
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_2.weight": "encoders.4.layer_norm2.weight",
422
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_fc.bias": "encoders.4.fc1.bias",
423
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_fc.weight": "encoders.4.fc1.weight",
424
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_proj.bias": "encoders.4.fc2.bias",
425
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_proj.weight": "encoders.4.fc2.weight",
426
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.in_proj_bias": ['encoders.5.attn.to_q.bias', 'encoders.5.attn.to_k.bias', 'encoders.5.attn.to_v.bias'],
427
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.in_proj_weight": ['encoders.5.attn.to_q.weight', 'encoders.5.attn.to_k.weight', 'encoders.5.attn.to_v.weight'],
428
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.out_proj.bias": "encoders.5.attn.to_out.bias",
429
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.out_proj.weight": "encoders.5.attn.to_out.weight",
430
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_1.bias": "encoders.5.layer_norm1.bias",
431
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_1.weight": "encoders.5.layer_norm1.weight",
432
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_2.bias": "encoders.5.layer_norm2.bias",
433
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_2.weight": "encoders.5.layer_norm2.weight",
434
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_fc.bias": "encoders.5.fc1.bias",
435
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_fc.weight": "encoders.5.fc1.weight",
436
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_proj.bias": "encoders.5.fc2.bias",
437
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_proj.weight": "encoders.5.fc2.weight",
438
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.in_proj_bias": ['encoders.6.attn.to_q.bias', 'encoders.6.attn.to_k.bias', 'encoders.6.attn.to_v.bias'],
439
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.in_proj_weight": ['encoders.6.attn.to_q.weight', 'encoders.6.attn.to_k.weight', 'encoders.6.attn.to_v.weight'],
440
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.out_proj.bias": "encoders.6.attn.to_out.bias",
441
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.out_proj.weight": "encoders.6.attn.to_out.weight",
442
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_1.bias": "encoders.6.layer_norm1.bias",
443
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_1.weight": "encoders.6.layer_norm1.weight",
444
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_2.bias": "encoders.6.layer_norm2.bias",
445
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_2.weight": "encoders.6.layer_norm2.weight",
446
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_fc.bias": "encoders.6.fc1.bias",
447
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_fc.weight": "encoders.6.fc1.weight",
448
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_proj.bias": "encoders.6.fc2.bias",
449
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_proj.weight": "encoders.6.fc2.weight",
450
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.in_proj_bias": ['encoders.7.attn.to_q.bias', 'encoders.7.attn.to_k.bias', 'encoders.7.attn.to_v.bias'],
451
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.in_proj_weight": ['encoders.7.attn.to_q.weight', 'encoders.7.attn.to_k.weight', 'encoders.7.attn.to_v.weight'],
452
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.out_proj.bias": "encoders.7.attn.to_out.bias",
453
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.out_proj.weight": "encoders.7.attn.to_out.weight",
454
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_1.bias": "encoders.7.layer_norm1.bias",
455
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_1.weight": "encoders.7.layer_norm1.weight",
456
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_2.bias": "encoders.7.layer_norm2.bias",
457
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_2.weight": "encoders.7.layer_norm2.weight",
458
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_fc.bias": "encoders.7.fc1.bias",
459
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_fc.weight": "encoders.7.fc1.weight",
460
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_proj.bias": "encoders.7.fc2.bias",
461
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_proj.weight": "encoders.7.fc2.weight",
462
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.in_proj_bias": ['encoders.8.attn.to_q.bias', 'encoders.8.attn.to_k.bias', 'encoders.8.attn.to_v.bias'],
463
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.in_proj_weight": ['encoders.8.attn.to_q.weight', 'encoders.8.attn.to_k.weight', 'encoders.8.attn.to_v.weight'],
464
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.out_proj.bias": "encoders.8.attn.to_out.bias",
465
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.out_proj.weight": "encoders.8.attn.to_out.weight",
466
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_1.bias": "encoders.8.layer_norm1.bias",
467
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_1.weight": "encoders.8.layer_norm1.weight",
468
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_2.bias": "encoders.8.layer_norm2.bias",
469
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_2.weight": "encoders.8.layer_norm2.weight",
470
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_fc.bias": "encoders.8.fc1.bias",
471
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_fc.weight": "encoders.8.fc1.weight",
472
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_proj.bias": "encoders.8.fc2.bias",
473
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_proj.weight": "encoders.8.fc2.weight",
474
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.in_proj_bias": ['encoders.9.attn.to_q.bias', 'encoders.9.attn.to_k.bias', 'encoders.9.attn.to_v.bias'],
475
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.in_proj_weight": ['encoders.9.attn.to_q.weight', 'encoders.9.attn.to_k.weight', 'encoders.9.attn.to_v.weight'],
476
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.out_proj.bias": "encoders.9.attn.to_out.bias",
477
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.out_proj.weight": "encoders.9.attn.to_out.weight",
478
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_1.bias": "encoders.9.layer_norm1.bias",
479
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_1.weight": "encoders.9.layer_norm1.weight",
480
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_2.bias": "encoders.9.layer_norm2.bias",
481
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_2.weight": "encoders.9.layer_norm2.weight",
482
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_fc.bias": "encoders.9.fc1.bias",
483
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_fc.weight": "encoders.9.fc1.weight",
484
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_proj.bias": "encoders.9.fc2.bias",
485
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_proj.weight": "encoders.9.fc2.weight",
486
+ "conditioner.embedders.0.open_clip.model.visual.proj": "visual_projection.weight",
487
+ }
488
+ state_dict_ = {}
489
+ for name in state_dict:
490
+ if name in rename_dict:
491
+ param = state_dict[name]
492
+ if name == "conditioner.embedders.0.open_clip.model.visual.class_embedding":
493
+ param = param.reshape((1, 1, param.shape[0]))
494
+ elif name == "conditioner.embedders.0.open_clip.model.visual.positional_embedding":
495
+ param = param.reshape((1, param.shape[0], param.shape[1]))
496
+ elif name == "conditioner.embedders.0.open_clip.model.visual.proj":
497
+ param = param.T
498
+ if isinstance(rename_dict[name], str):
499
+ state_dict_[rename_dict[name]] = param
500
+ else:
501
+ length = param.shape[0] // 3
502
+ for i, rename in enumerate(rename_dict[name]):
503
+ state_dict_[rename] = param[i*length: i*length+length]
504
+ return state_dict_
diffsynth/models/svd_unet.py ADDED
The diff for this file is too large to render. See raw diff
 
diffsynth/models/svd_vae_decoder.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .attention import Attention
3
+ from .sd_unet import ResnetBlock, UpSampler
4
+ from .tiler import TileWorker
5
+ from einops import rearrange, repeat
6
+
7
+
8
+ class VAEAttentionBlock(torch.nn.Module):
9
+
10
+ def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
11
+ super().__init__()
12
+ inner_dim = num_attention_heads * attention_head_dim
13
+
14
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
15
+
16
+ self.transformer_blocks = torch.nn.ModuleList([
17
+ Attention(
18
+ inner_dim,
19
+ num_attention_heads,
20
+ attention_head_dim,
21
+ bias_q=True,
22
+ bias_kv=True,
23
+ bias_out=True
24
+ )
25
+ for d in range(num_layers)
26
+ ])
27
+
28
+ def forward(self, hidden_states, time_emb, text_emb, res_stack):
29
+ batch, _, height, width = hidden_states.shape
30
+ residual = hidden_states
31
+
32
+ hidden_states = self.norm(hidden_states)
33
+ inner_dim = hidden_states.shape[1]
34
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
35
+
36
+ for block in self.transformer_blocks:
37
+ hidden_states = block(hidden_states)
38
+
39
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
40
+ hidden_states = hidden_states + residual
41
+
42
+ return hidden_states, time_emb, text_emb, res_stack
43
+
44
+
45
+ class TemporalResnetBlock(torch.nn.Module):
46
+
47
+ def __init__(self, in_channels, out_channels, groups=32, eps=1e-5):
48
+ super().__init__()
49
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
50
+ self.conv1 = torch.nn.Conv3d(in_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0))
51
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
52
+ self.conv2 = torch.nn.Conv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0))
53
+ self.nonlinearity = torch.nn.SiLU()
54
+ self.mix_factor = torch.nn.Parameter(torch.Tensor([0.5]))
55
+
56
+ def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
57
+ x_spatial = hidden_states
58
+ x = rearrange(hidden_states, "T C H W -> 1 C T H W")
59
+ x = self.norm1(x)
60
+ x = self.nonlinearity(x)
61
+ x = self.conv1(x)
62
+ x = self.norm2(x)
63
+ x = self.nonlinearity(x)
64
+ x = self.conv2(x)
65
+ x_temporal = hidden_states + x[0].permute(1, 0, 2, 3)
66
+ alpha = torch.sigmoid(self.mix_factor)
67
+ hidden_states = alpha * x_temporal + (1 - alpha) * x_spatial
68
+ return hidden_states, time_emb, text_emb, res_stack
69
+
70
+
71
+ class SVDVAEDecoder(torch.nn.Module):
72
+ def __init__(self):
73
+ super().__init__()
74
+ self.scaling_factor = 0.18215
75
+ self.conv_in = torch.nn.Conv2d(4, 512, kernel_size=3, padding=1)
76
+
77
+ self.blocks = torch.nn.ModuleList([
78
+ # UNetMidBlock
79
+ ResnetBlock(512, 512, eps=1e-6),
80
+ TemporalResnetBlock(512, 512, eps=1e-6),
81
+ VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
82
+ ResnetBlock(512, 512, eps=1e-6),
83
+ TemporalResnetBlock(512, 512, eps=1e-6),
84
+ # UpDecoderBlock
85
+ ResnetBlock(512, 512, eps=1e-6),
86
+ TemporalResnetBlock(512, 512, eps=1e-6),
87
+ ResnetBlock(512, 512, eps=1e-6),
88
+ TemporalResnetBlock(512, 512, eps=1e-6),
89
+ ResnetBlock(512, 512, eps=1e-6),
90
+ TemporalResnetBlock(512, 512, eps=1e-6),
91
+ UpSampler(512),
92
+ # UpDecoderBlock
93
+ ResnetBlock(512, 512, eps=1e-6),
94
+ TemporalResnetBlock(512, 512, eps=1e-6),
95
+ ResnetBlock(512, 512, eps=1e-6),
96
+ TemporalResnetBlock(512, 512, eps=1e-6),
97
+ ResnetBlock(512, 512, eps=1e-6),
98
+ TemporalResnetBlock(512, 512, eps=1e-6),
99
+ UpSampler(512),
100
+ # UpDecoderBlock
101
+ ResnetBlock(512, 256, eps=1e-6),
102
+ TemporalResnetBlock(256, 256, eps=1e-6),
103
+ ResnetBlock(256, 256, eps=1e-6),
104
+ TemporalResnetBlock(256, 256, eps=1e-6),
105
+ ResnetBlock(256, 256, eps=1e-6),
106
+ TemporalResnetBlock(256, 256, eps=1e-6),
107
+ UpSampler(256),
108
+ # UpDecoderBlock
109
+ ResnetBlock(256, 128, eps=1e-6),
110
+ TemporalResnetBlock(128, 128, eps=1e-6),
111
+ ResnetBlock(128, 128, eps=1e-6),
112
+ TemporalResnetBlock(128, 128, eps=1e-6),
113
+ ResnetBlock(128, 128, eps=1e-6),
114
+ TemporalResnetBlock(128, 128, eps=1e-6),
115
+ ])
116
+
117
+ self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-5)
118
+ self.conv_act = torch.nn.SiLU()
119
+ self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
120
+ self.time_conv_out = torch.nn.Conv3d(3, 3, kernel_size=(3, 1, 1), padding=(1, 0, 0))
121
+
122
+
123
+ def forward(self, sample):
124
+ # 1. pre-process
125
+ hidden_states = rearrange(sample, "C T H W -> T C H W")
126
+ hidden_states = hidden_states / self.scaling_factor
127
+ hidden_states = self.conv_in(hidden_states)
128
+ time_emb, text_emb, res_stack = None, None, None
129
+
130
+ # 2. blocks
131
+ for i, block in enumerate(self.blocks):
132
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
133
+
134
+ # 3. output
135
+ hidden_states = self.conv_norm_out(hidden_states)
136
+ hidden_states = self.conv_act(hidden_states)
137
+ hidden_states = self.conv_out(hidden_states)
138
+ hidden_states = rearrange(hidden_states, "T C H W -> C T H W")
139
+ hidden_states = self.time_conv_out(hidden_states)
140
+
141
+ return hidden_states
142
+
143
+
144
+ def build_mask(self, data, is_bound):
145
+ _, T, H, W = data.shape
146
+ t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
147
+ h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
148
+ w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
149
+ border_width = (T + H + W) // 6
150
+ pad = torch.ones_like(t) * border_width
151
+ mask = torch.stack([
152
+ pad if is_bound[0] else t + 1,
153
+ pad if is_bound[1] else T - t,
154
+ pad if is_bound[2] else h + 1,
155
+ pad if is_bound[3] else H - h,
156
+ pad if is_bound[4] else w + 1,
157
+ pad if is_bound[5] else W - w
158
+ ]).min(dim=0).values
159
+ mask = mask.clip(1, border_width)
160
+ mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
161
+ mask = rearrange(mask, "T H W -> 1 T H W")
162
+ return mask
163
+
164
+
165
+ def decode_video(
166
+ self, sample,
167
+ batch_time=8, batch_height=128, batch_width=128,
168
+ stride_time=4, stride_height=32, stride_width=32,
169
+ progress_bar=lambda x:x
170
+ ):
171
+ sample = sample.permute(1, 0, 2, 3)
172
+ data_device = sample.device
173
+ computation_device = self.conv_in.weight.device
174
+ torch_dtype = sample.dtype
175
+ _, T, H, W = sample.shape
176
+
177
+ weight = torch.zeros((1, T, H*8, W*8), dtype=torch_dtype, device=data_device)
178
+ values = torch.zeros((3, T, H*8, W*8), dtype=torch_dtype, device=data_device)
179
+
180
+ # Split tasks
181
+ tasks = []
182
+ for t in range(0, T, stride_time):
183
+ for h in range(0, H, stride_height):
184
+ for w in range(0, W, stride_width):
185
+ if (t-stride_time >= 0 and t-stride_time+batch_time >= T)\
186
+ or (h-stride_height >= 0 and h-stride_height+batch_height >= H)\
187
+ or (w-stride_width >= 0 and w-stride_width+batch_width >= W):
188
+ continue
189
+ tasks.append((t, t+batch_time, h, h+batch_height, w, w+batch_width))
190
+
191
+ # Run
192
+ for tl, tr, hl, hr, wl, wr in progress_bar(tasks):
193
+ sample_batch = sample[:, tl:tr, hl:hr, wl:wr].to(computation_device)
194
+ sample_batch = self.forward(sample_batch).to(data_device)
195
+ mask = self.build_mask(sample_batch, is_bound=(tl==0, tr>=T, hl==0, hr>=H, wl==0, wr>=W))
196
+ values[:, tl:tr, hl*8:hr*8, wl*8:wr*8] += sample_batch * mask
197
+ weight[:, tl:tr, hl*8:hr*8, wl*8:wr*8] += mask
198
+ values /= weight
199
+ return values
200
+
201
+
202
+ def state_dict_converter(self):
203
+ return SVDVAEDecoderStateDictConverter()
204
+
205
+
206
+ class SVDVAEDecoderStateDictConverter:
207
+ def __init__(self):
208
+ pass
209
+
210
+ def from_diffusers(self, state_dict):
211
+ static_rename_dict = {
212
+ "decoder.conv_in": "conv_in",
213
+ "decoder.mid_block.attentions.0.group_norm": "blocks.2.norm",
214
+ "decoder.mid_block.attentions.0.to_q": "blocks.2.transformer_blocks.0.to_q",
215
+ "decoder.mid_block.attentions.0.to_k": "blocks.2.transformer_blocks.0.to_k",
216
+ "decoder.mid_block.attentions.0.to_v": "blocks.2.transformer_blocks.0.to_v",
217
+ "decoder.mid_block.attentions.0.to_out.0": "blocks.2.transformer_blocks.0.to_out",
218
+ "decoder.up_blocks.0.upsamplers.0.conv": "blocks.11.conv",
219
+ "decoder.up_blocks.1.upsamplers.0.conv": "blocks.18.conv",
220
+ "decoder.up_blocks.2.upsamplers.0.conv": "blocks.25.conv",
221
+ "decoder.conv_norm_out": "conv_norm_out",
222
+ "decoder.conv_out": "conv_out",
223
+ "decoder.time_conv_out": "time_conv_out"
224
+ }
225
+ prefix_rename_dict = {
226
+ "decoder.mid_block.resnets.0.spatial_res_block": "blocks.0",
227
+ "decoder.mid_block.resnets.0.temporal_res_block": "blocks.1",
228
+ "decoder.mid_block.resnets.0.time_mixer": "blocks.1",
229
+ "decoder.mid_block.resnets.1.spatial_res_block": "blocks.3",
230
+ "decoder.mid_block.resnets.1.temporal_res_block": "blocks.4",
231
+ "decoder.mid_block.resnets.1.time_mixer": "blocks.4",
232
+
233
+ "decoder.up_blocks.0.resnets.0.spatial_res_block": "blocks.5",
234
+ "decoder.up_blocks.0.resnets.0.temporal_res_block": "blocks.6",
235
+ "decoder.up_blocks.0.resnets.0.time_mixer": "blocks.6",
236
+ "decoder.up_blocks.0.resnets.1.spatial_res_block": "blocks.7",
237
+ "decoder.up_blocks.0.resnets.1.temporal_res_block": "blocks.8",
238
+ "decoder.up_blocks.0.resnets.1.time_mixer": "blocks.8",
239
+ "decoder.up_blocks.0.resnets.2.spatial_res_block": "blocks.9",
240
+ "decoder.up_blocks.0.resnets.2.temporal_res_block": "blocks.10",
241
+ "decoder.up_blocks.0.resnets.2.time_mixer": "blocks.10",
242
+
243
+ "decoder.up_blocks.1.resnets.0.spatial_res_block": "blocks.12",
244
+ "decoder.up_blocks.1.resnets.0.temporal_res_block": "blocks.13",
245
+ "decoder.up_blocks.1.resnets.0.time_mixer": "blocks.13",
246
+ "decoder.up_blocks.1.resnets.1.spatial_res_block": "blocks.14",
247
+ "decoder.up_blocks.1.resnets.1.temporal_res_block": "blocks.15",
248
+ "decoder.up_blocks.1.resnets.1.time_mixer": "blocks.15",
249
+ "decoder.up_blocks.1.resnets.2.spatial_res_block": "blocks.16",
250
+ "decoder.up_blocks.1.resnets.2.temporal_res_block": "blocks.17",
251
+ "decoder.up_blocks.1.resnets.2.time_mixer": "blocks.17",
252
+
253
+ "decoder.up_blocks.2.resnets.0.spatial_res_block": "blocks.19",
254
+ "decoder.up_blocks.2.resnets.0.temporal_res_block": "blocks.20",
255
+ "decoder.up_blocks.2.resnets.0.time_mixer": "blocks.20",
256
+ "decoder.up_blocks.2.resnets.1.spatial_res_block": "blocks.21",
257
+ "decoder.up_blocks.2.resnets.1.temporal_res_block": "blocks.22",
258
+ "decoder.up_blocks.2.resnets.1.time_mixer": "blocks.22",
259
+ "decoder.up_blocks.2.resnets.2.spatial_res_block": "blocks.23",
260
+ "decoder.up_blocks.2.resnets.2.temporal_res_block": "blocks.24",
261
+ "decoder.up_blocks.2.resnets.2.time_mixer": "blocks.24",
262
+
263
+ "decoder.up_blocks.3.resnets.0.spatial_res_block": "blocks.26",
264
+ "decoder.up_blocks.3.resnets.0.temporal_res_block": "blocks.27",
265
+ "decoder.up_blocks.3.resnets.0.time_mixer": "blocks.27",
266
+ "decoder.up_blocks.3.resnets.1.spatial_res_block": "blocks.28",
267
+ "decoder.up_blocks.3.resnets.1.temporal_res_block": "blocks.29",
268
+ "decoder.up_blocks.3.resnets.1.time_mixer": "blocks.29",
269
+ "decoder.up_blocks.3.resnets.2.spatial_res_block": "blocks.30",
270
+ "decoder.up_blocks.3.resnets.2.temporal_res_block": "blocks.31",
271
+ "decoder.up_blocks.3.resnets.2.time_mixer": "blocks.31",
272
+ }
273
+ suffix_rename_dict = {
274
+ "norm1.weight": "norm1.weight",
275
+ "conv1.weight": "conv1.weight",
276
+ "norm2.weight": "norm2.weight",
277
+ "conv2.weight": "conv2.weight",
278
+ "conv_shortcut.weight": "conv_shortcut.weight",
279
+ "norm1.bias": "norm1.bias",
280
+ "conv1.bias": "conv1.bias",
281
+ "norm2.bias": "norm2.bias",
282
+ "conv2.bias": "conv2.bias",
283
+ "conv_shortcut.bias": "conv_shortcut.bias",
284
+ "mix_factor": "mix_factor",
285
+ }
286
+
287
+ state_dict_ = {}
288
+ for name in static_rename_dict:
289
+ state_dict_[static_rename_dict[name] + ".weight"] = state_dict[name + ".weight"]
290
+ state_dict_[static_rename_dict[name] + ".bias"] = state_dict[name + ".bias"]
291
+ for prefix_name in prefix_rename_dict:
292
+ for suffix_name in suffix_rename_dict:
293
+ name = prefix_name + "." + suffix_name
294
+ name_ = prefix_rename_dict[prefix_name] + "." + suffix_rename_dict[suffix_name]
295
+ if name in state_dict:
296
+ state_dict_[name_] = state_dict[name]
297
+
298
+ return state_dict_
299
+
300
+
301
+ def from_civitai(self, state_dict):
302
+ rename_dict = {
303
+ "first_stage_model.decoder.conv_in.bias": "conv_in.bias",
304
+ "first_stage_model.decoder.conv_in.weight": "conv_in.weight",
305
+ "first_stage_model.decoder.conv_out.bias": "conv_out.bias",
306
+ "first_stage_model.decoder.conv_out.time_mix_conv.bias": "time_conv_out.bias",
307
+ "first_stage_model.decoder.conv_out.time_mix_conv.weight": "time_conv_out.weight",
308
+ "first_stage_model.decoder.conv_out.weight": "conv_out.weight",
309
+ "first_stage_model.decoder.mid.attn_1.k.bias": "blocks.2.transformer_blocks.0.to_k.bias",
310
+ "first_stage_model.decoder.mid.attn_1.k.weight": "blocks.2.transformer_blocks.0.to_k.weight",
311
+ "first_stage_model.decoder.mid.attn_1.norm.bias": "blocks.2.norm.bias",
312
+ "first_stage_model.decoder.mid.attn_1.norm.weight": "blocks.2.norm.weight",
313
+ "first_stage_model.decoder.mid.attn_1.proj_out.bias": "blocks.2.transformer_blocks.0.to_out.bias",
314
+ "first_stage_model.decoder.mid.attn_1.proj_out.weight": "blocks.2.transformer_blocks.0.to_out.weight",
315
+ "first_stage_model.decoder.mid.attn_1.q.bias": "blocks.2.transformer_blocks.0.to_q.bias",
316
+ "first_stage_model.decoder.mid.attn_1.q.weight": "blocks.2.transformer_blocks.0.to_q.weight",
317
+ "first_stage_model.decoder.mid.attn_1.v.bias": "blocks.2.transformer_blocks.0.to_v.bias",
318
+ "first_stage_model.decoder.mid.attn_1.v.weight": "blocks.2.transformer_blocks.0.to_v.weight",
319
+ "first_stage_model.decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
320
+ "first_stage_model.decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
321
+ "first_stage_model.decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
322
+ "first_stage_model.decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
323
+ "first_stage_model.decoder.mid.block_1.mix_factor": "blocks.1.mix_factor",
324
+ "first_stage_model.decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
325
+ "first_stage_model.decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
326
+ "first_stage_model.decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
327
+ "first_stage_model.decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
328
+ "first_stage_model.decoder.mid.block_1.time_stack.in_layers.0.bias": "blocks.1.norm1.bias",
329
+ "first_stage_model.decoder.mid.block_1.time_stack.in_layers.0.weight": "blocks.1.norm1.weight",
330
+ "first_stage_model.decoder.mid.block_1.time_stack.in_layers.2.bias": "blocks.1.conv1.bias",
331
+ "first_stage_model.decoder.mid.block_1.time_stack.in_layers.2.weight": "blocks.1.conv1.weight",
332
+ "first_stage_model.decoder.mid.block_1.time_stack.out_layers.0.bias": "blocks.1.norm2.bias",
333
+ "first_stage_model.decoder.mid.block_1.time_stack.out_layers.0.weight": "blocks.1.norm2.weight",
334
+ "first_stage_model.decoder.mid.block_1.time_stack.out_layers.3.bias": "blocks.1.conv2.bias",
335
+ "first_stage_model.decoder.mid.block_1.time_stack.out_layers.3.weight": "blocks.1.conv2.weight",
336
+ "first_stage_model.decoder.mid.block_2.conv1.bias": "blocks.3.conv1.bias",
337
+ "first_stage_model.decoder.mid.block_2.conv1.weight": "blocks.3.conv1.weight",
338
+ "first_stage_model.decoder.mid.block_2.conv2.bias": "blocks.3.conv2.bias",
339
+ "first_stage_model.decoder.mid.block_2.conv2.weight": "blocks.3.conv2.weight",
340
+ "first_stage_model.decoder.mid.block_2.mix_factor": "blocks.4.mix_factor",
341
+ "first_stage_model.decoder.mid.block_2.norm1.bias": "blocks.3.norm1.bias",
342
+ "first_stage_model.decoder.mid.block_2.norm1.weight": "blocks.3.norm1.weight",
343
+ "first_stage_model.decoder.mid.block_2.norm2.bias": "blocks.3.norm2.bias",
344
+ "first_stage_model.decoder.mid.block_2.norm2.weight": "blocks.3.norm2.weight",
345
+ "first_stage_model.decoder.mid.block_2.time_stack.in_layers.0.bias": "blocks.4.norm1.bias",
346
+ "first_stage_model.decoder.mid.block_2.time_stack.in_layers.0.weight": "blocks.4.norm1.weight",
347
+ "first_stage_model.decoder.mid.block_2.time_stack.in_layers.2.bias": "blocks.4.conv1.bias",
348
+ "first_stage_model.decoder.mid.block_2.time_stack.in_layers.2.weight": "blocks.4.conv1.weight",
349
+ "first_stage_model.decoder.mid.block_2.time_stack.out_layers.0.bias": "blocks.4.norm2.bias",
350
+ "first_stage_model.decoder.mid.block_2.time_stack.out_layers.0.weight": "blocks.4.norm2.weight",
351
+ "first_stage_model.decoder.mid.block_2.time_stack.out_layers.3.bias": "blocks.4.conv2.bias",
352
+ "first_stage_model.decoder.mid.block_2.time_stack.out_layers.3.weight": "blocks.4.conv2.weight",
353
+ "first_stage_model.decoder.norm_out.bias": "conv_norm_out.bias",
354
+ "first_stage_model.decoder.norm_out.weight": "conv_norm_out.weight",
355
+ "first_stage_model.decoder.up.0.block.0.conv1.bias": "blocks.26.conv1.bias",
356
+ "first_stage_model.decoder.up.0.block.0.conv1.weight": "blocks.26.conv1.weight",
357
+ "first_stage_model.decoder.up.0.block.0.conv2.bias": "blocks.26.conv2.bias",
358
+ "first_stage_model.decoder.up.0.block.0.conv2.weight": "blocks.26.conv2.weight",
359
+ "first_stage_model.decoder.up.0.block.0.mix_factor": "blocks.27.mix_factor",
360
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "blocks.26.conv_shortcut.bias",
361
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "blocks.26.conv_shortcut.weight",
362
+ "first_stage_model.decoder.up.0.block.0.norm1.bias": "blocks.26.norm1.bias",
363
+ "first_stage_model.decoder.up.0.block.0.norm1.weight": "blocks.26.norm1.weight",
364
+ "first_stage_model.decoder.up.0.block.0.norm2.bias": "blocks.26.norm2.bias",
365
+ "first_stage_model.decoder.up.0.block.0.norm2.weight": "blocks.26.norm2.weight",
366
+ "first_stage_model.decoder.up.0.block.0.time_stack.in_layers.0.bias": "blocks.27.norm1.bias",
367
+ "first_stage_model.decoder.up.0.block.0.time_stack.in_layers.0.weight": "blocks.27.norm1.weight",
368
+ "first_stage_model.decoder.up.0.block.0.time_stack.in_layers.2.bias": "blocks.27.conv1.bias",
369
+ "first_stage_model.decoder.up.0.block.0.time_stack.in_layers.2.weight": "blocks.27.conv1.weight",
370
+ "first_stage_model.decoder.up.0.block.0.time_stack.out_layers.0.bias": "blocks.27.norm2.bias",
371
+ "first_stage_model.decoder.up.0.block.0.time_stack.out_layers.0.weight": "blocks.27.norm2.weight",
372
+ "first_stage_model.decoder.up.0.block.0.time_stack.out_layers.3.bias": "blocks.27.conv2.bias",
373
+ "first_stage_model.decoder.up.0.block.0.time_stack.out_layers.3.weight": "blocks.27.conv2.weight",
374
+ "first_stage_model.decoder.up.0.block.1.conv1.bias": "blocks.28.conv1.bias",
375
+ "first_stage_model.decoder.up.0.block.1.conv1.weight": "blocks.28.conv1.weight",
376
+ "first_stage_model.decoder.up.0.block.1.conv2.bias": "blocks.28.conv2.bias",
377
+ "first_stage_model.decoder.up.0.block.1.conv2.weight": "blocks.28.conv2.weight",
378
+ "first_stage_model.decoder.up.0.block.1.mix_factor": "blocks.29.mix_factor",
379
+ "first_stage_model.decoder.up.0.block.1.norm1.bias": "blocks.28.norm1.bias",
380
+ "first_stage_model.decoder.up.0.block.1.norm1.weight": "blocks.28.norm1.weight",
381
+ "first_stage_model.decoder.up.0.block.1.norm2.bias": "blocks.28.norm2.bias",
382
+ "first_stage_model.decoder.up.0.block.1.norm2.weight": "blocks.28.norm2.weight",
383
+ "first_stage_model.decoder.up.0.block.1.time_stack.in_layers.0.bias": "blocks.29.norm1.bias",
384
+ "first_stage_model.decoder.up.0.block.1.time_stack.in_layers.0.weight": "blocks.29.norm1.weight",
385
+ "first_stage_model.decoder.up.0.block.1.time_stack.in_layers.2.bias": "blocks.29.conv1.bias",
386
+ "first_stage_model.decoder.up.0.block.1.time_stack.in_layers.2.weight": "blocks.29.conv1.weight",
387
+ "first_stage_model.decoder.up.0.block.1.time_stack.out_layers.0.bias": "blocks.29.norm2.bias",
388
+ "first_stage_model.decoder.up.0.block.1.time_stack.out_layers.0.weight": "blocks.29.norm2.weight",
389
+ "first_stage_model.decoder.up.0.block.1.time_stack.out_layers.3.bias": "blocks.29.conv2.bias",
390
+ "first_stage_model.decoder.up.0.block.1.time_stack.out_layers.3.weight": "blocks.29.conv2.weight",
391
+ "first_stage_model.decoder.up.0.block.2.conv1.bias": "blocks.30.conv1.bias",
392
+ "first_stage_model.decoder.up.0.block.2.conv1.weight": "blocks.30.conv1.weight",
393
+ "first_stage_model.decoder.up.0.block.2.conv2.bias": "blocks.30.conv2.bias",
394
+ "first_stage_model.decoder.up.0.block.2.conv2.weight": "blocks.30.conv2.weight",
395
+ "first_stage_model.decoder.up.0.block.2.mix_factor": "blocks.31.mix_factor",
396
+ "first_stage_model.decoder.up.0.block.2.norm1.bias": "blocks.30.norm1.bias",
397
+ "first_stage_model.decoder.up.0.block.2.norm1.weight": "blocks.30.norm1.weight",
398
+ "first_stage_model.decoder.up.0.block.2.norm2.bias": "blocks.30.norm2.bias",
399
+ "first_stage_model.decoder.up.0.block.2.norm2.weight": "blocks.30.norm2.weight",
400
+ "first_stage_model.decoder.up.0.block.2.time_stack.in_layers.0.bias": "blocks.31.norm1.bias",
401
+ "first_stage_model.decoder.up.0.block.2.time_stack.in_layers.0.weight": "blocks.31.norm1.weight",
402
+ "first_stage_model.decoder.up.0.block.2.time_stack.in_layers.2.bias": "blocks.31.conv1.bias",
403
+ "first_stage_model.decoder.up.0.block.2.time_stack.in_layers.2.weight": "blocks.31.conv1.weight",
404
+ "first_stage_model.decoder.up.0.block.2.time_stack.out_layers.0.bias": "blocks.31.norm2.bias",
405
+ "first_stage_model.decoder.up.0.block.2.time_stack.out_layers.0.weight": "blocks.31.norm2.weight",
406
+ "first_stage_model.decoder.up.0.block.2.time_stack.out_layers.3.bias": "blocks.31.conv2.bias",
407
+ "first_stage_model.decoder.up.0.block.2.time_stack.out_layers.3.weight": "blocks.31.conv2.weight",
408
+ "first_stage_model.decoder.up.1.block.0.conv1.bias": "blocks.19.conv1.bias",
409
+ "first_stage_model.decoder.up.1.block.0.conv1.weight": "blocks.19.conv1.weight",
410
+ "first_stage_model.decoder.up.1.block.0.conv2.bias": "blocks.19.conv2.bias",
411
+ "first_stage_model.decoder.up.1.block.0.conv2.weight": "blocks.19.conv2.weight",
412
+ "first_stage_model.decoder.up.1.block.0.mix_factor": "blocks.20.mix_factor",
413
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "blocks.19.conv_shortcut.bias",
414
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "blocks.19.conv_shortcut.weight",
415
+ "first_stage_model.decoder.up.1.block.0.norm1.bias": "blocks.19.norm1.bias",
416
+ "first_stage_model.decoder.up.1.block.0.norm1.weight": "blocks.19.norm1.weight",
417
+ "first_stage_model.decoder.up.1.block.0.norm2.bias": "blocks.19.norm2.bias",
418
+ "first_stage_model.decoder.up.1.block.0.norm2.weight": "blocks.19.norm2.weight",
419
+ "first_stage_model.decoder.up.1.block.0.time_stack.in_layers.0.bias": "blocks.20.norm1.bias",
420
+ "first_stage_model.decoder.up.1.block.0.time_stack.in_layers.0.weight": "blocks.20.norm1.weight",
421
+ "first_stage_model.decoder.up.1.block.0.time_stack.in_layers.2.bias": "blocks.20.conv1.bias",
422
+ "first_stage_model.decoder.up.1.block.0.time_stack.in_layers.2.weight": "blocks.20.conv1.weight",
423
+ "first_stage_model.decoder.up.1.block.0.time_stack.out_layers.0.bias": "blocks.20.norm2.bias",
424
+ "first_stage_model.decoder.up.1.block.0.time_stack.out_layers.0.weight": "blocks.20.norm2.weight",
425
+ "first_stage_model.decoder.up.1.block.0.time_stack.out_layers.3.bias": "blocks.20.conv2.bias",
426
+ "first_stage_model.decoder.up.1.block.0.time_stack.out_layers.3.weight": "blocks.20.conv2.weight",
427
+ "first_stage_model.decoder.up.1.block.1.conv1.bias": "blocks.21.conv1.bias",
428
+ "first_stage_model.decoder.up.1.block.1.conv1.weight": "blocks.21.conv1.weight",
429
+ "first_stage_model.decoder.up.1.block.1.conv2.bias": "blocks.21.conv2.bias",
430
+ "first_stage_model.decoder.up.1.block.1.conv2.weight": "blocks.21.conv2.weight",
431
+ "first_stage_model.decoder.up.1.block.1.mix_factor": "blocks.22.mix_factor",
432
+ "first_stage_model.decoder.up.1.block.1.norm1.bias": "blocks.21.norm1.bias",
433
+ "first_stage_model.decoder.up.1.block.1.norm1.weight": "blocks.21.norm1.weight",
434
+ "first_stage_model.decoder.up.1.block.1.norm2.bias": "blocks.21.norm2.bias",
435
+ "first_stage_model.decoder.up.1.block.1.norm2.weight": "blocks.21.norm2.weight",
436
+ "first_stage_model.decoder.up.1.block.1.time_stack.in_layers.0.bias": "blocks.22.norm1.bias",
437
+ "first_stage_model.decoder.up.1.block.1.time_stack.in_layers.0.weight": "blocks.22.norm1.weight",
438
+ "first_stage_model.decoder.up.1.block.1.time_stack.in_layers.2.bias": "blocks.22.conv1.bias",
439
+ "first_stage_model.decoder.up.1.block.1.time_stack.in_layers.2.weight": "blocks.22.conv1.weight",
440
+ "first_stage_model.decoder.up.1.block.1.time_stack.out_layers.0.bias": "blocks.22.norm2.bias",
441
+ "first_stage_model.decoder.up.1.block.1.time_stack.out_layers.0.weight": "blocks.22.norm2.weight",
442
+ "first_stage_model.decoder.up.1.block.1.time_stack.out_layers.3.bias": "blocks.22.conv2.bias",
443
+ "first_stage_model.decoder.up.1.block.1.time_stack.out_layers.3.weight": "blocks.22.conv2.weight",
444
+ "first_stage_model.decoder.up.1.block.2.conv1.bias": "blocks.23.conv1.bias",
445
+ "first_stage_model.decoder.up.1.block.2.conv1.weight": "blocks.23.conv1.weight",
446
+ "first_stage_model.decoder.up.1.block.2.conv2.bias": "blocks.23.conv2.bias",
447
+ "first_stage_model.decoder.up.1.block.2.conv2.weight": "blocks.23.conv2.weight",
448
+ "first_stage_model.decoder.up.1.block.2.mix_factor": "blocks.24.mix_factor",
449
+ "first_stage_model.decoder.up.1.block.2.norm1.bias": "blocks.23.norm1.bias",
450
+ "first_stage_model.decoder.up.1.block.2.norm1.weight": "blocks.23.norm1.weight",
451
+ "first_stage_model.decoder.up.1.block.2.norm2.bias": "blocks.23.norm2.bias",
452
+ "first_stage_model.decoder.up.1.block.2.norm2.weight": "blocks.23.norm2.weight",
453
+ "first_stage_model.decoder.up.1.block.2.time_stack.in_layers.0.bias": "blocks.24.norm1.bias",
454
+ "first_stage_model.decoder.up.1.block.2.time_stack.in_layers.0.weight": "blocks.24.norm1.weight",
455
+ "first_stage_model.decoder.up.1.block.2.time_stack.in_layers.2.bias": "blocks.24.conv1.bias",
456
+ "first_stage_model.decoder.up.1.block.2.time_stack.in_layers.2.weight": "blocks.24.conv1.weight",
457
+ "first_stage_model.decoder.up.1.block.2.time_stack.out_layers.0.bias": "blocks.24.norm2.bias",
458
+ "first_stage_model.decoder.up.1.block.2.time_stack.out_layers.0.weight": "blocks.24.norm2.weight",
459
+ "first_stage_model.decoder.up.1.block.2.time_stack.out_layers.3.bias": "blocks.24.conv2.bias",
460
+ "first_stage_model.decoder.up.1.block.2.time_stack.out_layers.3.weight": "blocks.24.conv2.weight",
461
+ "first_stage_model.decoder.up.1.upsample.conv.bias": "blocks.25.conv.bias",
462
+ "first_stage_model.decoder.up.1.upsample.conv.weight": "blocks.25.conv.weight",
463
+ "first_stage_model.decoder.up.2.block.0.conv1.bias": "blocks.12.conv1.bias",
464
+ "first_stage_model.decoder.up.2.block.0.conv1.weight": "blocks.12.conv1.weight",
465
+ "first_stage_model.decoder.up.2.block.0.conv2.bias": "blocks.12.conv2.bias",
466
+ "first_stage_model.decoder.up.2.block.0.conv2.weight": "blocks.12.conv2.weight",
467
+ "first_stage_model.decoder.up.2.block.0.mix_factor": "blocks.13.mix_factor",
468
+ "first_stage_model.decoder.up.2.block.0.norm1.bias": "blocks.12.norm1.bias",
469
+ "first_stage_model.decoder.up.2.block.0.norm1.weight": "blocks.12.norm1.weight",
470
+ "first_stage_model.decoder.up.2.block.0.norm2.bias": "blocks.12.norm2.bias",
471
+ "first_stage_model.decoder.up.2.block.0.norm2.weight": "blocks.12.norm2.weight",
472
+ "first_stage_model.decoder.up.2.block.0.time_stack.in_layers.0.bias": "blocks.13.norm1.bias",
473
+ "first_stage_model.decoder.up.2.block.0.time_stack.in_layers.0.weight": "blocks.13.norm1.weight",
474
+ "first_stage_model.decoder.up.2.block.0.time_stack.in_layers.2.bias": "blocks.13.conv1.bias",
475
+ "first_stage_model.decoder.up.2.block.0.time_stack.in_layers.2.weight": "blocks.13.conv1.weight",
476
+ "first_stage_model.decoder.up.2.block.0.time_stack.out_layers.0.bias": "blocks.13.norm2.bias",
477
+ "first_stage_model.decoder.up.2.block.0.time_stack.out_layers.0.weight": "blocks.13.norm2.weight",
478
+ "first_stage_model.decoder.up.2.block.0.time_stack.out_layers.3.bias": "blocks.13.conv2.bias",
479
+ "first_stage_model.decoder.up.2.block.0.time_stack.out_layers.3.weight": "blocks.13.conv2.weight",
480
+ "first_stage_model.decoder.up.2.block.1.conv1.bias": "blocks.14.conv1.bias",
481
+ "first_stage_model.decoder.up.2.block.1.conv1.weight": "blocks.14.conv1.weight",
482
+ "first_stage_model.decoder.up.2.block.1.conv2.bias": "blocks.14.conv2.bias",
483
+ "first_stage_model.decoder.up.2.block.1.conv2.weight": "blocks.14.conv2.weight",
484
+ "first_stage_model.decoder.up.2.block.1.mix_factor": "blocks.15.mix_factor",
485
+ "first_stage_model.decoder.up.2.block.1.norm1.bias": "blocks.14.norm1.bias",
486
+ "first_stage_model.decoder.up.2.block.1.norm1.weight": "blocks.14.norm1.weight",
487
+ "first_stage_model.decoder.up.2.block.1.norm2.bias": "blocks.14.norm2.bias",
488
+ "first_stage_model.decoder.up.2.block.1.norm2.weight": "blocks.14.norm2.weight",
489
+ "first_stage_model.decoder.up.2.block.1.time_stack.in_layers.0.bias": "blocks.15.norm1.bias",
490
+ "first_stage_model.decoder.up.2.block.1.time_stack.in_layers.0.weight": "blocks.15.norm1.weight",
491
+ "first_stage_model.decoder.up.2.block.1.time_stack.in_layers.2.bias": "blocks.15.conv1.bias",
492
+ "first_stage_model.decoder.up.2.block.1.time_stack.in_layers.2.weight": "blocks.15.conv1.weight",
493
+ "first_stage_model.decoder.up.2.block.1.time_stack.out_layers.0.bias": "blocks.15.norm2.bias",
494
+ "first_stage_model.decoder.up.2.block.1.time_stack.out_layers.0.weight": "blocks.15.norm2.weight",
495
+ "first_stage_model.decoder.up.2.block.1.time_stack.out_layers.3.bias": "blocks.15.conv2.bias",
496
+ "first_stage_model.decoder.up.2.block.1.time_stack.out_layers.3.weight": "blocks.15.conv2.weight",
497
+ "first_stage_model.decoder.up.2.block.2.conv1.bias": "blocks.16.conv1.bias",
498
+ "first_stage_model.decoder.up.2.block.2.conv1.weight": "blocks.16.conv1.weight",
499
+ "first_stage_model.decoder.up.2.block.2.conv2.bias": "blocks.16.conv2.bias",
500
+ "first_stage_model.decoder.up.2.block.2.conv2.weight": "blocks.16.conv2.weight",
501
+ "first_stage_model.decoder.up.2.block.2.mix_factor": "blocks.17.mix_factor",
502
+ "first_stage_model.decoder.up.2.block.2.norm1.bias": "blocks.16.norm1.bias",
503
+ "first_stage_model.decoder.up.2.block.2.norm1.weight": "blocks.16.norm1.weight",
504
+ "first_stage_model.decoder.up.2.block.2.norm2.bias": "blocks.16.norm2.bias",
505
+ "first_stage_model.decoder.up.2.block.2.norm2.weight": "blocks.16.norm2.weight",
506
+ "first_stage_model.decoder.up.2.block.2.time_stack.in_layers.0.bias": "blocks.17.norm1.bias",
507
+ "first_stage_model.decoder.up.2.block.2.time_stack.in_layers.0.weight": "blocks.17.norm1.weight",
508
+ "first_stage_model.decoder.up.2.block.2.time_stack.in_layers.2.bias": "blocks.17.conv1.bias",
509
+ "first_stage_model.decoder.up.2.block.2.time_stack.in_layers.2.weight": "blocks.17.conv1.weight",
510
+ "first_stage_model.decoder.up.2.block.2.time_stack.out_layers.0.bias": "blocks.17.norm2.bias",
511
+ "first_stage_model.decoder.up.2.block.2.time_stack.out_layers.0.weight": "blocks.17.norm2.weight",
512
+ "first_stage_model.decoder.up.2.block.2.time_stack.out_layers.3.bias": "blocks.17.conv2.bias",
513
+ "first_stage_model.decoder.up.2.block.2.time_stack.out_layers.3.weight": "blocks.17.conv2.weight",
514
+ "first_stage_model.decoder.up.2.upsample.conv.bias": "blocks.18.conv.bias",
515
+ "first_stage_model.decoder.up.2.upsample.conv.weight": "blocks.18.conv.weight",
516
+ "first_stage_model.decoder.up.3.block.0.conv1.bias": "blocks.5.conv1.bias",
517
+ "first_stage_model.decoder.up.3.block.0.conv1.weight": "blocks.5.conv1.weight",
518
+ "first_stage_model.decoder.up.3.block.0.conv2.bias": "blocks.5.conv2.bias",
519
+ "first_stage_model.decoder.up.3.block.0.conv2.weight": "blocks.5.conv2.weight",
520
+ "first_stage_model.decoder.up.3.block.0.mix_factor": "blocks.6.mix_factor",
521
+ "first_stage_model.decoder.up.3.block.0.norm1.bias": "blocks.5.norm1.bias",
522
+ "first_stage_model.decoder.up.3.block.0.norm1.weight": "blocks.5.norm1.weight",
523
+ "first_stage_model.decoder.up.3.block.0.norm2.bias": "blocks.5.norm2.bias",
524
+ "first_stage_model.decoder.up.3.block.0.norm2.weight": "blocks.5.norm2.weight",
525
+ "first_stage_model.decoder.up.3.block.0.time_stack.in_layers.0.bias": "blocks.6.norm1.bias",
526
+ "first_stage_model.decoder.up.3.block.0.time_stack.in_layers.0.weight": "blocks.6.norm1.weight",
527
+ "first_stage_model.decoder.up.3.block.0.time_stack.in_layers.2.bias": "blocks.6.conv1.bias",
528
+ "first_stage_model.decoder.up.3.block.0.time_stack.in_layers.2.weight": "blocks.6.conv1.weight",
529
+ "first_stage_model.decoder.up.3.block.0.time_stack.out_layers.0.bias": "blocks.6.norm2.bias",
530
+ "first_stage_model.decoder.up.3.block.0.time_stack.out_layers.0.weight": "blocks.6.norm2.weight",
531
+ "first_stage_model.decoder.up.3.block.0.time_stack.out_layers.3.bias": "blocks.6.conv2.bias",
532
+ "first_stage_model.decoder.up.3.block.0.time_stack.out_layers.3.weight": "blocks.6.conv2.weight",
533
+ "first_stage_model.decoder.up.3.block.1.conv1.bias": "blocks.7.conv1.bias",
534
+ "first_stage_model.decoder.up.3.block.1.conv1.weight": "blocks.7.conv1.weight",
535
+ "first_stage_model.decoder.up.3.block.1.conv2.bias": "blocks.7.conv2.bias",
536
+ "first_stage_model.decoder.up.3.block.1.conv2.weight": "blocks.7.conv2.weight",
537
+ "first_stage_model.decoder.up.3.block.1.mix_factor": "blocks.8.mix_factor",
538
+ "first_stage_model.decoder.up.3.block.1.norm1.bias": "blocks.7.norm1.bias",
539
+ "first_stage_model.decoder.up.3.block.1.norm1.weight": "blocks.7.norm1.weight",
540
+ "first_stage_model.decoder.up.3.block.1.norm2.bias": "blocks.7.norm2.bias",
541
+ "first_stage_model.decoder.up.3.block.1.norm2.weight": "blocks.7.norm2.weight",
542
+ "first_stage_model.decoder.up.3.block.1.time_stack.in_layers.0.bias": "blocks.8.norm1.bias",
543
+ "first_stage_model.decoder.up.3.block.1.time_stack.in_layers.0.weight": "blocks.8.norm1.weight",
544
+ "first_stage_model.decoder.up.3.block.1.time_stack.in_layers.2.bias": "blocks.8.conv1.bias",
545
+ "first_stage_model.decoder.up.3.block.1.time_stack.in_layers.2.weight": "blocks.8.conv1.weight",
546
+ "first_stage_model.decoder.up.3.block.1.time_stack.out_layers.0.bias": "blocks.8.norm2.bias",
547
+ "first_stage_model.decoder.up.3.block.1.time_stack.out_layers.0.weight": "blocks.8.norm2.weight",
548
+ "first_stage_model.decoder.up.3.block.1.time_stack.out_layers.3.bias": "blocks.8.conv2.bias",
549
+ "first_stage_model.decoder.up.3.block.1.time_stack.out_layers.3.weight": "blocks.8.conv2.weight",
550
+ "first_stage_model.decoder.up.3.block.2.conv1.bias": "blocks.9.conv1.bias",
551
+ "first_stage_model.decoder.up.3.block.2.conv1.weight": "blocks.9.conv1.weight",
552
+ "first_stage_model.decoder.up.3.block.2.conv2.bias": "blocks.9.conv2.bias",
553
+ "first_stage_model.decoder.up.3.block.2.conv2.weight": "blocks.9.conv2.weight",
554
+ "first_stage_model.decoder.up.3.block.2.mix_factor": "blocks.10.mix_factor",
555
+ "first_stage_model.decoder.up.3.block.2.norm1.bias": "blocks.9.norm1.bias",
556
+ "first_stage_model.decoder.up.3.block.2.norm1.weight": "blocks.9.norm1.weight",
557
+ "first_stage_model.decoder.up.3.block.2.norm2.bias": "blocks.9.norm2.bias",
558
+ "first_stage_model.decoder.up.3.block.2.norm2.weight": "blocks.9.norm2.weight",
559
+ "first_stage_model.decoder.up.3.block.2.time_stack.in_layers.0.bias": "blocks.10.norm1.bias",
560
+ "first_stage_model.decoder.up.3.block.2.time_stack.in_layers.0.weight": "blocks.10.norm1.weight",
561
+ "first_stage_model.decoder.up.3.block.2.time_stack.in_layers.2.bias": "blocks.10.conv1.bias",
562
+ "first_stage_model.decoder.up.3.block.2.time_stack.in_layers.2.weight": "blocks.10.conv1.weight",
563
+ "first_stage_model.decoder.up.3.block.2.time_stack.out_layers.0.bias": "blocks.10.norm2.bias",
564
+ "first_stage_model.decoder.up.3.block.2.time_stack.out_layers.0.weight": "blocks.10.norm2.weight",
565
+ "first_stage_model.decoder.up.3.block.2.time_stack.out_layers.3.bias": "blocks.10.conv2.bias",
566
+ "first_stage_model.decoder.up.3.block.2.time_stack.out_layers.3.weight": "blocks.10.conv2.weight",
567
+ "first_stage_model.decoder.up.3.upsample.conv.bias": "blocks.11.conv.bias",
568
+ "first_stage_model.decoder.up.3.upsample.conv.weight": "blocks.11.conv.weight",
569
+ }
570
+ state_dict_ = {}
571
+ for name in state_dict:
572
+ if name in rename_dict:
573
+ param = state_dict[name]
574
+ if "blocks.2.transformer_blocks.0" in rename_dict[name]:
575
+ param = param.squeeze()
576
+ state_dict_[rename_dict[name]] = param
577
+ return state_dict_
diffsynth/models/svd_vae_encoder.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder
2
+
3
+
4
+ class SVDVAEEncoder(SDVAEEncoder):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.scaling_factor = 0.13025
8
+
9
+ def state_dict_converter(self):
10
+ return SVDVAEEncoderStateDictConverter()
11
+
12
+
13
+ class SVDVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ def from_diffusers(self, state_dict):
18
+ return super().from_diffusers(state_dict)
19
+
20
+ def from_civitai(self, state_dict):
21
+ rename_dict = {
22
+ "conditioner.embedders.3.encoder.encoder.conv_in.bias": "conv_in.bias",
23
+ "conditioner.embedders.3.encoder.encoder.conv_in.weight": "conv_in.weight",
24
+ "conditioner.embedders.3.encoder.encoder.conv_out.bias": "conv_out.bias",
25
+ "conditioner.embedders.3.encoder.encoder.conv_out.weight": "conv_out.weight",
26
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
27
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
28
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
29
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
30
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
31
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
32
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
33
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
34
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
35
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
36
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
37
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
38
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
39
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
40
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
41
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
42
+ "conditioner.embedders.3.encoder.encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
43
+ "conditioner.embedders.3.encoder.encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
44
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
45
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
46
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
47
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
48
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
49
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
50
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
51
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
52
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
53
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
54
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
55
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
56
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
57
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
58
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
59
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
60
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
61
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
62
+ "conditioner.embedders.3.encoder.encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
63
+ "conditioner.embedders.3.encoder.encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
64
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
65
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
66
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
67
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
68
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
69
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
70
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
71
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
72
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
73
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
74
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
75
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
76
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
77
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
78
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
79
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
80
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
81
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
82
+ "conditioner.embedders.3.encoder.encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
83
+ "conditioner.embedders.3.encoder.encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
84
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
85
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
86
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
87
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
88
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
89
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
90
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
91
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
92
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
93
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
94
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
95
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
96
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
97
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
98
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
99
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
100
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
101
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
102
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
103
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
104
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
105
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
106
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
107
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
108
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
109
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
110
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
111
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
112
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
113
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
114
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
115
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
116
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
117
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
118
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
119
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
120
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
121
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
122
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
123
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
124
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
125
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
126
+ "conditioner.embedders.3.encoder.encoder.norm_out.bias": "conv_norm_out.bias",
127
+ "conditioner.embedders.3.encoder.encoder.norm_out.weight": "conv_norm_out.weight",
128
+ "conditioner.embedders.3.encoder.quant_conv.bias": "quant_conv.bias",
129
+ "conditioner.embedders.3.encoder.quant_conv.weight": "quant_conv.weight",
130
+ }
131
+ state_dict_ = {}
132
+ for name in state_dict:
133
+ if name in rename_dict:
134
+ param = state_dict[name]
135
+ if "transformer_blocks" in rename_dict[name]:
136
+ param = param.squeeze()
137
+ state_dict_[rename_dict[name]] = param
138
+ return state_dict_
diffsynth/models/tiler.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange, repeat
3
+
4
+
5
+ class TileWorker:
6
+ def __init__(self):
7
+ pass
8
+
9
+
10
+ def mask(self, height, width, border_width):
11
+ # Create a mask with shape (height, width).
12
+ # The centre area is filled with 1, and the border line is filled with values in range (0, 1].
13
+ x = torch.arange(height).repeat(width, 1).T
14
+ y = torch.arange(width).repeat(height, 1)
15
+ mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
16
+ mask = (mask / border_width).clip(0, 1)
17
+ return mask
18
+
19
+
20
+ def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):
21
+ # Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)
22
+ batch_size, channel, _, _ = model_input.shape
23
+ model_input = model_input.to(device=tile_device, dtype=tile_dtype)
24
+ unfold_operator = torch.nn.Unfold(
25
+ kernel_size=(tile_size, tile_size),
26
+ stride=(tile_stride, tile_stride)
27
+ )
28
+ model_input = unfold_operator(model_input)
29
+ model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))
30
+
31
+ return model_input
32
+
33
+
34
+ def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype):
35
+ # Call y=forward_fn(x) for each tile
36
+ tile_num = model_input.shape[-1]
37
+ model_output_stack = []
38
+
39
+ for tile_id in range(0, tile_num, tile_batch_size):
40
+
41
+ # process input
42
+ tile_id_ = min(tile_id + tile_batch_size, tile_num)
43
+ x = model_input[:, :, :, :, tile_id: tile_id_]
44
+ x = x.to(device=inference_device, dtype=inference_dtype)
45
+ x = rearrange(x, "b c h w n -> (n b) c h w")
46
+
47
+ # process output
48
+ y = forward_fn(x)
49
+ y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id)
50
+ y = y.to(device=tile_device, dtype=tile_dtype)
51
+ model_output_stack.append(y)
52
+
53
+ model_output = torch.concat(model_output_stack, dim=-1)
54
+ return model_output
55
+
56
+
57
+ def io_scale(self, model_output, tile_size):
58
+ # Determine the size modification happend in forward_fn
59
+ # We only consider the same scale on height and width.
60
+ io_scale = model_output.shape[2] / tile_size
61
+ return io_scale
62
+
63
+
64
+ def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype):
65
+ # The reversed function of tile
66
+ mask = self.mask(tile_size, tile_size, border_width)
67
+ mask = mask.to(device=tile_device, dtype=tile_dtype)
68
+ mask = rearrange(mask, "h w -> 1 1 h w 1")
69
+ model_output = model_output * mask
70
+
71
+ fold_operator = torch.nn.Fold(
72
+ output_size=(height, width),
73
+ kernel_size=(tile_size, tile_size),
74
+ stride=(tile_stride, tile_stride)
75
+ )
76
+ mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1])
77
+ model_output = rearrange(model_output, "b c h w n -> b (c h w) n")
78
+ model_output = fold_operator(model_output) / fold_operator(mask)
79
+
80
+ return model_output
81
+
82
+
83
+ def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
84
+ # Prepare
85
+ inference_device, inference_dtype = model_input.device, model_input.dtype
86
+ height, width = model_input.shape[2], model_input.shape[3]
87
+ border_width = int(tile_stride*0.5) if border_width is None else border_width
88
+
89
+ # tile
90
+ model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype)
91
+
92
+ # inference
93
+ model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype)
94
+
95
+ # resize
96
+ io_scale = self.io_scale(model_output, tile_size)
97
+ height, width = int(height*io_scale), int(width*io_scale)
98
+ tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale)
99
+ border_width = int(border_width*io_scale)
100
+
101
+ # untile
102
+ model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype)
103
+
104
+ # Done!
105
+ model_output = model_output.to(device=inference_device, dtype=inference_dtype)
106
+ return model_output
diffsynth/pipelines/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .stable_diffusion import SDImagePipeline
2
+ from .stable_diffusion_xl import SDXLImagePipeline
3
+ from .stable_diffusion_video import SDVideoPipeline, SDVideoPipelineRunner
4
+ from .stable_diffusion_xl_video import SDXLVideoPipeline
5
+ from .stable_video_diffusion import SVDVideoPipeline
6
+ from .hunyuan_dit import HunyuanDiTImagePipeline
diffsynth/pipelines/dancer.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ..models import SDUNet, SDMotionModel, SDXLUNet, SDXLMotionModel
3
+ from ..models.sd_unet import PushBlock, PopBlock
4
+ from ..controlnets import MultiControlNetManager
5
+
6
+
7
+ def lets_dance(
8
+ unet: SDUNet,
9
+ motion_modules: SDMotionModel = None,
10
+ controlnet: MultiControlNetManager = None,
11
+ sample = None,
12
+ timestep = None,
13
+ encoder_hidden_states = None,
14
+ ipadapter_kwargs_list = {},
15
+ controlnet_frames = None,
16
+ unet_batch_size = 1,
17
+ controlnet_batch_size = 1,
18
+ cross_frame_attention = False,
19
+ tiled=False,
20
+ tile_size=64,
21
+ tile_stride=32,
22
+ device = "cuda",
23
+ vram_limit_level = 0,
24
+ ):
25
+ # 1. ControlNet
26
+ # This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride.
27
+ # I leave it here because I intend to do something interesting on the ControlNets.
28
+ controlnet_insert_block_id = 30
29
+ if controlnet is not None and controlnet_frames is not None:
30
+ res_stacks = []
31
+ # process controlnet frames with batch
32
+ for batch_id in range(0, sample.shape[0], controlnet_batch_size):
33
+ batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0])
34
+ res_stack = controlnet(
35
+ sample[batch_id: batch_id_],
36
+ timestep,
37
+ encoder_hidden_states[batch_id: batch_id_],
38
+ controlnet_frames[:, batch_id: batch_id_],
39
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
40
+ )
41
+ if vram_limit_level >= 1:
42
+ res_stack = [res.cpu() for res in res_stack]
43
+ res_stacks.append(res_stack)
44
+ # concat the residual
45
+ additional_res_stack = []
46
+ for i in range(len(res_stacks[0])):
47
+ res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0)
48
+ additional_res_stack.append(res)
49
+ else:
50
+ additional_res_stack = None
51
+
52
+ # 2. time
53
+ time_emb = unet.time_proj(timestep[None]).to(sample.dtype)
54
+ time_emb = unet.time_embedding(time_emb)
55
+
56
+ # 3. pre-process
57
+ height, width = sample.shape[2], sample.shape[3]
58
+ hidden_states = unet.conv_in(sample)
59
+ text_emb = encoder_hidden_states
60
+ res_stack = [hidden_states.cpu() if vram_limit_level>=1 else hidden_states]
61
+
62
+ # 4. blocks
63
+ for block_id, block in enumerate(unet.blocks):
64
+ # 4.1 UNet
65
+ if isinstance(block, PushBlock):
66
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
67
+ if vram_limit_level>=1:
68
+ res_stack[-1] = res_stack[-1].cpu()
69
+ elif isinstance(block, PopBlock):
70
+ if vram_limit_level>=1:
71
+ res_stack[-1] = res_stack[-1].to(device)
72
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
73
+ else:
74
+ hidden_states_input = hidden_states
75
+ hidden_states_output = []
76
+ for batch_id in range(0, sample.shape[0], unet_batch_size):
77
+ batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
78
+ hidden_states, _, _, _ = block(
79
+ hidden_states_input[batch_id: batch_id_],
80
+ time_emb,
81
+ text_emb[batch_id: batch_id_],
82
+ res_stack,
83
+ cross_frame_attention=cross_frame_attention,
84
+ ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}),
85
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
86
+ )
87
+ hidden_states_output.append(hidden_states)
88
+ hidden_states = torch.concat(hidden_states_output, dim=0)
89
+ # 4.2 AnimateDiff
90
+ if motion_modules is not None:
91
+ if block_id in motion_modules.call_block_id:
92
+ motion_module_id = motion_modules.call_block_id[block_id]
93
+ hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](
94
+ hidden_states, time_emb, text_emb, res_stack,
95
+ batch_size=1
96
+ )
97
+ # 4.3 ControlNet
98
+ if block_id == controlnet_insert_block_id and additional_res_stack is not None:
99
+ hidden_states += additional_res_stack.pop().to(device)
100
+ if vram_limit_level>=1:
101
+ res_stack = [(res.to(device) + additional_res.to(device)).cpu() for res, additional_res in zip(res_stack, additional_res_stack)]
102
+ else:
103
+ res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
104
+
105
+ # 5. output
106
+ hidden_states = unet.conv_norm_out(hidden_states)
107
+ hidden_states = unet.conv_act(hidden_states)
108
+ hidden_states = unet.conv_out(hidden_states)
109
+
110
+ return hidden_states
111
+
112
+
113
+
114
+
115
+ def lets_dance_xl(
116
+ unet: SDXLUNet,
117
+ motion_modules: SDXLMotionModel = None,
118
+ controlnet: MultiControlNetManager = None,
119
+ sample = None,
120
+ add_time_id = None,
121
+ add_text_embeds = None,
122
+ timestep = None,
123
+ encoder_hidden_states = None,
124
+ ipadapter_kwargs_list = {},
125
+ controlnet_frames = None,
126
+ unet_batch_size = 1,
127
+ controlnet_batch_size = 1,
128
+ cross_frame_attention = False,
129
+ tiled=False,
130
+ tile_size=64,
131
+ tile_stride=32,
132
+ device = "cuda",
133
+ vram_limit_level = 0,
134
+ ):
135
+ # 2. time
136
+ t_emb = unet.time_proj(timestep[None]).to(sample.dtype)
137
+ t_emb = unet.time_embedding(t_emb)
138
+
139
+ time_embeds = unet.add_time_proj(add_time_id)
140
+ time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1))
141
+ add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1)
142
+ add_embeds = add_embeds.to(sample.dtype)
143
+ add_embeds = unet.add_time_embedding(add_embeds)
144
+
145
+ time_emb = t_emb + add_embeds
146
+
147
+ # 3. pre-process
148
+ height, width = sample.shape[2], sample.shape[3]
149
+ hidden_states = unet.conv_in(sample)
150
+ text_emb = encoder_hidden_states
151
+ res_stack = [hidden_states]
152
+
153
+ # 4. blocks
154
+ for block_id, block in enumerate(unet.blocks):
155
+ hidden_states, time_emb, text_emb, res_stack = block(
156
+ hidden_states, time_emb, text_emb, res_stack,
157
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
158
+ ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {})
159
+ )
160
+ # 4.2 AnimateDiff
161
+ if motion_modules is not None:
162
+ if block_id in motion_modules.call_block_id:
163
+ motion_module_id = motion_modules.call_block_id[block_id]
164
+ hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](
165
+ hidden_states, time_emb, text_emb, res_stack,
166
+ batch_size=1
167
+ )
168
+
169
+ # 5. output
170
+ hidden_states = unet.conv_norm_out(hidden_states)
171
+ hidden_states = unet.conv_act(hidden_states)
172
+ hidden_states = unet.conv_out(hidden_states)
173
+
174
+ return hidden_states
diffsynth/pipelines/hunyuan_dit.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..models.hunyuan_dit import HunyuanDiT
2
+ from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
3
+ from ..models.sdxl_vae_encoder import SDXLVAEEncoder
4
+ from ..models.sdxl_vae_decoder import SDXLVAEDecoder
5
+ from ..models import ModelManager
6
+ from ..prompts import HunyuanDiTPrompter
7
+ from ..schedulers import EnhancedDDIMScheduler
8
+ import torch
9
+ from tqdm import tqdm
10
+ from PIL import Image
11
+ import numpy as np
12
+
13
+
14
+
15
+ class ImageSizeManager:
16
+ def __init__(self):
17
+ pass
18
+
19
+
20
+ def _to_tuple(self, x):
21
+ if isinstance(x, int):
22
+ return x, x
23
+ else:
24
+ return x
25
+
26
+
27
+ def get_fill_resize_and_crop(self, src, tgt):
28
+ th, tw = self._to_tuple(tgt)
29
+ h, w = self._to_tuple(src)
30
+
31
+ tr = th / tw # base 分辨率
32
+ r = h / w # 目标分辨率
33
+
34
+ # resize
35
+ if r > tr:
36
+ resize_height = th
37
+ resize_width = int(round(th / h * w))
38
+ else:
39
+ resize_width = tw
40
+ resize_height = int(round(tw / w * h)) # 根据base分辨率,将目标分辨率resize下来
41
+
42
+ crop_top = int(round((th - resize_height) / 2.0))
43
+ crop_left = int(round((tw - resize_width) / 2.0))
44
+
45
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
46
+
47
+
48
+ def get_meshgrid(self, start, *args):
49
+ if len(args) == 0:
50
+ # start is grid_size
51
+ num = self._to_tuple(start)
52
+ start = (0, 0)
53
+ stop = num
54
+ elif len(args) == 1:
55
+ # start is start, args[0] is stop, step is 1
56
+ start = self._to_tuple(start)
57
+ stop = self._to_tuple(args[0])
58
+ num = (stop[0] - start[0], stop[1] - start[1])
59
+ elif len(args) == 2:
60
+ # start is start, args[0] is stop, args[1] is num
61
+ start = self._to_tuple(start) # 左上角 eg: 12,0
62
+ stop = self._to_tuple(args[0]) # 右下角 eg: 20,32
63
+ num = self._to_tuple(args[1]) # 目标大小 eg: 32,124
64
+ else:
65
+ raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
66
+
67
+ grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) # 12-20 中间差值32份 0-32 中间差值124份
68
+ grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
69
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
70
+ grid = np.stack(grid, axis=0) # [2, W, H]
71
+ return grid
72
+
73
+
74
+ def get_2d_rotary_pos_embed(self, embed_dim, start, *args, use_real=True):
75
+ grid = self.get_meshgrid(start, *args) # [2, H, w]
76
+ grid = grid.reshape([2, 1, *grid.shape[1:]]) # 返回一个采样矩阵 分辨率与目标分辨率一致
77
+ pos_embed = self.get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
78
+ return pos_embed
79
+
80
+
81
+ def get_2d_rotary_pos_embed_from_grid(self, embed_dim, grid, use_real=False):
82
+ assert embed_dim % 4 == 0
83
+
84
+ # use half of dimensions to encode grid_h
85
+ emb_h = self.get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
86
+ emb_w = self.get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
87
+
88
+ if use_real:
89
+ cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
90
+ sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
91
+ return cos, sin
92
+ else:
93
+ emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
94
+ return emb
95
+
96
+
97
+ def get_1d_rotary_pos_embed(self, dim: int, pos, theta: float = 10000.0, use_real=False):
98
+ if isinstance(pos, int):
99
+ pos = np.arange(pos)
100
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
101
+ t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
102
+ freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
103
+ if use_real:
104
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
105
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
106
+ return freqs_cos, freqs_sin
107
+ else:
108
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
109
+ return freqs_cis
110
+
111
+
112
+ def calc_rope(self, height, width):
113
+ patch_size = 2
114
+ head_size = 88
115
+ th = height // 8 // patch_size
116
+ tw = width // 8 // patch_size
117
+ base_size = 512 // 8 // patch_size
118
+ start, stop = self.get_fill_resize_and_crop((th, tw), base_size)
119
+ sub_args = [start, stop, (th, tw)]
120
+ rope = self.get_2d_rotary_pos_embed(head_size, *sub_args)
121
+ return rope
122
+
123
+
124
+
125
+ class HunyuanDiTImagePipeline(torch.nn.Module):
126
+
127
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
128
+ super().__init__()
129
+ self.scheduler = EnhancedDDIMScheduler(prediction_type="v_prediction", beta_start=0.00085, beta_end=0.03)
130
+ self.prompter = HunyuanDiTPrompter()
131
+ self.device = device
132
+ self.torch_dtype = torch_dtype
133
+ self.image_size_manager = ImageSizeManager()
134
+ # models
135
+ self.text_encoder: HunyuanDiTCLIPTextEncoder = None
136
+ self.text_encoder_t5: HunyuanDiTT5TextEncoder = None
137
+ self.dit: HunyuanDiT = None
138
+ self.vae_decoder: SDXLVAEDecoder = None
139
+ self.vae_encoder: SDXLVAEEncoder = None
140
+
141
+
142
+ def fetch_main_models(self, model_manager: ModelManager):
143
+ self.text_encoder = model_manager.hunyuan_dit_clip_text_encoder
144
+ self.text_encoder_t5 = model_manager.hunyuan_dit_t5_text_encoder
145
+ self.dit = model_manager.hunyuan_dit
146
+ self.vae_decoder = model_manager.vae_decoder
147
+ self.vae_encoder = model_manager.vae_encoder
148
+
149
+
150
+ def fetch_prompter(self, model_manager: ModelManager):
151
+ self.prompter.load_from_model_manager(model_manager)
152
+
153
+
154
+ @staticmethod
155
+ def from_model_manager(model_manager: ModelManager):
156
+ pipe = HunyuanDiTImagePipeline(
157
+ device=model_manager.device,
158
+ torch_dtype=model_manager.torch_dtype,
159
+ )
160
+ pipe.fetch_main_models(model_manager)
161
+ pipe.fetch_prompter(model_manager)
162
+ return pipe
163
+
164
+
165
+ def preprocess_image(self, image):
166
+ image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
167
+ return image
168
+
169
+
170
+ def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
171
+ image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
172
+ image = image.cpu().permute(1, 2, 0).numpy()
173
+ image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
174
+ return image
175
+
176
+
177
+ def prepare_extra_input(self, height=1024, width=1024, tiled=False, tile_size=64, tile_stride=32, batch_size=1):
178
+ if tiled:
179
+ height, width = tile_size * 16, tile_size * 16
180
+ image_meta_size = torch.as_tensor([width, height, width, height, 0, 0]).to(device=self.device)
181
+ freqs_cis_img = self.image_size_manager.calc_rope(height, width)
182
+ image_meta_size = torch.stack([image_meta_size] * batch_size)
183
+ return {
184
+ "size_emb": image_meta_size,
185
+ "freq_cis_img": (freqs_cis_img[0].to(dtype=self.torch_dtype, device=self.device), freqs_cis_img[1].to(dtype=self.torch_dtype, device=self.device)),
186
+ "tiled": tiled,
187
+ "tile_size": tile_size,
188
+ "tile_stride": tile_stride
189
+ }
190
+
191
+
192
+ @torch.no_grad()
193
+ def __call__(
194
+ self,
195
+ prompt,
196
+ negative_prompt="",
197
+ cfg_scale=7.5,
198
+ clip_skip=1,
199
+ clip_skip_2=1,
200
+ input_image=None,
201
+ reference_images=[],
202
+ reference_strengths=[0.4],
203
+ denoising_strength=1.0,
204
+ height=1024,
205
+ width=1024,
206
+ num_inference_steps=20,
207
+ tiled=False,
208
+ tile_size=64,
209
+ tile_stride=32,
210
+ progress_bar_cmd=tqdm,
211
+ progress_bar_st=None,
212
+ ):
213
+ # Prepare scheduler
214
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
215
+
216
+ # Prepare latent tensors
217
+ noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
218
+ if input_image is not None:
219
+ image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
220
+ latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
221
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
222
+ else:
223
+ latents = noise.clone()
224
+
225
+ # Prepare reference latents
226
+ reference_latents = []
227
+ for reference_image in reference_images:
228
+ reference_image = self.preprocess_image(reference_image).to(device=self.device, dtype=self.torch_dtype)
229
+ reference_latents.append(self.vae_encoder(reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype))
230
+
231
+ # Encode prompts
232
+ prompt_emb_posi, attention_mask_posi, prompt_emb_t5_posi, attention_mask_t5_posi = self.prompter.encode_prompt(
233
+ self.text_encoder,
234
+ self.text_encoder_t5,
235
+ prompt,
236
+ clip_skip=clip_skip,
237
+ clip_skip_2=clip_skip_2,
238
+ positive=True,
239
+ device=self.device
240
+ )
241
+ if cfg_scale != 1.0:
242
+ prompt_emb_nega, attention_mask_nega, prompt_emb_t5_nega, attention_mask_t5_nega = self.prompter.encode_prompt(
243
+ self.text_encoder,
244
+ self.text_encoder_t5,
245
+ negative_prompt,
246
+ clip_skip=clip_skip,
247
+ clip_skip_2=clip_skip_2,
248
+ positive=False,
249
+ device=self.device
250
+ )
251
+
252
+ # Prepare positional id
253
+ extra_input = self.prepare_extra_input(height, width, tiled, tile_size)
254
+
255
+ # Denoise
256
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
257
+ timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device)
258
+
259
+ # In-context reference
260
+ for reference_latents_, reference_strength in zip(reference_latents, reference_strengths):
261
+ if progress_id < num_inference_steps * reference_strength:
262
+ noisy_reference_latents = self.scheduler.add_noise(reference_latents_, noise, self.scheduler.timesteps[progress_id])
263
+ self.dit(
264
+ noisy_reference_latents,
265
+ prompt_emb_posi, prompt_emb_t5_posi, attention_mask_posi, attention_mask_t5_posi,
266
+ timestep,
267
+ **extra_input,
268
+ to_cache=True
269
+ )
270
+ # Positive side
271
+ noise_pred_posi = self.dit(
272
+ latents,
273
+ prompt_emb_posi, prompt_emb_t5_posi, attention_mask_posi, attention_mask_t5_posi,
274
+ timestep,
275
+ **extra_input,
276
+ )
277
+ if cfg_scale != 1.0:
278
+ # Negative side
279
+ noise_pred_nega = self.dit(
280
+ latents,
281
+ prompt_emb_nega, prompt_emb_t5_nega, attention_mask_nega, attention_mask_t5_nega,
282
+ timestep,
283
+ **extra_input
284
+ )
285
+ # Classifier-free guidance
286
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
287
+ else:
288
+ noise_pred = noise_pred_posi
289
+
290
+ latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
291
+
292
+ if progress_bar_st is not None:
293
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
294
+
295
+ # Decode image
296
+ image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
297
+
298
+ return image
diffsynth/pipelines/stable_diffusion.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder
2
+ from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
3
+ from ..prompts import SDPrompter
4
+ from ..schedulers import EnhancedDDIMScheduler
5
+ from .dancer import lets_dance
6
+ from typing import List
7
+ import torch
8
+ from tqdm import tqdm
9
+ from PIL import Image
10
+ import numpy as np
11
+
12
+
13
+ class SDImagePipeline(torch.nn.Module):
14
+
15
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
16
+ super().__init__()
17
+ self.scheduler = EnhancedDDIMScheduler()
18
+ self.prompter = SDPrompter()
19
+ self.device = device
20
+ self.torch_dtype = torch_dtype
21
+ # models
22
+ self.text_encoder: SDTextEncoder = None
23
+ self.unet: SDUNet = None
24
+ self.vae_decoder: SDVAEDecoder = None
25
+ self.vae_encoder: SDVAEEncoder = None
26
+ self.controlnet: MultiControlNetManager = None
27
+ self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
28
+ self.ipadapter: SDIpAdapter = None
29
+
30
+
31
+ def fetch_main_models(self, model_manager: ModelManager):
32
+ self.text_encoder = model_manager.text_encoder
33
+ self.unet = model_manager.unet
34
+ self.vae_decoder = model_manager.vae_decoder
35
+ self.vae_encoder = model_manager.vae_encoder
36
+
37
+
38
+ def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
39
+ controlnet_units = []
40
+ for config in controlnet_config_units:
41
+ controlnet_unit = ControlNetUnit(
42
+ Annotator(config.processor_id),
43
+ model_manager.get_model_with_model_path(config.model_path),
44
+ config.scale
45
+ )
46
+ controlnet_units.append(controlnet_unit)
47
+ self.controlnet = MultiControlNetManager(controlnet_units)
48
+
49
+
50
+ def fetch_ipadapter(self, model_manager: ModelManager):
51
+ if "ipadapter" in model_manager.model:
52
+ self.ipadapter = model_manager.ipadapter
53
+ if "ipadapter_image_encoder" in model_manager.model:
54
+ self.ipadapter_image_encoder = model_manager.ipadapter_image_encoder
55
+
56
+
57
+ def fetch_prompter(self, model_manager: ModelManager):
58
+ self.prompter.load_from_model_manager(model_manager)
59
+
60
+
61
+ @staticmethod
62
+ def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
63
+ pipe = SDImagePipeline(
64
+ device=model_manager.device,
65
+ torch_dtype=model_manager.torch_dtype,
66
+ )
67
+ pipe.fetch_main_models(model_manager)
68
+ pipe.fetch_prompter(model_manager)
69
+ pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
70
+ pipe.fetch_ipadapter(model_manager)
71
+ return pipe
72
+
73
+
74
+ def preprocess_image(self, image):
75
+ image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
76
+ return image
77
+
78
+
79
+ def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
80
+ image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
81
+ image = image.cpu().permute(1, 2, 0).numpy()
82
+ image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
83
+ return image
84
+
85
+
86
+ @torch.no_grad()
87
+ def __call__(
88
+ self,
89
+ prompt,
90
+ negative_prompt="",
91
+ cfg_scale=7.5,
92
+ clip_skip=1,
93
+ input_image=None,
94
+ ipadapter_images=None,
95
+ ipadapter_scale=1.0,
96
+ controlnet_image=None,
97
+ denoising_strength=1.0,
98
+ height=512,
99
+ width=512,
100
+ num_inference_steps=20,
101
+ tiled=False,
102
+ tile_size=64,
103
+ tile_stride=32,
104
+ progress_bar_cmd=tqdm,
105
+ progress_bar_st=None,
106
+ ):
107
+ # Prepare scheduler
108
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
109
+
110
+ # Prepare latent tensors
111
+ if input_image is not None:
112
+ image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
113
+ latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
114
+ noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
115
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
116
+ else:
117
+ latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
118
+
119
+ # Encode prompts
120
+ prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True)
121
+ prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False)
122
+
123
+ # IP-Adapter
124
+ if ipadapter_images is not None:
125
+ ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
126
+ ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)
127
+ ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding))
128
+ else:
129
+ ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {}
130
+
131
+ # Prepare ControlNets
132
+ if controlnet_image is not None:
133
+ controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
134
+ controlnet_image = controlnet_image.unsqueeze(1)
135
+
136
+ # Denoise
137
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
138
+ timestep = torch.IntTensor((timestep,))[0].to(self.device)
139
+
140
+ # Classifier-free guidance
141
+ noise_pred_posi = lets_dance(
142
+ self.unet, motion_modules=None, controlnet=self.controlnet,
143
+ sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_image,
144
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
145
+ ipadapter_kwargs_list=ipadapter_kwargs_list_posi,
146
+ device=self.device, vram_limit_level=0
147
+ )
148
+ noise_pred_nega = lets_dance(
149
+ self.unet, motion_modules=None, controlnet=self.controlnet,
150
+ sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_image,
151
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
152
+ ipadapter_kwargs_list=ipadapter_kwargs_list_nega,
153
+ device=self.device, vram_limit_level=0
154
+ )
155
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
156
+
157
+ # DDIM
158
+ latents = self.scheduler.step(noise_pred, timestep, latents)
159
+
160
+ # UI
161
+ if progress_bar_st is not None:
162
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
163
+
164
+ # Decode image
165
+ image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
166
+
167
+ return image
diffsynth/pipelines/stable_diffusion_video.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDMotionModel
2
+ from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
3
+ from ..prompts import SDPrompter
4
+ from ..schedulers import EnhancedDDIMScheduler
5
+ from ..data import VideoData, save_frames, save_video
6
+ from .dancer import lets_dance
7
+ from ..processors.sequencial_processor import SequencialProcessor
8
+ from typing import List
9
+ import torch, os, json
10
+ from tqdm import tqdm
11
+ from PIL import Image
12
+ import numpy as np
13
+
14
+
15
+ def lets_dance_with_long_video(
16
+ unet: SDUNet,
17
+ motion_modules: SDMotionModel = None,
18
+ controlnet: MultiControlNetManager = None,
19
+ sample = None,
20
+ timestep = None,
21
+ encoder_hidden_states = None,
22
+ controlnet_frames = None,
23
+ animatediff_batch_size = 16,
24
+ animatediff_stride = 8,
25
+ unet_batch_size = 1,
26
+ controlnet_batch_size = 1,
27
+ cross_frame_attention = False,
28
+ device = "cuda",
29
+ vram_limit_level = 0,
30
+ ):
31
+ num_frames = sample.shape[0]
32
+ hidden_states_output = [(torch.zeros(sample[0].shape, dtype=sample[0].dtype), 0) for i in range(num_frames)]
33
+
34
+ for batch_id in range(0, num_frames, animatediff_stride):
35
+ batch_id_ = min(batch_id + animatediff_batch_size, num_frames)
36
+
37
+ # process this batch
38
+ hidden_states_batch = lets_dance(
39
+ unet, motion_modules, controlnet,
40
+ sample[batch_id: batch_id_].to(device),
41
+ timestep,
42
+ encoder_hidden_states[batch_id: batch_id_].to(device),
43
+ controlnet_frames=controlnet_frames[:, batch_id: batch_id_].to(device) if controlnet_frames is not None else None,
44
+ unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
45
+ cross_frame_attention=cross_frame_attention,
46
+ device=device, vram_limit_level=vram_limit_level
47
+ ).cpu()
48
+
49
+ # update hidden_states
50
+ for i, hidden_states_updated in zip(range(batch_id, batch_id_), hidden_states_batch):
51
+ bias = max(1 - abs(i - (batch_id + batch_id_ - 1) / 2) / ((batch_id_ - batch_id - 1 + 1e-2) / 2), 1e-2)
52
+ hidden_states, num = hidden_states_output[i]
53
+ hidden_states = hidden_states * (num / (num + bias)) + hidden_states_updated * (bias / (num + bias))
54
+ hidden_states_output[i] = (hidden_states, num + bias)
55
+
56
+ if batch_id_ == num_frames:
57
+ break
58
+
59
+ # output
60
+ hidden_states = torch.stack([h for h, _ in hidden_states_output])
61
+ return hidden_states
62
+
63
+
64
+ class SDVideoPipeline(torch.nn.Module):
65
+
66
+ def __init__(self, device="cuda", torch_dtype=torch.float16, use_animatediff=True):
67
+ super().__init__()
68
+ self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_animatediff else "scaled_linear")
69
+ self.prompter = SDPrompter()
70
+ self.device = device
71
+ self.torch_dtype = torch_dtype
72
+ # models
73
+ self.text_encoder: SDTextEncoder = None
74
+ self.unet: SDUNet = None
75
+ self.vae_decoder: SDVAEDecoder = None
76
+ self.vae_encoder: SDVAEEncoder = None
77
+ self.controlnet: MultiControlNetManager = None
78
+ self.motion_modules: SDMotionModel = None
79
+
80
+
81
+ def fetch_main_models(self, model_manager: ModelManager):
82
+ self.text_encoder = model_manager.text_encoder
83
+ self.unet = model_manager.unet
84
+ self.vae_decoder = model_manager.vae_decoder
85
+ self.vae_encoder = model_manager.vae_encoder
86
+
87
+
88
+ def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
89
+ controlnet_units = []
90
+ for config in controlnet_config_units:
91
+ controlnet_unit = ControlNetUnit(
92
+ Annotator(config.processor_id),
93
+ model_manager.get_model_with_model_path(config.model_path),
94
+ config.scale
95
+ )
96
+ controlnet_units.append(controlnet_unit)
97
+ self.controlnet = MultiControlNetManager(controlnet_units)
98
+
99
+
100
+ def fetch_motion_modules(self, model_manager: ModelManager):
101
+ if "motion_modules" in model_manager.model:
102
+ self.motion_modules = model_manager.motion_modules
103
+
104
+
105
+ def fetch_prompter(self, model_manager: ModelManager):
106
+ self.prompter.load_from_model_manager(model_manager)
107
+
108
+
109
+ @staticmethod
110
+ def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
111
+ pipe = SDVideoPipeline(
112
+ device=model_manager.device,
113
+ torch_dtype=model_manager.torch_dtype,
114
+ use_animatediff="motion_modules" in model_manager.model
115
+ )
116
+ pipe.fetch_main_models(model_manager)
117
+ pipe.fetch_motion_modules(model_manager)
118
+ pipe.fetch_prompter(model_manager)
119
+ pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
120
+ return pipe
121
+
122
+
123
+ def preprocess_image(self, image):
124
+ image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
125
+ return image
126
+
127
+
128
+ def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
129
+ image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
130
+ image = image.cpu().permute(1, 2, 0).numpy()
131
+ image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
132
+ return image
133
+
134
+
135
+ def decode_images(self, latents, tiled=False, tile_size=64, tile_stride=32):
136
+ images = [
137
+ self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
138
+ for frame_id in range(latents.shape[0])
139
+ ]
140
+ return images
141
+
142
+
143
+ def encode_images(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
144
+ latents = []
145
+ for image in processed_images:
146
+ image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
147
+ latent = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).cpu()
148
+ latents.append(latent)
149
+ latents = torch.concat(latents, dim=0)
150
+ return latents
151
+
152
+
153
+ @torch.no_grad()
154
+ def __call__(
155
+ self,
156
+ prompt,
157
+ negative_prompt="",
158
+ cfg_scale=7.5,
159
+ clip_skip=1,
160
+ num_frames=None,
161
+ input_frames=None,
162
+ controlnet_frames=None,
163
+ denoising_strength=1.0,
164
+ height=512,
165
+ width=512,
166
+ num_inference_steps=20,
167
+ animatediff_batch_size = 16,
168
+ animatediff_stride = 8,
169
+ unet_batch_size = 1,
170
+ controlnet_batch_size = 1,
171
+ cross_frame_attention = False,
172
+ smoother=None,
173
+ smoother_progress_ids=[],
174
+ vram_limit_level=0,
175
+ progress_bar_cmd=tqdm,
176
+ progress_bar_st=None,
177
+ ):
178
+ # Prepare scheduler
179
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
180
+
181
+ # Prepare latent tensors
182
+ if self.motion_modules is None:
183
+ noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
184
+ else:
185
+ noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype)
186
+ if input_frames is None or denoising_strength == 1.0:
187
+ latents = noise
188
+ else:
189
+ latents = self.encode_images(input_frames)
190
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
191
+
192
+ # Encode prompts
193
+ prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True).cpu()
194
+ prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False).cpu()
195
+ prompt_emb_posi = prompt_emb_posi.repeat(num_frames, 1, 1)
196
+ prompt_emb_nega = prompt_emb_nega.repeat(num_frames, 1, 1)
197
+
198
+ # Prepare ControlNets
199
+ if controlnet_frames is not None:
200
+ if isinstance(controlnet_frames[0], list):
201
+ controlnet_frames_ = []
202
+ for processor_id in range(len(controlnet_frames)):
203
+ controlnet_frames_.append(
204
+ torch.stack([
205
+ self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
206
+ for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
207
+ ], dim=1)
208
+ )
209
+ controlnet_frames = torch.concat(controlnet_frames_, dim=0)
210
+ else:
211
+ controlnet_frames = torch.stack([
212
+ self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
213
+ for controlnet_frame in progress_bar_cmd(controlnet_frames)
214
+ ], dim=1)
215
+
216
+ # Denoise
217
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
218
+ timestep = torch.IntTensor((timestep,))[0].to(self.device)
219
+
220
+ # Classifier-free guidance
221
+ noise_pred_posi = lets_dance_with_long_video(
222
+ self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
223
+ sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_frames,
224
+ animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride,
225
+ unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
226
+ cross_frame_attention=cross_frame_attention,
227
+ device=self.device, vram_limit_level=vram_limit_level
228
+ )
229
+ noise_pred_nega = lets_dance_with_long_video(
230
+ self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
231
+ sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
232
+ animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride,
233
+ unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
234
+ cross_frame_attention=cross_frame_attention,
235
+ device=self.device, vram_limit_level=vram_limit_level
236
+ )
237
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
238
+
239
+ # DDIM and smoother
240
+ if smoother is not None and progress_id in smoother_progress_ids:
241
+ rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
242
+ rendered_frames = self.decode_images(rendered_frames)
243
+ rendered_frames = smoother(rendered_frames, original_frames=input_frames)
244
+ target_latents = self.encode_images(rendered_frames)
245
+ noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
246
+ latents = self.scheduler.step(noise_pred, timestep, latents)
247
+
248
+ # UI
249
+ if progress_bar_st is not None:
250
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
251
+
252
+ # Decode image
253
+ output_frames = self.decode_images(latents)
254
+
255
+ # Post-process
256
+ if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
257
+ output_frames = smoother(output_frames, original_frames=input_frames)
258
+
259
+ return output_frames
260
+
261
+
262
+
263
+ class SDVideoPipelineRunner:
264
+ def __init__(self, in_streamlit=False):
265
+ self.in_streamlit = in_streamlit
266
+
267
+
268
+ def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units):
269
+ # Load models
270
+ model_manager = ModelManager(torch_dtype=torch.float16, device=device)
271
+ model_manager.load_textual_inversions(textual_inversion_folder)
272
+ model_manager.load_models(model_list, lora_alphas=lora_alphas)
273
+ pipe = SDVideoPipeline.from_model_manager(
274
+ model_manager,
275
+ [
276
+ ControlNetConfigUnit(
277
+ processor_id=unit["processor_id"],
278
+ model_path=unit["model_path"],
279
+ scale=unit["scale"]
280
+ ) for unit in controlnet_units
281
+ ]
282
+ )
283
+ return model_manager, pipe
284
+
285
+
286
+ def load_smoother(self, model_manager, smoother_configs):
287
+ smoother = SequencialProcessor.from_model_manager(model_manager, smoother_configs)
288
+ return smoother
289
+
290
+
291
+ def synthesize_video(self, model_manager, pipe, seed, smoother, **pipeline_inputs):
292
+ torch.manual_seed(seed)
293
+ if self.in_streamlit:
294
+ import streamlit as st
295
+ progress_bar_st = st.progress(0.0)
296
+ output_video = pipe(**pipeline_inputs, smoother=smoother, progress_bar_st=progress_bar_st)
297
+ progress_bar_st.progress(1.0)
298
+ else:
299
+ output_video = pipe(**pipeline_inputs, smoother=smoother)
300
+ model_manager.to("cpu")
301
+ return output_video
302
+
303
+
304
+ def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id):
305
+ video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width)
306
+ if start_frame_id is None:
307
+ start_frame_id = 0
308
+ if end_frame_id is None:
309
+ end_frame_id = len(video)
310
+ frames = [video[i] for i in range(start_frame_id, end_frame_id)]
311
+ return frames
312
+
313
+
314
+ def add_data_to_pipeline_inputs(self, data, pipeline_inputs):
315
+ pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"])
316
+ pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"])
317
+ pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size
318
+ if len(data["controlnet_frames"]) > 0:
319
+ pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]]
320
+ return pipeline_inputs
321
+
322
+
323
+ def save_output(self, video, output_folder, fps, config):
324
+ os.makedirs(output_folder, exist_ok=True)
325
+ save_frames(video, os.path.join(output_folder, "frames"))
326
+ save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps)
327
+ config["pipeline"]["pipeline_inputs"]["input_frames"] = []
328
+ config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = []
329
+ with open(os.path.join(output_folder, "config.json"), 'w') as file:
330
+ json.dump(config, file, indent=4)
331
+
332
+
333
+ def run(self, config):
334
+ if self.in_streamlit:
335
+ import streamlit as st
336
+ if self.in_streamlit: st.markdown("Loading videos ...")
337
+ config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"])
338
+ if self.in_streamlit: st.markdown("Loading videos ... done!")
339
+ if self.in_streamlit: st.markdown("Loading models ...")
340
+ model_manager, pipe = self.load_pipeline(**config["models"])
341
+ if self.in_streamlit: st.markdown("Loading models ... done!")
342
+ if "smoother_configs" in config:
343
+ if self.in_streamlit: st.markdown("Loading smoother ...")
344
+ smoother = self.load_smoother(model_manager, config["smoother_configs"])
345
+ if self.in_streamlit: st.markdown("Loading smoother ... done!")
346
+ else:
347
+ smoother = None
348
+ if self.in_streamlit: st.markdown("Synthesizing videos ...")
349
+ output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], smoother, **config["pipeline"]["pipeline_inputs"])
350
+ if self.in_streamlit: st.markdown("Synthesizing videos ... done!")
351
+ if self.in_streamlit: st.markdown("Saving videos ...")
352
+ self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config)
353
+ if self.in_streamlit: st.markdown("Saving videos ... done!")
354
+ if self.in_streamlit: st.markdown("Finished!")
355
+ video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb')
356
+ if self.in_streamlit: st.video(video_file.read())
diffsynth/pipelines/stable_diffusion_xl.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
2
+ # TODO: SDXL ControlNet
3
+ from ..prompts import SDXLPrompter
4
+ from ..schedulers import EnhancedDDIMScheduler
5
+ from .dancer import lets_dance_xl
6
+ import torch
7
+ from tqdm import tqdm
8
+ from PIL import Image
9
+ import numpy as np
10
+
11
+
12
+ class SDXLImagePipeline(torch.nn.Module):
13
+
14
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
15
+ super().__init__()
16
+ self.scheduler = EnhancedDDIMScheduler()
17
+ self.prompter = SDXLPrompter()
18
+ self.device = device
19
+ self.torch_dtype = torch_dtype
20
+ # models
21
+ self.text_encoder: SDXLTextEncoder = None
22
+ self.text_encoder_2: SDXLTextEncoder2 = None
23
+ self.unet: SDXLUNet = None
24
+ self.vae_decoder: SDXLVAEDecoder = None
25
+ self.vae_encoder: SDXLVAEEncoder = None
26
+ self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
27
+ self.ipadapter: SDXLIpAdapter = None
28
+ # TODO: SDXL ControlNet
29
+
30
+ def fetch_main_models(self, model_manager: ModelManager):
31
+ self.text_encoder = model_manager.text_encoder
32
+ self.text_encoder_2 = model_manager.text_encoder_2
33
+ self.unet = model_manager.unet
34
+ self.vae_decoder = model_manager.vae_decoder
35
+ self.vae_encoder = model_manager.vae_encoder
36
+
37
+
38
+ def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs):
39
+ # TODO: SDXL ControlNet
40
+ pass
41
+
42
+
43
+ def fetch_ipadapter(self, model_manager: ModelManager):
44
+ if "ipadapter_xl" in model_manager.model:
45
+ self.ipadapter = model_manager.ipadapter_xl
46
+ if "ipadapter_xl_image_encoder" in model_manager.model:
47
+ self.ipadapter_image_encoder = model_manager.ipadapter_xl_image_encoder
48
+
49
+
50
+ def fetch_prompter(self, model_manager: ModelManager):
51
+ self.prompter.load_from_model_manager(model_manager)
52
+
53
+
54
+ @staticmethod
55
+ def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs):
56
+ pipe = SDXLImagePipeline(
57
+ device=model_manager.device,
58
+ torch_dtype=model_manager.torch_dtype,
59
+ )
60
+ pipe.fetch_main_models(model_manager)
61
+ pipe.fetch_prompter(model_manager)
62
+ pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
63
+ pipe.fetch_ipadapter(model_manager)
64
+ return pipe
65
+
66
+
67
+ def preprocess_image(self, image):
68
+ image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
69
+ return image
70
+
71
+
72
+ def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
73
+ image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
74
+ image = image.cpu().permute(1, 2, 0).numpy()
75
+ image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
76
+ return image
77
+
78
+
79
+ @torch.no_grad()
80
+ def __call__(
81
+ self,
82
+ prompt,
83
+ negative_prompt="",
84
+ cfg_scale=7.5,
85
+ clip_skip=1,
86
+ clip_skip_2=2,
87
+ input_image=None,
88
+ ipadapter_images=None,
89
+ ipadapter_scale=1.0,
90
+ controlnet_image=None,
91
+ denoising_strength=1.0,
92
+ height=1024,
93
+ width=1024,
94
+ num_inference_steps=20,
95
+ tiled=False,
96
+ tile_size=64,
97
+ tile_stride=32,
98
+ progress_bar_cmd=tqdm,
99
+ progress_bar_st=None,
100
+ ):
101
+ # Prepare scheduler
102
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
103
+
104
+ # Prepare latent tensors
105
+ if input_image is not None:
106
+ image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
107
+ latents = self.vae_encoder(image.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
108
+ noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
109
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
110
+ else:
111
+ latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
112
+
113
+ # Encode prompts
114
+ add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt(
115
+ self.text_encoder,
116
+ self.text_encoder_2,
117
+ prompt,
118
+ clip_skip=clip_skip, clip_skip_2=clip_skip_2,
119
+ device=self.device,
120
+ positive=True,
121
+ )
122
+ if cfg_scale != 1.0:
123
+ add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
124
+ self.text_encoder,
125
+ self.text_encoder_2,
126
+ negative_prompt,
127
+ clip_skip=clip_skip, clip_skip_2=clip_skip_2,
128
+ device=self.device,
129
+ positive=False,
130
+ )
131
+
132
+ # Prepare positional id
133
+ add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
134
+
135
+ # IP-Adapter
136
+ if ipadapter_images is not None:
137
+ ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
138
+ ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)
139
+ ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding))
140
+ else:
141
+ ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {}
142
+
143
+ # Denoise
144
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
145
+ timestep = torch.IntTensor((timestep,))[0].to(self.device)
146
+
147
+ # Classifier-free guidance
148
+ noise_pred_posi = lets_dance_xl(
149
+ self.unet,
150
+ sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi,
151
+ add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
152
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
153
+ ipadapter_kwargs_list=ipadapter_kwargs_list_posi,
154
+ )
155
+ if cfg_scale != 1.0:
156
+ noise_pred_nega = lets_dance_xl(
157
+ self.unet,
158
+ sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega,
159
+ add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
160
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
161
+ ipadapter_kwargs_list=ipadapter_kwargs_list_nega,
162
+ )
163
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
164
+ else:
165
+ noise_pred = noise_pred_posi
166
+
167
+ latents = self.scheduler.step(noise_pred, timestep, latents)
168
+
169
+ if progress_bar_st is not None:
170
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
171
+
172
+ # Decode image
173
+ image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
174
+
175
+ return image
diffsynth/pipelines/stable_diffusion_xl_video.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLMotionModel
2
+ from .dancer import lets_dance_xl
3
+ # TODO: SDXL ControlNet
4
+ from ..prompts import SDXLPrompter
5
+ from ..schedulers import EnhancedDDIMScheduler
6
+ import torch
7
+ from tqdm import tqdm
8
+ from PIL import Image
9
+ import numpy as np
10
+
11
+
12
+ class SDXLVideoPipeline(torch.nn.Module):
13
+
14
+ def __init__(self, device="cuda", torch_dtype=torch.float16, use_animatediff=True):
15
+ super().__init__()
16
+ self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_animatediff else "scaled_linear")
17
+ self.prompter = SDXLPrompter()
18
+ self.device = device
19
+ self.torch_dtype = torch_dtype
20
+ # models
21
+ self.text_encoder: SDXLTextEncoder = None
22
+ self.text_encoder_2: SDXLTextEncoder2 = None
23
+ self.unet: SDXLUNet = None
24
+ self.vae_decoder: SDXLVAEDecoder = None
25
+ self.vae_encoder: SDXLVAEEncoder = None
26
+ # TODO: SDXL ControlNet
27
+ self.motion_modules: SDXLMotionModel = None
28
+
29
+
30
+ def fetch_main_models(self, model_manager: ModelManager):
31
+ self.text_encoder = model_manager.text_encoder
32
+ self.text_encoder_2 = model_manager.text_encoder_2
33
+ self.unet = model_manager.unet
34
+ self.vae_decoder = model_manager.vae_decoder
35
+ self.vae_encoder = model_manager.vae_encoder
36
+
37
+
38
+ def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs):
39
+ # TODO: SDXL ControlNet
40
+ pass
41
+
42
+
43
+ def fetch_motion_modules(self, model_manager: ModelManager):
44
+ if "motion_modules_xl" in model_manager.model:
45
+ self.motion_modules = model_manager.motion_modules_xl
46
+
47
+
48
+ def fetch_prompter(self, model_manager: ModelManager):
49
+ self.prompter.load_from_model_manager(model_manager)
50
+
51
+
52
+ @staticmethod
53
+ def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs):
54
+ pipe = SDXLVideoPipeline(
55
+ device=model_manager.device,
56
+ torch_dtype=model_manager.torch_dtype,
57
+ use_animatediff="motion_modules_xl" in model_manager.model
58
+ )
59
+ pipe.fetch_main_models(model_manager)
60
+ pipe.fetch_motion_modules(model_manager)
61
+ pipe.fetch_prompter(model_manager)
62
+ pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
63
+ return pipe
64
+
65
+
66
+ def preprocess_image(self, image):
67
+ image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
68
+ return image
69
+
70
+
71
+ def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
72
+ image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
73
+ image = image.cpu().permute(1, 2, 0).numpy()
74
+ image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
75
+ return image
76
+
77
+
78
+ def decode_images(self, latents, tiled=False, tile_size=64, tile_stride=32):
79
+ images = [
80
+ self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
81
+ for frame_id in range(latents.shape[0])
82
+ ]
83
+ return images
84
+
85
+
86
+ def encode_images(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
87
+ latents = []
88
+ for image in processed_images:
89
+ image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
90
+ latent = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).cpu()
91
+ latents.append(latent)
92
+ latents = torch.concat(latents, dim=0)
93
+ return latents
94
+
95
+
96
+ @torch.no_grad()
97
+ def __call__(
98
+ self,
99
+ prompt,
100
+ negative_prompt="",
101
+ cfg_scale=7.5,
102
+ clip_skip=1,
103
+ clip_skip_2=2,
104
+ num_frames=None,
105
+ input_frames=None,
106
+ controlnet_frames=None,
107
+ denoising_strength=1.0,
108
+ height=512,
109
+ width=512,
110
+ num_inference_steps=20,
111
+ animatediff_batch_size = 16,
112
+ animatediff_stride = 8,
113
+ unet_batch_size = 1,
114
+ controlnet_batch_size = 1,
115
+ cross_frame_attention = False,
116
+ smoother=None,
117
+ smoother_progress_ids=[],
118
+ vram_limit_level=0,
119
+ progress_bar_cmd=tqdm,
120
+ progress_bar_st=None,
121
+ ):
122
+ # Prepare scheduler
123
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
124
+
125
+ # Prepare latent tensors
126
+ if self.motion_modules is None:
127
+ noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
128
+ else:
129
+ noise = torch.randn((num_frames, 4, height//8, width//8), device="cuda", dtype=self.torch_dtype)
130
+ if input_frames is None or denoising_strength == 1.0:
131
+ latents = noise
132
+ else:
133
+ latents = self.encode_images(input_frames)
134
+ latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
135
+
136
+ # Encode prompts
137
+ add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt(
138
+ self.text_encoder,
139
+ self.text_encoder_2,
140
+ prompt,
141
+ clip_skip=clip_skip, clip_skip_2=clip_skip_2,
142
+ device=self.device,
143
+ positive=True,
144
+ )
145
+ if cfg_scale != 1.0:
146
+ add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
147
+ self.text_encoder,
148
+ self.text_encoder_2,
149
+ negative_prompt,
150
+ clip_skip=clip_skip, clip_skip_2=clip_skip_2,
151
+ device=self.device,
152
+ positive=False,
153
+ )
154
+
155
+ # Prepare positional id
156
+ add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
157
+
158
+ # Denoise
159
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
160
+ timestep = torch.IntTensor((timestep,))[0].to(self.device)
161
+
162
+ # Classifier-free guidance
163
+ noise_pred_posi = lets_dance_xl(
164
+ self.unet, motion_modules=self.motion_modules, controlnet=None,
165
+ sample=latents, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
166
+ timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_frames,
167
+ cross_frame_attention=cross_frame_attention,
168
+ device=self.device, vram_limit_level=vram_limit_level
169
+ )
170
+ if cfg_scale != 1.0:
171
+ noise_pred_nega = lets_dance_xl(
172
+ self.unet, motion_modules=self.motion_modules, controlnet=None,
173
+ sample=latents, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
174
+ timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
175
+ cross_frame_attention=cross_frame_attention,
176
+ device=self.device, vram_limit_level=vram_limit_level
177
+ )
178
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
179
+ else:
180
+ noise_pred = noise_pred_posi
181
+
182
+ latents = self.scheduler.step(noise_pred, timestep, latents)
183
+
184
+ if progress_bar_st is not None:
185
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
186
+
187
+ # Decode image
188
+ image = self.decode_images(latents.to(torch.float32))
189
+
190
+ return image
diffsynth/pipelines/stable_video_diffusion.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..models import ModelManager, SVDImageEncoder, SVDUNet, SVDVAEEncoder, SVDVAEDecoder
2
+ from ..schedulers import ContinuousODEScheduler
3
+ import torch
4
+ from tqdm import tqdm
5
+ from PIL import Image
6
+ import numpy as np
7
+ from einops import rearrange, repeat
8
+
9
+
10
+
11
+ class SVDVideoPipeline(torch.nn.Module):
12
+
13
+ def __init__(self, device="cuda", torch_dtype=torch.float16):
14
+ super().__init__()
15
+ self.scheduler = ContinuousODEScheduler()
16
+ self.device = device
17
+ self.torch_dtype = torch_dtype
18
+ # models
19
+ self.image_encoder: SVDImageEncoder = None
20
+ self.unet: SVDUNet = None
21
+ self.vae_encoder: SVDVAEEncoder = None
22
+ self.vae_decoder: SVDVAEDecoder = None
23
+
24
+
25
+ def fetch_main_models(self, model_manager: ModelManager):
26
+ self.image_encoder = model_manager.image_encoder
27
+ self.unet = model_manager.unet
28
+ self.vae_encoder = model_manager.vae_encoder
29
+ self.vae_decoder = model_manager.vae_decoder
30
+
31
+
32
+ @staticmethod
33
+ def from_model_manager(model_manager: ModelManager, **kwargs):
34
+ pipe = SVDVideoPipeline(device=model_manager.device, torch_dtype=model_manager.torch_dtype)
35
+ pipe.fetch_main_models(model_manager)
36
+ return pipe
37
+
38
+
39
+ def preprocess_image(self, image):
40
+ image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
41
+ return image
42
+
43
+
44
+ def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
45
+ image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
46
+ image = image.cpu().permute(1, 2, 0).numpy()
47
+ image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
48
+ return image
49
+
50
+
51
+ def encode_image_with_clip(self, image):
52
+ image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
53
+ image = SVDCLIPImageProcessor().resize_with_antialiasing(image, (224, 224))
54
+ image = (image + 1.0) / 2.0
55
+ mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).reshape(1, 3, 1, 1).to(device=self.device, dtype=self.torch_dtype)
56
+ std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).reshape(1, 3, 1, 1).to(device=self.device, dtype=self.torch_dtype)
57
+ image = (image - mean) / std
58
+ image_emb = self.image_encoder(image)
59
+ return image_emb
60
+
61
+
62
+ def encode_image_with_vae(self, image, noise_aug_strength):
63
+ image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
64
+ noise = torch.randn(image.shape, device="cpu", dtype=self.torch_dtype).to(self.device)
65
+ image = image + noise_aug_strength * noise
66
+ image_emb = self.vae_encoder(image) / self.vae_encoder.scaling_factor
67
+ return image_emb
68
+
69
+
70
+ def encode_video_with_vae(self, video):
71
+ video = torch.concat([self.preprocess_image(frame) for frame in video], dim=0)
72
+ video = rearrange(video, "T C H W -> 1 C T H W")
73
+ video = video.to(device=self.device, dtype=self.torch_dtype)
74
+ latents = self.vae_encoder.encode_video(video)
75
+ latents = rearrange(latents[0], "C T H W -> T C H W")
76
+ return latents
77
+
78
+
79
+ def tensor2video(self, frames):
80
+ frames = rearrange(frames, "C T H W -> T H W C")
81
+ frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
82
+ frames = [Image.fromarray(frame) for frame in frames]
83
+ return frames
84
+
85
+
86
+ def calculate_noise_pred(
87
+ self,
88
+ latents,
89
+ timestep,
90
+ add_time_id,
91
+ cfg_scales,
92
+ image_emb_vae_posi, image_emb_clip_posi,
93
+ image_emb_vae_nega, image_emb_clip_nega
94
+ ):
95
+ # Positive side
96
+ noise_pred_posi = self.unet(
97
+ torch.cat([latents, image_emb_vae_posi], dim=1),
98
+ timestep, image_emb_clip_posi, add_time_id
99
+ )
100
+ # Negative side
101
+ noise_pred_nega = self.unet(
102
+ torch.cat([latents, image_emb_vae_nega], dim=1),
103
+ timestep, image_emb_clip_nega, add_time_id
104
+ )
105
+
106
+ # Classifier-free guidance
107
+ noise_pred = noise_pred_nega + cfg_scales * (noise_pred_posi - noise_pred_nega)
108
+
109
+ return noise_pred
110
+
111
+
112
+ def post_process_latents(self, latents, post_normalize=True, contrast_enhance_scale=1.0):
113
+ if post_normalize:
114
+ mean, std = latents.mean(), latents.std()
115
+ latents = (latents - latents.mean(dim=[1, 2, 3], keepdim=True)) / latents.std(dim=[1, 2, 3], keepdim=True) * std + mean
116
+ latents = latents * contrast_enhance_scale
117
+ return latents
118
+
119
+
120
+ @torch.no_grad()
121
+ def __call__(
122
+ self,
123
+ input_image=None,
124
+ input_video=None,
125
+ mask_frames=[],
126
+ mask_frame_ids=[],
127
+ min_cfg_scale=1.0,
128
+ max_cfg_scale=3.0,
129
+ denoising_strength=1.0,
130
+ num_frames=25,
131
+ height=576,
132
+ width=1024,
133
+ fps=7,
134
+ motion_bucket_id=127,
135
+ noise_aug_strength=0.02,
136
+ num_inference_steps=20,
137
+ post_normalize=True,
138
+ contrast_enhance_scale=1.2,
139
+ progress_bar_cmd=tqdm,
140
+ progress_bar_st=None,
141
+ ):
142
+ # Prepare scheduler
143
+ self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
144
+
145
+ # Prepare latent tensors
146
+ noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).to(self.device)
147
+ if denoising_strength == 1.0:
148
+ latents = noise.clone()
149
+ else:
150
+ latents = self.encode_video_with_vae(input_video)
151
+ latents = self.scheduler.add_noise(latents, noise, self.scheduler.timesteps[0])
152
+
153
+ # Prepare mask frames
154
+ if len(mask_frames) > 0:
155
+ mask_latents = self.encode_video_with_vae(mask_frames)
156
+
157
+ # Encode image
158
+ image_emb_clip_posi = self.encode_image_with_clip(input_image)
159
+ image_emb_clip_nega = torch.zeros_like(image_emb_clip_posi)
160
+ image_emb_vae_posi = repeat(self.encode_image_with_vae(input_image, noise_aug_strength), "B C H W -> (B T) C H W", T=num_frames)
161
+ image_emb_vae_nega = torch.zeros_like(image_emb_vae_posi)
162
+
163
+ # Prepare classifier-free guidance
164
+ cfg_scales = torch.linspace(min_cfg_scale, max_cfg_scale, num_frames)
165
+ cfg_scales = cfg_scales.reshape(num_frames, 1, 1, 1).to(device=self.device, dtype=self.torch_dtype)
166
+
167
+ # Prepare positional id
168
+ add_time_id = torch.tensor([[fps-1, motion_bucket_id, noise_aug_strength]], device=self.device)
169
+
170
+ # Denoise
171
+ for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
172
+
173
+ # Mask frames
174
+ for frame_id, mask_frame_id in enumerate(mask_frame_ids):
175
+ latents[mask_frame_id] = self.scheduler.add_noise(mask_latents[frame_id], noise[mask_frame_id], timestep)
176
+
177
+ # Fetch model output
178
+ noise_pred = self.calculate_noise_pred(
179
+ latents, timestep, add_time_id, cfg_scales,
180
+ image_emb_vae_posi, image_emb_clip_posi, image_emb_vae_nega, image_emb_clip_nega
181
+ )
182
+
183
+ # Forward Euler
184
+ latents = self.scheduler.step(noise_pred, timestep, latents)
185
+
186
+ # Update progress bar
187
+ if progress_bar_st is not None:
188
+ progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
189
+
190
+ # Decode image
191
+ latents = self.post_process_latents(latents, post_normalize=post_normalize, contrast_enhance_scale=contrast_enhance_scale)
192
+ video = self.vae_decoder.decode_video(latents, progress_bar=progress_bar_cmd)
193
+ video = self.tensor2video(video)
194
+
195
+ return video
196
+
197
+
198
+
199
+ class SVDCLIPImageProcessor:
200
+ def __init__(self):
201
+ pass
202
+
203
+ def resize_with_antialiasing(self, input, size, interpolation="bicubic", align_corners=True):
204
+ h, w = input.shape[-2:]
205
+ factors = (h / size[0], w / size[1])
206
+
207
+ # First, we have to determine sigma
208
+ # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
209
+ sigmas = (
210
+ max((factors[0] - 1.0) / 2.0, 0.001),
211
+ max((factors[1] - 1.0) / 2.0, 0.001),
212
+ )
213
+
214
+ # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
215
+ # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
216
+ # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
217
+ ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
218
+
219
+ # Make sure it is odd
220
+ if (ks[0] % 2) == 0:
221
+ ks = ks[0] + 1, ks[1]
222
+
223
+ if (ks[1] % 2) == 0:
224
+ ks = ks[0], ks[1] + 1
225
+
226
+ input = self._gaussian_blur2d(input, ks, sigmas)
227
+
228
+ output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
229
+ return output
230
+
231
+
232
+ def _compute_padding(self, kernel_size):
233
+ """Compute padding tuple."""
234
+ # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
235
+ # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
236
+ if len(kernel_size) < 2:
237
+ raise AssertionError(kernel_size)
238
+ computed = [k - 1 for k in kernel_size]
239
+
240
+ # for even kernels we need to do asymmetric padding :(
241
+ out_padding = 2 * len(kernel_size) * [0]
242
+
243
+ for i in range(len(kernel_size)):
244
+ computed_tmp = computed[-(i + 1)]
245
+
246
+ pad_front = computed_tmp // 2
247
+ pad_rear = computed_tmp - pad_front
248
+
249
+ out_padding[2 * i + 0] = pad_front
250
+ out_padding[2 * i + 1] = pad_rear
251
+
252
+ return out_padding
253
+
254
+
255
+ def _filter2d(self, input, kernel):
256
+ # prepare kernel
257
+ b, c, h, w = input.shape
258
+ tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
259
+
260
+ tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
261
+
262
+ height, width = tmp_kernel.shape[-2:]
263
+
264
+ padding_shape: list[int] = self._compute_padding([height, width])
265
+ input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
266
+
267
+ # kernel and input tensor reshape to align element-wise or batch-wise params
268
+ tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
269
+ input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
270
+
271
+ # convolve the tensor with the kernel.
272
+ output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
273
+
274
+ out = output.view(b, c, h, w)
275
+ return out
276
+
277
+
278
+ def _gaussian(self, window_size: int, sigma):
279
+ if isinstance(sigma, float):
280
+ sigma = torch.tensor([[sigma]])
281
+
282
+ batch_size = sigma.shape[0]
283
+
284
+ x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
285
+
286
+ if window_size % 2 == 0:
287
+ x = x + 0.5
288
+
289
+ gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
290
+
291
+ return gauss / gauss.sum(-1, keepdim=True)
292
+
293
+
294
+ def _gaussian_blur2d(self, input, kernel_size, sigma):
295
+ if isinstance(sigma, tuple):
296
+ sigma = torch.tensor([sigma], dtype=input.dtype)
297
+ else:
298
+ sigma = sigma.to(dtype=input.dtype)
299
+
300
+ ky, kx = int(kernel_size[0]), int(kernel_size[1])
301
+ bs = sigma.shape[0]
302
+ kernel_x = self._gaussian(kx, sigma[:, 1].view(bs, 1))
303
+ kernel_y = self._gaussian(ky, sigma[:, 0].view(bs, 1))
304
+ out_x = self._filter2d(input, kernel_x[..., None, :])
305
+ out = self._filter2d(out_x, kernel_y[..., None])
306
+
307
+ return out
diffsynth/processors/FastBlend.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import cupy as cp
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from ..extensions.FastBlend.patch_match import PyramidPatchMatcher
6
+ from ..extensions.FastBlend.runners.fast import TableManager
7
+ from .base import VideoProcessor
8
+
9
+
10
+ class FastBlendSmoother(VideoProcessor):
11
+ def __init__(
12
+ self,
13
+ inference_mode="fast", batch_size=8, window_size=60,
14
+ minimum_patch_size=5, threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0, initialize="identity", tracking_window_size=0
15
+ ):
16
+ self.inference_mode = inference_mode
17
+ self.batch_size = batch_size
18
+ self.window_size = window_size
19
+ self.ebsynth_config = {
20
+ "minimum_patch_size": minimum_patch_size,
21
+ "threads_per_block": threads_per_block,
22
+ "num_iter": num_iter,
23
+ "gpu_id": gpu_id,
24
+ "guide_weight": guide_weight,
25
+ "initialize": initialize,
26
+ "tracking_window_size": tracking_window_size
27
+ }
28
+
29
+ @staticmethod
30
+ def from_model_manager(model_manager, **kwargs):
31
+ # TODO: fetch GPU ID from model_manager
32
+ return FastBlendSmoother(**kwargs)
33
+
34
+ def inference_fast(self, frames_guide, frames_style):
35
+ table_manager = TableManager()
36
+ patch_match_engine = PyramidPatchMatcher(
37
+ image_height=frames_style[0].shape[0],
38
+ image_width=frames_style[0].shape[1],
39
+ channel=3,
40
+ **self.ebsynth_config
41
+ )
42
+ # left part
43
+ table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, self.batch_size, desc="Fast Mode Step 1/4")
44
+ table_l = table_manager.remapping_table_to_blending_table(table_l)
45
+ table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, self.window_size, self.batch_size, desc="Fast Mode Step 2/4")
46
+ # right part
47
+ table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, self.batch_size, desc="Fast Mode Step 3/4")
48
+ table_r = table_manager.remapping_table_to_blending_table(table_r)
49
+ table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, self.window_size, self.batch_size, desc="Fast Mode Step 4/4")[::-1]
50
+ # merge
51
+ frames = []
52
+ for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
53
+ weight_m = -1
54
+ weight = weight_l + weight_m + weight_r
55
+ frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
56
+ frames.append(frame)
57
+ frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
58
+ frames = [Image.fromarray(frame) for frame in frames]
59
+ return frames
60
+
61
+ def inference_balanced(self, frames_guide, frames_style):
62
+ patch_match_engine = PyramidPatchMatcher(
63
+ image_height=frames_style[0].shape[0],
64
+ image_width=frames_style[0].shape[1],
65
+ channel=3,
66
+ **self.ebsynth_config
67
+ )
68
+ output_frames = []
69
+ # tasks
70
+ n = len(frames_style)
71
+ tasks = []
72
+ for target in range(n):
73
+ for source in range(target - self.window_size, target + self.window_size + 1):
74
+ if source >= 0 and source < n and source != target:
75
+ tasks.append((source, target))
76
+ # run
77
+ frames = [(None, 1) for i in range(n)]
78
+ for batch_id in tqdm(range(0, len(tasks), self.batch_size), desc="Balanced Mode"):
79
+ tasks_batch = tasks[batch_id: min(batch_id+self.batch_size, len(tasks))]
80
+ source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
81
+ target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
82
+ source_style = np.stack([frames_style[source] for source, target in tasks_batch])
83
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
84
+ for (source, target), result in zip(tasks_batch, target_style):
85
+ frame, weight = frames[target]
86
+ if frame is None:
87
+ frame = frames_style[target]
88
+ frames[target] = (
89
+ frame * (weight / (weight + 1)) + result / (weight + 1),
90
+ weight + 1
91
+ )
92
+ if weight + 1 == min(n, target + self.window_size + 1) - max(0, target - self.window_size):
93
+ frame = frame.clip(0, 255).astype("uint8")
94
+ output_frames.append(Image.fromarray(frame))
95
+ frames[target] = (None, 1)
96
+ return output_frames
97
+
98
+ def inference_accurate(self, frames_guide, frames_style):
99
+ patch_match_engine = PyramidPatchMatcher(
100
+ image_height=frames_style[0].shape[0],
101
+ image_width=frames_style[0].shape[1],
102
+ channel=3,
103
+ use_mean_target_style=True,
104
+ **self.ebsynth_config
105
+ )
106
+ output_frames = []
107
+ # run
108
+ n = len(frames_style)
109
+ for target in tqdm(range(n), desc="Accurate Mode"):
110
+ l, r = max(target - self.window_size, 0), min(target + self.window_size + 1, n)
111
+ remapped_frames = []
112
+ for i in range(l, r, self.batch_size):
113
+ j = min(i + self.batch_size, r)
114
+ source_guide = np.stack([frames_guide[source] for source in range(i, j)])
115
+ target_guide = np.stack([frames_guide[target]] * (j - i))
116
+ source_style = np.stack([frames_style[source] for source in range(i, j)])
117
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
118
+ remapped_frames.append(target_style)
119
+ frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
120
+ frame = frame.clip(0, 255).astype("uint8")
121
+ output_frames.append(Image.fromarray(frame))
122
+ return output_frames
123
+
124
+ def release_vram(self):
125
+ mempool = cp.get_default_memory_pool()
126
+ pinned_mempool = cp.get_default_pinned_memory_pool()
127
+ mempool.free_all_blocks()
128
+ pinned_mempool.free_all_blocks()
129
+
130
+ def __call__(self, rendered_frames, original_frames=None, **kwargs):
131
+ rendered_frames = [np.array(frame) for frame in rendered_frames]
132
+ original_frames = [np.array(frame) for frame in original_frames]
133
+ if self.inference_mode == "fast":
134
+ output_frames = self.inference_fast(original_frames, rendered_frames)
135
+ elif self.inference_mode == "balanced":
136
+ output_frames = self.inference_balanced(original_frames, rendered_frames)
137
+ elif self.inference_mode == "accurate":
138
+ output_frames = self.inference_accurate(original_frames, rendered_frames)
139
+ else:
140
+ raise ValueError("inference_mode must be fast, balanced or accurate")
141
+ self.release_vram()
142
+ return output_frames