Luffuly commited on
Commit
491602a
·
1 Parent(s): 70bbf64
feature_extractor/preprocessor_config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_valid_processor_keys": [
3
+ "images",
4
+ "do_resize",
5
+ "size",
6
+ "resample",
7
+ "do_center_crop",
8
+ "crop_size",
9
+ "do_rescale",
10
+ "rescale_factor",
11
+ "do_normalize",
12
+ "image_mean",
13
+ "image_std",
14
+ "do_convert_rgb",
15
+ "return_tensors",
16
+ "data_format",
17
+ "input_data_format"
18
+ ],
19
+ "crop_size": {
20
+ "height": 224,
21
+ "width": 224
22
+ },
23
+ "do_center_crop": true,
24
+ "do_convert_rgb": true,
25
+ "do_normalize": true,
26
+ "do_rescale": true,
27
+ "do_resize": true,
28
+ "image_mean": [
29
+ 0.48145466,
30
+ 0.4578275,
31
+ 0.40821073
32
+ ],
33
+ "image_processor_type": "CLIPImageProcessor",
34
+ "image_std": [
35
+ 0.26862954,
36
+ 0.26130258,
37
+ 0.27577711
38
+ ],
39
+ "resample": 3,
40
+ "rescale_factor": 0.00392156862745098,
41
+ "size": {
42
+ "shortest_edge": 224
43
+ }
44
+ }
image_encoder/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "lambdalabs/sd-image-variations-diffusers",
3
+ "architectures": [
4
+ "CLIPVisionModelWithProjection"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "dropout": 0.0,
8
+ "hidden_act": "quick_gelu",
9
+ "hidden_size": 1024,
10
+ "image_size": 224,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 4096,
14
+ "layer_norm_eps": 1e-05,
15
+ "model_type": "clip_vision_model",
16
+ "num_attention_heads": 16,
17
+ "num_channels": 3,
18
+ "num_hidden_layers": 24,
19
+ "patch_size": 14,
20
+ "projection_dim": 768,
21
+ "torch_dtype": "bfloat16",
22
+ "transformers_version": "4.39.3"
23
+ }
image_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4b33d864f89a793357a768cb07d0dc18d6a14e6664f4110a0d535ca9ba78da8
3
+ size 607980488
model_index.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "StableDiffusionImageCustomPipeline",
3
+ "_diffusers_version": "0.27.2",
4
+ "feature_extractor": [
5
+ "transformers",
6
+ "CLIPImageProcessor"
7
+ ],
8
+ "image_encoder": [
9
+ "transformers",
10
+ "CLIPVisionModelWithProjection"
11
+ ],
12
+ "noisy_cond_latents": false,
13
+ "requires_safety_checker": true,
14
+ "safety_checker": [
15
+ null,
16
+ null
17
+ ],
18
+ "scheduler": [
19
+ "diffusers",
20
+ "EulerAncestralDiscreteScheduler"
21
+ ],
22
+ "unet": [
23
+ "mv_unet",
24
+ "UnifieldWrappedUNet"
25
+ ],
26
+ "vae": [
27
+ "diffusers",
28
+ "AutoencoderKL"
29
+ ]
30
+ }
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "EulerAncestralDiscreteScheduler",
3
+ "_diffusers_version": "0.27.2",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "num_train_timesteps": 1000,
9
+ "prediction_type": "epsilon",
10
+ "rescale_betas_zero_snr": false,
11
+ "set_alpha_to_one": false,
12
+ "skip_prk_steps": true,
13
+ "steps_offset": 1,
14
+ "timestep_spacing": "linspace",
15
+ "trained_betas": null
16
+ }
unet/config.json ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UnifieldWrappedUNet",
3
+ "_diffusers_version": "0.27.2",
4
+ "act_fn": "silu",
5
+ "addition_embed_type": null,
6
+ "addition_embed_type_num_heads": 64,
7
+ "addition_time_embed_dim": null,
8
+ "attention_head_dim": 8,
9
+ "attention_type": "default",
10
+ "block_out_channels": [
11
+ 320,
12
+ 640,
13
+ 1280,
14
+ 1280
15
+ ],
16
+ "center_input_sample": false,
17
+ "class_embed_type": null,
18
+ "class_embeddings_concat": false,
19
+ "conv_in_kernel": 3,
20
+ "conv_out_kernel": 3,
21
+ "cross_attention_dim": 768,
22
+ "cross_attention_norm": null,
23
+ "down_block_types": [
24
+ "CrossAttnDownBlock2D",
25
+ "CrossAttnDownBlock2D",
26
+ "CrossAttnDownBlock2D",
27
+ "DownBlock2D"
28
+ ],
29
+ "downsample_padding": 1,
30
+ "dropout": 0.0,
31
+ "dual_cross_attention": false,
32
+ "encoder_hid_dim": null,
33
+ "encoder_hid_dim_type": null,
34
+ "flip_sin_to_cos": true,
35
+ "freq_shift": 0,
36
+ "in_channels": 4,
37
+ "layers_per_block": 2,
38
+ "mid_block_only_cross_attention": null,
39
+ "mid_block_scale_factor": 1,
40
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
41
+ "norm_eps": 1e-05,
42
+ "norm_num_groups": 32,
43
+ "num_attention_heads": null,
44
+ "num_class_embeds": null,
45
+ "only_cross_attention": false,
46
+ "out_channels": 4,
47
+ "projection_class_embeddings_input_dim": null,
48
+ "resnet_out_scale_factor": 1.0,
49
+ "resnet_skip_time_act": false,
50
+ "resnet_time_scale_shift": "default",
51
+ "reverse_transformer_layers_per_block": null,
52
+ "sample_size": 64,
53
+ "time_cond_proj_dim": null,
54
+ "time_embedding_act_fn": null,
55
+ "time_embedding_dim": null,
56
+ "time_embedding_type": "positional",
57
+ "timestep_post_act": null,
58
+ "transformer_layers_per_block": 1,
59
+ "up_block_types": [
60
+ "UpBlock2D",
61
+ "CrossAttnUpBlock2D",
62
+ "CrossAttnUpBlock2D",
63
+ "CrossAttnUpBlock2D"
64
+ ],
65
+ "upcast_attention": false,
66
+ "use_linear_projection": false,
67
+
68
+ "init_self_attn_ref": true,
69
+ "self_attn_ref_position": "attn1",
70
+ "self_attn_ref_pixel_wise_crosspond": true,
71
+ "self_attn_ref_effect_on": "all",
72
+ "self_attn_ref_chain_pos": "parralle",
73
+ "use_simple3d_attn": false
74
+ }
unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5cbaf1d56619345ce78de8cfbb20d94923b3305a364bf6a5b2a2cc422d4b701
3
+ size 3537503456
unet/mv_unet.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional, Tuple, Union, Any
3
+ from diffusers import UNet2DConditionModel
4
+ from diffusers.models.attention_processor import Attention
5
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
6
+
7
+
8
+ def construct_pix2pix_attention(hidden_states_dim, norm_type="none"):
9
+ if norm_type == "layernorm":
10
+ norm = torch.nn.LayerNorm(hidden_states_dim)
11
+ else:
12
+ norm = torch.nn.Identity()
13
+ attention = Attention(
14
+ query_dim=hidden_states_dim,
15
+ heads=8,
16
+ dim_head=hidden_states_dim // 8,
17
+ bias=True,
18
+ )
19
+ # NOTE: xformers 0.22 does not support batchsize >= 4096
20
+ attention.xformers_not_supported = True # hacky solution
21
+ return norm, attention
22
+
23
+
24
+ def switch_extra_processor(model, enable_filter=lambda x:True):
25
+ def recursive_add_processors(name: str, module: torch.nn.Module):
26
+ for sub_name, child in module.named_children():
27
+ recursive_add_processors(f"{name}.{sub_name}", child)
28
+
29
+ if isinstance(module, ExtraAttnProc):
30
+ module.enabled = enable_filter(name)
31
+
32
+ for name, module in model.named_children():
33
+ recursive_add_processors(name, module)
34
+
35
+
36
+ def add_extra_processor(model: torch.nn.Module, enable_filter=lambda x:True, **kwargs):
37
+ return_dict = torch.nn.ModuleDict()
38
+ proj_in_dim = kwargs.get('proj_in_dim', False)
39
+ kwargs.pop('proj_in_dim', None)
40
+
41
+ def recursive_add_processors(name: str, module: torch.nn.Module):
42
+ for sub_name, child in module.named_children():
43
+ if "ref_unet" not in (sub_name + name):
44
+ recursive_add_processors(f"{name}.{sub_name}", child)
45
+
46
+ if isinstance(module, Attention):
47
+ new_processor = ExtraAttnProc(
48
+ chained_proc=module.get_processor(),
49
+ enabled=enable_filter(f"{name}.processor"),
50
+ name=f"{name}.processor",
51
+ proj_in_dim=proj_in_dim if proj_in_dim else module.cross_attention_dim,
52
+ target_dim=module.cross_attention_dim,
53
+ **kwargs
54
+ )
55
+ module.set_processor(new_processor)
56
+ return_dict[f"{name}.processor".replace(".", "__")] = new_processor
57
+
58
+ for name, module in model.named_children():
59
+ recursive_add_processors(name, module)
60
+ return return_dict
61
+
62
+
63
+
64
+ class ExtraAttnProc(torch.nn.Module):
65
+ def __init__(
66
+ self,
67
+ chained_proc,
68
+ enabled=False,
69
+ name=None,
70
+ mode='extract',
71
+ with_proj_in=False,
72
+ proj_in_dim=768,
73
+ target_dim=None,
74
+ pixel_wise_crosspond=False,
75
+ norm_type="none", # none or layernorm
76
+ crosspond_effect_on="all", # all or first
77
+ crosspond_chain_pos="parralle", # before or parralle or after
78
+ simple_3d=False,
79
+ views=4,
80
+ ) -> None:
81
+ super().__init__()
82
+ self.enabled = enabled
83
+ self.chained_proc = chained_proc
84
+ self.name = name
85
+ self.mode = mode
86
+ self.with_proj_in=with_proj_in
87
+ self.proj_in_dim = proj_in_dim
88
+ self.target_dim = target_dim or proj_in_dim
89
+ self.hidden_states_dim = self.target_dim
90
+ self.pixel_wise_crosspond = pixel_wise_crosspond
91
+ self.crosspond_effect_on = crosspond_effect_on
92
+ self.crosspond_chain_pos = crosspond_chain_pos
93
+ self.views = views
94
+ self.simple_3d = simple_3d
95
+ if self.with_proj_in and self.enabled:
96
+ self.in_linear = torch.nn.Linear(self.proj_in_dim, self.target_dim, bias=False)
97
+ if self.target_dim == self.proj_in_dim:
98
+ self.in_linear.weight.data = torch.eye(proj_in_dim)
99
+ else:
100
+ self.in_linear = None
101
+ if self.pixel_wise_crosspond and self.enabled:
102
+ self.crosspond_norm, self.crosspond_attention = construct_pix2pix_attention(self.hidden_states_dim, norm_type=norm_type)
103
+
104
+ def do_crosspond_attention(self, hidden_states: torch.FloatTensor, other_states: torch.FloatTensor):
105
+ hidden_states = self.crosspond_norm(hidden_states)
106
+
107
+ batch, L, D = hidden_states.shape
108
+ assert hidden_states.shape == other_states.shape, f"got {hidden_states.shape} and {other_states.shape}"
109
+ # to -> batch * L, 1, D
110
+ hidden_states = hidden_states.reshape(batch * L, 1, D)
111
+ other_states = other_states.reshape(batch * L, 1, D)
112
+ hidden_states_catted = other_states
113
+ hidden_states = self.crosspond_attention(
114
+ hidden_states,
115
+ encoder_hidden_states=hidden_states_catted,
116
+ )
117
+ return hidden_states.reshape(batch, L, D)
118
+
119
+ def __call__(
120
+ self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None,
121
+ ref_dict: dict = None, mode=None, **kwargs
122
+ ) -> Any:
123
+ if not self.enabled:
124
+ return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
125
+ if encoder_hidden_states is None:
126
+ encoder_hidden_states = hidden_states
127
+ assert ref_dict is not None
128
+ if (mode or self.mode) == 'extract':
129
+ ref_dict[self.name] = hidden_states
130
+ hidden_states1 = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
131
+ if self.pixel_wise_crosspond and self.crosspond_chain_pos == "after":
132
+ ref_dict[self.name] = hidden_states1
133
+ return hidden_states1
134
+ elif (mode or self.mode) == 'inject':
135
+ ref_state = ref_dict.pop(self.name)
136
+ if self.with_proj_in:
137
+ ref_state = self.in_linear(ref_state)
138
+
139
+ B, L, D = ref_state.shape
140
+ if hidden_states.shape[0] == B:
141
+ modalities = 1
142
+ views = 1
143
+ else:
144
+ modalities = hidden_states.shape[0] // B // self.views
145
+ views = self.views
146
+ if self.pixel_wise_crosspond:
147
+ if self.crosspond_effect_on == "all":
148
+ ref_state = ref_state[:, None].expand(-1, modalities * views, -1, -1).reshape(-1, *ref_state.shape[-2:])
149
+
150
+ if self.crosspond_chain_pos == "before":
151
+ hidden_states = hidden_states + self.do_crosspond_attention(hidden_states, ref_state)
152
+
153
+ hidden_states1 = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
154
+
155
+ if self.crosspond_chain_pos == "parralle":
156
+ hidden_states1 = hidden_states1 + self.do_crosspond_attention(hidden_states, ref_state)
157
+
158
+ if self.crosspond_chain_pos == "after":
159
+ hidden_states1 = hidden_states1 + self.do_crosspond_attention(hidden_states1, ref_state)
160
+ return hidden_states1
161
+ else:
162
+ assert self.crosspond_effect_on == "first"
163
+ # hidden_states [B * modalities * views, L, D]
164
+ # ref_state [B, L, D]
165
+ ref_state = ref_state[:, None].expand(-1, modalities, -1, -1).reshape(-1, ref_state.shape[-2], ref_state.shape[-1]) # [B * modalities, L, D]
166
+
167
+ def do_paritial_crosspond(hidden_states, ref_state):
168
+ first_view_hidden_states = hidden_states.view(-1, views, hidden_states.shape[1], hidden_states.shape[2])[:, 0] # [B * modalities, L, D]
169
+ hidden_states2 = self.do_crosspond_attention(first_view_hidden_states, ref_state) # [B * modalities, L, D]
170
+ hidden_states2_padded = torch.zeros_like(hidden_states).reshape(-1, views, hidden_states.shape[1], hidden_states.shape[2])
171
+ hidden_states2_padded[:, 0] = hidden_states2
172
+ hidden_states2_padded = hidden_states2_padded.reshape(-1, hidden_states.shape[1], hidden_states.shape[2])
173
+ return hidden_states2_padded
174
+
175
+ if self.crosspond_chain_pos == "before":
176
+ hidden_states = hidden_states + do_paritial_crosspond(hidden_states, ref_state)
177
+
178
+ hidden_states1 = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs) # [B * modalities * views, L, D]
179
+ if self.crosspond_chain_pos == "parralle":
180
+ hidden_states1 = hidden_states1 + do_paritial_crosspond(hidden_states, ref_state)
181
+ if self.crosspond_chain_pos == "after":
182
+ hidden_states1 = hidden_states1 + do_paritial_crosspond(hidden_states1, ref_state)
183
+ return hidden_states1
184
+ elif self.simple_3d:
185
+ B, L, C = encoder_hidden_states.shape
186
+ mv = self.views
187
+ encoder_hidden_states = encoder_hidden_states.reshape(B // mv, mv, L, C)
188
+ ref_state = ref_state[:, None]
189
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ref_state], dim=1)
190
+ encoder_hidden_states = encoder_hidden_states.reshape(B // mv, 1, (mv+1) * L, C)
191
+ encoder_hidden_states = encoder_hidden_states.repeat(1, mv, 1, 1).reshape(-1, (mv+1) * L, C)
192
+ return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
193
+ else:
194
+ ref_state = ref_state[:, None].expand(-1, modalities * views, -1, -1).reshape(-1, ref_state.shape[-2], ref_state.shape[-1])
195
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ref_state], dim=1)
196
+ return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
197
+ else:
198
+ raise NotImplementedError("mode or self.mode is required to be 'extract' or 'inject'")
199
+
200
+
201
+ class UnifieldWrappedUNet(UNet2DConditionModel):
202
+ def __init__(
203
+ self,
204
+ sample_size: Optional[int] = None,
205
+ in_channels: int = 4,
206
+ out_channels: int = 4,
207
+ center_input_sample: bool = False,
208
+ flip_sin_to_cos: bool = True,
209
+ freq_shift: int = 0,
210
+ down_block_types: Tuple[str] = (
211
+ "CrossAttnDownBlock2D",
212
+ "CrossAttnDownBlock2D",
213
+ "CrossAttnDownBlock2D",
214
+ "DownBlock2D",
215
+ ),
216
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
217
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
218
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
219
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
220
+ layers_per_block: Union[int, Tuple[int]] = 2,
221
+ downsample_padding: int = 1,
222
+ mid_block_scale_factor: float = 1,
223
+ dropout: float = 0.0,
224
+ act_fn: str = "silu",
225
+ norm_num_groups: Optional[int] = 32,
226
+ norm_eps: float = 1e-5,
227
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
228
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
229
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
230
+ encoder_hid_dim: Optional[int] = None,
231
+ encoder_hid_dim_type: Optional[str] = None,
232
+ attention_head_dim: Union[int, Tuple[int]] = 8,
233
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
234
+ dual_cross_attention: bool = False,
235
+ use_linear_projection: bool = False,
236
+ class_embed_type: Optional[str] = None,
237
+ addition_embed_type: Optional[str] = None,
238
+ addition_time_embed_dim: Optional[int] = None,
239
+ num_class_embeds: Optional[int] = None,
240
+ upcast_attention: bool = False,
241
+ resnet_time_scale_shift: str = "default",
242
+ resnet_skip_time_act: bool = False,
243
+ resnet_out_scale_factor: float = 1.0,
244
+ time_embedding_type: str = "positional",
245
+ time_embedding_dim: Optional[int] = None,
246
+ time_embedding_act_fn: Optional[str] = None,
247
+ timestep_post_act: Optional[str] = None,
248
+ time_cond_proj_dim: Optional[int] = None,
249
+ conv_in_kernel: int = 3,
250
+ conv_out_kernel: int = 3,
251
+ projection_class_embeddings_input_dim: Optional[int] = None,
252
+ attention_type: str = "default",
253
+ class_embeddings_concat: bool = False,
254
+ mid_block_only_cross_attention: Optional[bool] = None,
255
+ cross_attention_norm: Optional[str] = None,
256
+ addition_embed_type_num_heads: int = 64,
257
+
258
+ init_self_attn_ref: bool = False,
259
+ self_attn_ref_other_model_name: str = 'lambdalabs/sd-image-variations-diffusers',
260
+ self_attn_ref_position: str = "attn1",
261
+ self_attn_ref_pixel_wise_crosspond: bool = False,
262
+ self_attn_ref_effect_on: str = "all",
263
+ self_attn_ref_chain_pos: str = "parralle",
264
+ use_simple3d_attn: bool = False,
265
+ **kwargs
266
+ ):
267
+ super().__init__(**{
268
+ k: v for k, v in locals().items() if k not in
269
+ ["self", "kwargs", "__class__",
270
+ "init_self_attn_ref", "self_attn_ref_other_model_name", "self_attn_ref_position", "self_attn_ref_pixel_wise_crosspond",
271
+ "self_attn_ref_effect_on", "self_attn_ref_chain_pos", "use_simple3d_attn"
272
+ ]
273
+ })
274
+
275
+
276
+ self.ref_unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(
277
+ self_attn_ref_other_model_name, subfolder="unet", torch_dtype=self.dtype
278
+ )
279
+ add_extra_processor(
280
+ model=self.ref_unet,
281
+ enable_filter=lambda name: name.endswith(f"{self_attn_ref_position}.processor"),
282
+ mode='extract',
283
+ with_proj_in=False,
284
+ pixel_wise_crosspond=False,
285
+ )
286
+ add_extra_processor(
287
+ model=self,
288
+ enable_filter=lambda name: name.endswith(f"{self_attn_ref_position}.processor"),
289
+ mode='inject',
290
+ with_proj_in=False,
291
+ pixel_wise_crosspond=self_attn_ref_pixel_wise_crosspond,
292
+ crosspond_effect_on=self_attn_ref_effect_on,
293
+ crosspond_chain_pos=self_attn_ref_chain_pos,
294
+ simple_3d=use_simple3d_attn,
295
+ )
296
+ switch_extra_processor(self, enable_filter=lambda name: name.endswith(f"{self_attn_ref_position}.processor"))
297
+
298
+ def __call__(
299
+ self,
300
+ sample: torch.Tensor,
301
+ timestep: Union[torch.Tensor, float, int],
302
+ encoder_hidden_states: torch.Tensor,
303
+ condition_latens: torch.Tensor = None,
304
+ class_labels: Optional[torch.Tensor] = None,
305
+ ) -> Union[UNet2DConditionOutput, Tuple]:
306
+
307
+ ref_dict = {}
308
+ self.ref_unet(condition_latens, timestep, encoder_hidden_states, cross_attention_kwargs=dict(ref_dict=ref_dict))
309
+ return self.forward(
310
+ sample, timestep, encoder_hidden_states,
311
+ class_labels=class_labels,
312
+ cross_attention_kwargs=dict(ref_dict=ref_dict, mode='inject'),
313
+ )
314
+
vae/config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.27.2",
4
+ "_name_or_path": "lambdalabs/sd-image-variations-diffusers",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": true,
19
+ "in_channels": 3,
20
+ "latent_channels": 4,
21
+ "latents_mean": null,
22
+ "latents_std": null,
23
+ "layers_per_block": 2,
24
+ "norm_num_groups": 32,
25
+ "out_channels": 3,
26
+ "sample_size": 256,
27
+ "scaling_factor": 0.18215,
28
+ "up_block_types": [
29
+ "UpDecoderBlock2D",
30
+ "UpDecoderBlock2D",
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D"
33
+ ]
34
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d0c34f57abe50f323040f2366c8e22b941068dcdf53c8eb1d6fafb838afecb7
3
+ size 167335590