MKFMIKU commited on
Commit
5239732
1 Parent(s): 6440329

Upload 5 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ checkpoint_72001/checkpoint filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import jax
3
+ import numpy as np
4
+ import jax.numpy as jnp
5
+ from flax.training import checkpoints
6
+ from diffusers import FlaxControlNetModel, FlaxUNet2DConditionModel, FlaxAutoencoderKL, FlaxDDIMScheduler
7
+ from codi.controlnet_flax import FlaxControlNetModel
8
+ from codi.pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
9
+ from transformers import CLIPTokenizer, FlaxCLIPTextModel
10
+ from flax.training.common_utils import shard
11
+ from flax.jax_utils import replicate
12
+
13
+
14
+ MODEL_NAME = "CompVis/stable-diffusion-v1-4"
15
+
16
+ unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
17
+ MODEL_NAME,
18
+ subfolder="unet",
19
+ revision="flax",
20
+ dtype=jnp.float32,
21
+ )
22
+ vae, vae_params = FlaxAutoencoderKL.from_pretrained(
23
+ MODEL_NAME,
24
+ subfolder="vae",
25
+ revision="flax",
26
+ dtype=jnp.float32,
27
+ )
28
+ text_encoder = FlaxCLIPTextModel.from_pretrained(
29
+ MODEL_NAME,
30
+ subfolder="text_encoder",
31
+ revision="flax",
32
+ dtype=jnp.float32,
33
+ )
34
+ tokenizer = CLIPTokenizer.from_pretrained(
35
+ MODEL_NAME,
36
+ subfolder="tokenizer",
37
+ revision="flax",
38
+ dtype=jnp.float32,
39
+ )
40
+
41
+ controlnet = FlaxControlNetModel(
42
+ in_channels=unet.config.in_channels,
43
+ down_block_types=unet.config.down_block_types,
44
+ only_cross_attention=unet.config.only_cross_attention,
45
+ block_out_channels=unet.config.block_out_channels,
46
+ layers_per_block=unet.config.layers_per_block,
47
+ attention_head_dim=unet.config.attention_head_dim,
48
+ cross_attention_dim=unet.config.cross_attention_dim,
49
+ use_linear_projection=unet.config.use_linear_projection,
50
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
51
+ freq_shift=unet.config.freq_shift,
52
+ )
53
+ scheduler = FlaxDDIMScheduler(
54
+ num_train_timesteps=1000,
55
+ beta_start=0.00085,
56
+ beta_end=0.012,
57
+ beta_schedule="scaled_linear",
58
+ trained_betas=None,
59
+ set_alpha_to_one=True,
60
+ steps_offset=0,
61
+ )
62
+ scheduler_state = scheduler.create_state()
63
+
64
+ pipeline = FlaxStableDiffusionControlNetPipeline(
65
+ vae,
66
+ text_encoder,
67
+ tokenizer,
68
+ unet,
69
+ controlnet,
70
+ scheduler,
71
+ None,
72
+ None,
73
+ dtype=jnp.float32,
74
+ )
75
+ controlnet_params = checkpoints.restore_checkpoint("experiments/checkpoint_72001", target=None)
76
+
77
+ pipeline_params = {
78
+ "vae": vae_params,
79
+ "unet": unet_params,
80
+ "text_encoder": text_encoder.params,
81
+ "scheduler": scheduler_state,
82
+ "controlnet": controlnet_params,
83
+ }
84
+ pipeline_params = replicate(pipeline_params)
85
+
86
+ def infer(seed, prompt, negative_prompt, steps, cfgr):
87
+ rng = jax.random.PRNGKey(int(seed))
88
+
89
+ num_samples = jax.device_count()
90
+ rng = jax.random.split(rng, num_samples)
91
+
92
+ prompt_ids = pipeline.prepare_text_inputs([prompt] * num_samples)
93
+ negative_prompt_ids = pipeline.prepare_text_inputs([negative_prompt] * num_samples)
94
+
95
+ prompt_ids = shard(prompt_ids)
96
+ negative_prompt_ids = shard(negative_prompt_ids)
97
+
98
+ output = pipeline(
99
+ prompt_ids=prompt_ids,
100
+ image=None,
101
+ params=pipeline_params,
102
+ prng_seed=rng,
103
+ num_inference_steps=int(steps),
104
+ guidance_scale=float(cfgr),
105
+ neg_prompt_ids=negative_prompt_ids,
106
+ jit=True,
107
+ ).images
108
+
109
+ output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
110
+ return output_images
111
+
112
+ with gr.Blocks(theme='gradio/soft') as demo:
113
+ gr.Markdown("## Parameter-efficient text-to-image distillation")
114
+ gr.Markdown("[\[Paper\]](https://arxiv.org/abs/2310.01407) [\[Project Page\]](https://fast-codi.github.io)")
115
+
116
+ with gr.Tab("CoDi on Text-to-Image"):
117
+
118
+ with gr.Row():
119
+ with gr.Column():
120
+ prompt_input = gr.Textbox(label="Prompt")
121
+ negative_prompt = gr.Textbox(label="Negative Prompt", value="monochrome, lowres, bad anatomy, worst quality, low quality")
122
+ seed = gr.Number(label="Seed", value=0)
123
+ output = gr.Gallery(label="Output Images")
124
+
125
+ with gr.Row():
126
+ num_inference_steps = gr.Slider(2, 50, value=4, step=1, label="Steps")
127
+ guidance_scale = gr.Slider(2.0, 14.0, value=7.5, step=0.5, label='Guidance Scale')
128
+ submit_btn = gr.Button(value = "Submit")
129
+ inputs = [
130
+ seed,
131
+ prompt_input,
132
+ negative_prompt,
133
+ num_inference_steps,
134
+ guidance_scale
135
+ ]
136
+ submit_btn.click(fn=infer, inputs=inputs, outputs=[output])
137
+
138
+ with gr.Row():
139
+ gr.Examples(
140
+ examples=["oranges", "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"],
141
+ inputs=prompt_input,
142
+ fn=infer
143
+ )
144
+
145
+ demo.launch(max_threads=1, share=True)
checkpoint_72001/_METADATA ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint_72001/checkpoint ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5bf57d86d67db04cfa568f7c0399ea91d0ee5a9a4d35bda9d07dae370b04b89
3
+ size 1445128798
codi/controlnet_flax.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional, Tuple, Union
15
+
16
+ import flax
17
+ import flax.linen as nn
18
+ import jax
19
+ import jax.numpy as jnp
20
+ from flax.core.frozen_dict import FrozenDict
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, flax_register_to_config
23
+ from diffusers.utils import BaseOutput
24
+ from diffusers.models.embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
25
+ from diffusers.models.modeling_flax_utils import FlaxModelMixin
26
+ from diffusers.models.unets.unet_2d_blocks_flax import (
27
+ FlaxCrossAttnDownBlock2D,
28
+ FlaxDownBlock2D,
29
+ FlaxUNetMidBlock2DCrossAttn,
30
+ )
31
+
32
+
33
+ @flax.struct.dataclass
34
+ class FlaxControlNetOutput(BaseOutput):
35
+ """
36
+ The output of [`FlaxControlNetModel`].
37
+
38
+ Args:
39
+ down_block_res_samples (`jnp.ndarray`):
40
+ mid_block_res_sample (`jnp.ndarray`):
41
+ """
42
+
43
+ down_block_res_samples: jnp.ndarray
44
+ mid_block_res_sample: jnp.ndarray
45
+
46
+
47
+ class FlaxControlNetConditioningEmbedding(nn.Module):
48
+ conditioning_embedding_channels: int
49
+ block_out_channels: Tuple[int, ...] = (16, 32, 96, 256)
50
+ dtype: jnp.dtype = jnp.float32
51
+
52
+ def setup(self) -> None:
53
+ self.conv_in = nn.Conv(
54
+ self.block_out_channels[0],
55
+ kernel_size=(3, 3),
56
+ padding=((1, 1), (1, 1)),
57
+ dtype=self.dtype,
58
+ )
59
+
60
+ blocks = []
61
+ for i in range(len(self.block_out_channels) - 1):
62
+ channel_in = self.block_out_channels[i]
63
+ channel_out = self.block_out_channels[i + 1]
64
+ conv1 = nn.Conv(
65
+ channel_in,
66
+ kernel_size=(3, 3),
67
+ padding=((1, 1), (1, 1)),
68
+ dtype=self.dtype,
69
+ )
70
+ blocks.append(conv1)
71
+ conv2 = nn.Conv(
72
+ channel_out,
73
+ kernel_size=(3, 3),
74
+ strides=(2, 2),
75
+ padding=((1, 1), (1, 1)),
76
+ dtype=self.dtype,
77
+ )
78
+ blocks.append(conv2)
79
+ self.blocks = blocks
80
+
81
+ self.conv_out = nn.Conv(
82
+ self.conditioning_embedding_channels,
83
+ kernel_size=(3, 3),
84
+ padding=((1, 1), (1, 1)),
85
+ kernel_init=nn.initializers.zeros_init(),
86
+ bias_init=nn.initializers.zeros_init(),
87
+ dtype=self.dtype,
88
+ )
89
+
90
+ def __call__(self, conditioning: jnp.ndarray) -> jnp.ndarray:
91
+ embedding = self.conv_in(conditioning)
92
+ embedding = nn.silu(embedding)
93
+
94
+ for block in self.blocks:
95
+ embedding = block(embedding)
96
+ embedding = nn.silu(embedding)
97
+
98
+ embedding = self.conv_out(embedding)
99
+
100
+ return embedding
101
+
102
+
103
+ @flax_register_to_config
104
+ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
105
+ r"""
106
+ A ControlNet model.
107
+
108
+ This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it’s generic methods
109
+ implemented for all models (such as downloading or saving).
110
+
111
+ This model is also a Flax Linen [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
112
+ subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
113
+ general usage and behavior.
114
+
115
+ Inherent JAX features such as the following are supported:
116
+
117
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
118
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
119
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
120
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
121
+
122
+ Parameters:
123
+ sample_size (`int`, *optional*):
124
+ The size of the input sample.
125
+ in_channels (`int`, *optional*, defaults to 4):
126
+ The number of channels in the input sample.
127
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
128
+ The tuple of downsample blocks to use.
129
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
130
+ The tuple of output channels for each block.
131
+ layers_per_block (`int`, *optional*, defaults to 2):
132
+ The number of layers per block.
133
+ attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
134
+ The dimension of the attention heads.
135
+ num_attention_heads (`int` or `Tuple[int]`, *optional*):
136
+ The number of attention heads.
137
+ cross_attention_dim (`int`, *optional*, defaults to 768):
138
+ The dimension of the cross attention features.
139
+ dropout (`float`, *optional*, defaults to 0):
140
+ Dropout probability for down, up and bottleneck blocks.
141
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
142
+ Whether to flip the sin to cos in the time embedding.
143
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
144
+ controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`):
145
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
146
+ conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`):
147
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
148
+ """
149
+
150
+ sample_size: int = 32
151
+ in_channels: int = 4
152
+ down_block_types: Tuple[str, ...] = (
153
+ "CrossAttnDownBlock2D",
154
+ "CrossAttnDownBlock2D",
155
+ "CrossAttnDownBlock2D",
156
+ "DownBlock2D",
157
+ )
158
+ only_cross_attention: Union[bool, Tuple[bool, ...]] = False
159
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
160
+ layers_per_block: int = 2
161
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8
162
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None
163
+ cross_attention_dim: int = 1280
164
+ dropout: float = 0.0
165
+ use_linear_projection: bool = False
166
+ dtype: jnp.dtype = jnp.float32
167
+ flip_sin_to_cos: bool = True
168
+ freq_shift: int = 0
169
+ controlnet_conditioning_channel_order: str = "rgb"
170
+ conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256)
171
+
172
+ def init_weights(self, rng: jax.Array) -> FrozenDict:
173
+ # init input tensors
174
+ sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
175
+ sample = jnp.zeros(sample_shape, dtype=jnp.float32)
176
+ timesteps = jnp.ones((1,), dtype=jnp.int32)
177
+ encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
178
+ controlnet_cond_shape = (1, 3, self.sample_size * 8, self.sample_size * 8)
179
+ controlnet_cond = jnp.zeros(controlnet_cond_shape, dtype=jnp.float32)
180
+
181
+ params_rng, dropout_rng = jax.random.split(rng)
182
+ rngs = {"params": params_rng, "dropout": dropout_rng}
183
+
184
+ return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
185
+
186
+ def setup(self) -> None:
187
+ block_out_channels = self.block_out_channels
188
+ time_embed_dim = block_out_channels[0] * 4
189
+
190
+ # If `num_attention_heads` is not defined (which is the case for most models)
191
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
192
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
193
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
194
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
195
+ # which is why we correct for the naming here.
196
+ num_attention_heads = self.num_attention_heads or self.attention_head_dim
197
+
198
+ # input
199
+ self.conv_in = nn.Conv(
200
+ block_out_channels[0],
201
+ kernel_size=(3, 3),
202
+ strides=(1, 1),
203
+ padding=((1, 1), (1, 1)),
204
+ dtype=self.dtype,
205
+ )
206
+
207
+ # time
208
+ self.time_proj = FlaxTimesteps(
209
+ block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift
210
+ )
211
+ self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
212
+
213
+ self.controlnet_cond_embedding = FlaxControlNetConditioningEmbedding(
214
+ conditioning_embedding_channels=block_out_channels[0],
215
+ block_out_channels=self.conditioning_embedding_out_channels,
216
+ )
217
+
218
+ only_cross_attention = self.only_cross_attention
219
+ if isinstance(only_cross_attention, bool):
220
+ only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
221
+
222
+ if isinstance(num_attention_heads, int):
223
+ num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
224
+
225
+ # down
226
+ down_blocks = []
227
+ controlnet_down_blocks = []
228
+
229
+ output_channel = block_out_channels[0]
230
+
231
+ controlnet_block = nn.Conv(
232
+ output_channel,
233
+ kernel_size=(1, 1),
234
+ padding="VALID",
235
+ kernel_init=nn.initializers.zeros_init(),
236
+ bias_init=nn.initializers.zeros_init(),
237
+ dtype=self.dtype,
238
+ )
239
+ controlnet_down_blocks.append(controlnet_block)
240
+
241
+ for i, down_block_type in enumerate(self.down_block_types):
242
+ input_channel = output_channel
243
+ output_channel = block_out_channels[i]
244
+ is_final_block = i == len(block_out_channels) - 1
245
+
246
+ if down_block_type == "CrossAttnDownBlock2D":
247
+ down_block = FlaxCrossAttnDownBlock2D(
248
+ in_channels=input_channel,
249
+ out_channels=output_channel,
250
+ dropout=self.dropout,
251
+ num_layers=self.layers_per_block,
252
+ num_attention_heads=num_attention_heads[i],
253
+ add_downsample=not is_final_block,
254
+ use_linear_projection=self.use_linear_projection,
255
+ only_cross_attention=only_cross_attention[i],
256
+ dtype=self.dtype,
257
+ )
258
+ else:
259
+ down_block = FlaxDownBlock2D(
260
+ in_channels=input_channel,
261
+ out_channels=output_channel,
262
+ dropout=self.dropout,
263
+ num_layers=self.layers_per_block,
264
+ add_downsample=not is_final_block,
265
+ dtype=self.dtype,
266
+ )
267
+
268
+ down_blocks.append(down_block)
269
+
270
+ for _ in range(self.layers_per_block):
271
+ controlnet_block = nn.Conv(
272
+ output_channel,
273
+ kernel_size=(1, 1),
274
+ padding="VALID",
275
+ kernel_init=nn.initializers.zeros_init(),
276
+ bias_init=nn.initializers.zeros_init(),
277
+ dtype=self.dtype,
278
+ )
279
+ controlnet_down_blocks.append(controlnet_block)
280
+
281
+ if not is_final_block:
282
+ controlnet_block = nn.Conv(
283
+ output_channel,
284
+ kernel_size=(1, 1),
285
+ padding="VALID",
286
+ kernel_init=nn.initializers.zeros_init(),
287
+ bias_init=nn.initializers.zeros_init(),
288
+ dtype=self.dtype,
289
+ )
290
+ controlnet_down_blocks.append(controlnet_block)
291
+
292
+ self.down_blocks = down_blocks
293
+ self.controlnet_down_blocks = controlnet_down_blocks
294
+
295
+ # mid
296
+ mid_block_channel = block_out_channels[-1]
297
+ self.mid_block = FlaxUNetMidBlock2DCrossAttn(
298
+ in_channels=mid_block_channel,
299
+ dropout=self.dropout,
300
+ num_attention_heads=num_attention_heads[-1],
301
+ use_linear_projection=self.use_linear_projection,
302
+ dtype=self.dtype,
303
+ )
304
+
305
+ self.controlnet_mid_block = nn.Conv(
306
+ mid_block_channel,
307
+ kernel_size=(1, 1),
308
+ padding="VALID",
309
+ kernel_init=nn.initializers.zeros_init(),
310
+ bias_init=nn.initializers.zeros_init(),
311
+ dtype=self.dtype,
312
+ )
313
+
314
+ def __call__(
315
+ self,
316
+ sample: jnp.ndarray,
317
+ timesteps: Union[jnp.ndarray, float, int],
318
+ encoder_hidden_states: jnp.ndarray,
319
+ controlnet_cond: jnp.ndarray,
320
+ conditioning_scale: float = 1.0,
321
+ return_dict: bool = True,
322
+ train: bool = False,
323
+ ) -> Union[FlaxControlNetOutput, Tuple[Tuple[jnp.ndarray, ...], jnp.ndarray]]:
324
+ r"""
325
+ Args:
326
+ sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
327
+ timestep (`jnp.ndarray` or `float` or `int`): timesteps
328
+ encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
329
+ controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
330
+ conditioning_scale (`float`, *optional*, defaults to `1.0`): the scale factor for controlnet outputs
331
+ return_dict (`bool`, *optional*, defaults to `True`):
332
+ Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
333
+ plain tuple.
334
+ train (`bool`, *optional*, defaults to `False`):
335
+ Use deterministic functions and disable dropout when not training.
336
+
337
+ Returns:
338
+ [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
339
+ [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a
340
+ `tuple`. When returning a tuple, the first element is the sample tensor.
341
+ """
342
+ channel_order = self.controlnet_conditioning_channel_order
343
+
344
+ # 1. time
345
+ if not isinstance(timesteps, jnp.ndarray):
346
+ timesteps = jnp.array([timesteps], dtype=jnp.int32)
347
+ elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
348
+ timesteps = timesteps.astype(dtype=jnp.float32)
349
+ timesteps = jnp.expand_dims(timesteps, 0)
350
+
351
+ t_emb = self.time_proj(timesteps)
352
+ t_emb = self.time_embedding(t_emb)
353
+
354
+ # 2. pre-process
355
+ sample = jnp.transpose(sample, (0, 2, 3, 1))
356
+ sample = self.conv_in(sample)
357
+
358
+ if controlnet_cond is not None:
359
+ if channel_order == "bgr":
360
+ controlnet_cond = jnp.flip(controlnet_cond, axis=1)
361
+ controlnet_cond = jnp.transpose(controlnet_cond, (0, 2, 3, 1))
362
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
363
+ sample += controlnet_cond
364
+
365
+ # 3. down
366
+ down_block_res_samples = (sample,)
367
+ for down_block in self.down_blocks:
368
+ if isinstance(down_block, FlaxCrossAttnDownBlock2D):
369
+ sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
370
+ else:
371
+ sample, res_samples = down_block(sample, t_emb, deterministic=not train)
372
+ down_block_res_samples += res_samples
373
+
374
+ # 4. mid
375
+ sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
376
+
377
+ # 5. contronet blocks
378
+ controlnet_down_block_res_samples = ()
379
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
380
+ down_block_res_sample = controlnet_block(down_block_res_sample)
381
+ controlnet_down_block_res_samples += (down_block_res_sample,)
382
+
383
+ down_block_res_samples = controlnet_down_block_res_samples
384
+
385
+ mid_block_res_sample = self.controlnet_mid_block(sample)
386
+
387
+ # 6. scaling
388
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
389
+ mid_block_res_sample *= conditioning_scale
390
+
391
+ if not return_dict:
392
+ return (down_block_res_samples, mid_block_res_sample)
393
+
394
+ return FlaxControlNetOutput(
395
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
396
+ )
codi/pipeline_flax_controlnet.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import warnings
16
+ from functools import partial
17
+ from typing import Dict, List, Optional, Union
18
+
19
+ import jax
20
+ import jax.numpy as jnp
21
+ import numpy as np
22
+ from flax.core.frozen_dict import FrozenDict
23
+ from flax.jax_utils import unreplicate
24
+ from flax.training.common_utils import shard
25
+ from PIL import Image
26
+ from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
27
+
28
+ from diffusers.models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
29
+ from diffusers.schedulers import (
30
+ FlaxDDIMScheduler,
31
+ FlaxDPMSolverMultistepScheduler,
32
+ FlaxLMSDiscreteScheduler,
33
+ FlaxPNDMScheduler,
34
+ )
35
+ from diffusers.utils import PIL_INTERPOLATION, logging, replace_example_docstring
36
+ from diffusers.pipelines.pipeline_flax_utils import FlaxDiffusionPipeline
37
+ from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionPipelineOutput
38
+ from diffusers.pipelines.stable_diffusion.safety_checker_flax import FlaxStableDiffusionSafetyChecker
39
+ from diffusers.schedulers.scheduling_utils_flax import get_sqrt_alpha_prod
40
+ from diffusers.schedulers.scheduling_utils_flax import broadcast_to_shape_from_left
41
+
42
+ from codi.controlnet_flax import FlaxControlNetModel
43
+
44
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
45
+
46
+ # Set to True to use python for loop instead of jax.fori_loop for easier debugging
47
+ DEBUG = False
48
+
49
+ EXAMPLE_DOC_STRING = """
50
+ Examples:
51
+ ```py
52
+ >>> import jax
53
+ >>> import numpy as np
54
+ >>> import jax.numpy as jnp
55
+ >>> from flax.jax_utils import replicate
56
+ >>> from flax.training.common_utils import shard
57
+ >>> from diffusers.utils import load_image, make_image_grid
58
+ >>> from PIL import Image
59
+ >>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
60
+
61
+
62
+ >>> def create_key(seed=0):
63
+ ... return jax.random.PRNGKey(seed)
64
+
65
+
66
+ >>> rng = create_key(0)
67
+
68
+ >>> # get canny image
69
+ >>> canny_image = load_image(
70
+ ... "https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg"
71
+ ... )
72
+
73
+ >>> prompts = "best quality, extremely detailed"
74
+ >>> negative_prompts = "monochrome, lowres, bad anatomy, worst quality, low quality"
75
+
76
+ >>> # load control net and stable diffusion v1-5
77
+ >>> controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
78
+ ... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32
79
+ ... )
80
+ >>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
81
+ ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32
82
+ ... )
83
+ >>> params["controlnet"] = controlnet_params
84
+
85
+ >>> num_samples = jax.device_count()
86
+ >>> rng = jax.random.split(rng, jax.device_count())
87
+
88
+ >>> prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
89
+ >>> negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
90
+ >>> processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
91
+
92
+ >>> p_params = replicate(params)
93
+ >>> prompt_ids = shard(prompt_ids)
94
+ >>> negative_prompt_ids = shard(negative_prompt_ids)
95
+ >>> processed_image = shard(processed_image)
96
+
97
+ >>> output = pipe(
98
+ ... prompt_ids=prompt_ids,
99
+ ... image=processed_image,
100
+ ... params=p_params,
101
+ ... prng_seed=rng,
102
+ ... num_inference_steps=50,
103
+ ... neg_prompt_ids=negative_prompt_ids,
104
+ ... jit=True,
105
+ ... ).images
106
+
107
+ >>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
108
+ >>> output_images = make_image_grid(output_images, num_samples // 4, 4)
109
+ >>> output_images.save("generated_image.png")
110
+ ```
111
+ """
112
+
113
+ def scalings_for_boundary_conditions(
114
+ timestep, sigma_data=0.5, timestep_scaling=10.0
115
+ ):
116
+ scaled_timestep = timestep * timestep_scaling
117
+ c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
118
+ c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
119
+ return c_skip, c_out
120
+
121
+ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
122
+ r"""
123
+ Flax-based pipeline for text-to-image generation using Stable Diffusion with ControlNet Guidance.
124
+
125
+ This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods
126
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
127
+
128
+ Args:
129
+ vae ([`FlaxAutoencoderKL`]):
130
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
131
+ text_encoder ([`~transformers.FlaxCLIPTextModel`]):
132
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
133
+ tokenizer ([`~transformers.CLIPTokenizer`]):
134
+ A `CLIPTokenizer` to tokenize text.
135
+ unet ([`FlaxUNet2DConditionModel`]):
136
+ A `FlaxUNet2DConditionModel` to denoise the encoded image latents.
137
+ controlnet ([`FlaxControlNetModel`]:
138
+ Provides additional conditioning to the `unet` during the denoising process.
139
+ scheduler ([`SchedulerMixin`]):
140
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
141
+ [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or
142
+ [`FlaxDPMSolverMultistepScheduler`].
143
+ safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
144
+ Classification module that estimates whether generated images could be considered offensive or harmful.
145
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
146
+ about a model's potential harms.
147
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
148
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
149
+ """
150
+
151
+ def __init__(
152
+ self,
153
+ vae: FlaxAutoencoderKL,
154
+ text_encoder: FlaxCLIPTextModel,
155
+ tokenizer: CLIPTokenizer,
156
+ unet: FlaxUNet2DConditionModel,
157
+ controlnet: FlaxControlNetModel,
158
+ scheduler: Union[
159
+ FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
160
+ ],
161
+ safety_checker: FlaxStableDiffusionSafetyChecker,
162
+ feature_extractor: CLIPFeatureExtractor,
163
+ dtype: jnp.dtype = jnp.float32,
164
+ ):
165
+ super().__init__()
166
+ self.dtype = dtype
167
+
168
+ if safety_checker is None:
169
+ logger.warn(
170
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
171
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
172
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
173
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
174
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
175
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
176
+ )
177
+
178
+ self.register_modules(
179
+ vae=vae,
180
+ text_encoder=text_encoder,
181
+ tokenizer=tokenizer,
182
+ unet=unet,
183
+ controlnet=controlnet,
184
+ scheduler=scheduler,
185
+ safety_checker=safety_checker,
186
+ feature_extractor=feature_extractor,
187
+ )
188
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
189
+
190
+ def prepare_text_inputs(self, prompt: Union[str, List[str]]):
191
+ if not isinstance(prompt, (str, list)):
192
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
193
+
194
+ text_input = self.tokenizer(
195
+ prompt,
196
+ padding="max_length",
197
+ max_length=self.tokenizer.model_max_length,
198
+ truncation=True,
199
+ return_tensors="np",
200
+ )
201
+
202
+ return text_input.input_ids
203
+
204
+ def prepare_image_inputs(self, image: Union[Image.Image, List[Image.Image]]):
205
+ if not isinstance(image, (Image.Image, list)):
206
+ raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}")
207
+
208
+ if isinstance(image, Image.Image):
209
+ image = [image]
210
+
211
+ processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image])
212
+
213
+ return processed_images
214
+
215
+ def _get_has_nsfw_concepts(self, features, params):
216
+ has_nsfw_concepts = self.safety_checker(features, params)
217
+ return has_nsfw_concepts
218
+
219
+ def _run_safety_checker(self, images, safety_model_params, jit=False):
220
+ # safety_model_params should already be replicated when jit is True
221
+ pil_images = [Image.fromarray(image) for image in images]
222
+ features = self.feature_extractor(pil_images, return_tensors="np").pixel_values
223
+
224
+ if jit:
225
+ features = shard(features)
226
+ has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params)
227
+ has_nsfw_concepts = unshard(has_nsfw_concepts)
228
+ safety_model_params = unreplicate(safety_model_params)
229
+ else:
230
+ has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params)
231
+
232
+ images_was_copied = False
233
+ for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
234
+ if has_nsfw_concept:
235
+ if not images_was_copied:
236
+ images_was_copied = True
237
+ images = images.copy()
238
+
239
+ images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image
240
+
241
+ if any(has_nsfw_concepts):
242
+ warnings.warn(
243
+ "Potential NSFW content was detected in one or more images. A black image will be returned"
244
+ " instead. Try again with a different prompt and/or seed."
245
+ )
246
+
247
+ return images, has_nsfw_concepts
248
+
249
+ def _generate(
250
+ self,
251
+ prompt_ids: jnp.ndarray,
252
+ image: jnp.ndarray,
253
+ params: Union[Dict, FrozenDict],
254
+ prng_seed: jax.Array,
255
+ num_inference_steps: int,
256
+ guidance_scale: float,
257
+ latents: Optional[jnp.ndarray] = None,
258
+ neg_prompt_ids: Optional[jnp.ndarray] = None,
259
+ controlnet_conditioning_scale: float = 1.0,
260
+ height: int = 512,
261
+ width: int = 512,
262
+ distill_timestep_scaling: int = 10,
263
+ distill_learning_steps: int = 50,
264
+ onestepode_sample_eps: str = "nprediction"
265
+ ):
266
+ if image is not None:
267
+ height, width = image.shape[-2:]
268
+
269
+ if height % 64 != 0 or width % 64 != 0:
270
+ raise ValueError(f"`height` and `width` have to be divisible by 64 but are {height} and {width}.")
271
+
272
+ # get prompt text embeddings
273
+ prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
274
+
275
+ # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
276
+ # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
277
+ batch_size = prompt_ids.shape[0]
278
+
279
+ max_length = prompt_ids.shape[-1]
280
+
281
+ if neg_prompt_ids is None:
282
+ uncond_input = self.tokenizer(
283
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
284
+ ).input_ids
285
+ else:
286
+ uncond_input = neg_prompt_ids
287
+ negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
288
+ context = jnp.concatenate([negative_prompt_embeds, prompt_embeds])
289
+
290
+ if image is not None:
291
+ image = jnp.concatenate([image] * 2)
292
+
293
+ latents_shape = (
294
+ batch_size,
295
+ self.unet.config.in_channels,
296
+ height // self.vae_scale_factor,
297
+ width // self.vae_scale_factor,
298
+ )
299
+ if latents is None:
300
+ latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
301
+ else:
302
+ if latents.shape != latents_shape:
303
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
304
+
305
+ def loop_body(step, args):
306
+ latents, scheduler_state = args
307
+ latents_input = jnp.concatenate([latents] * 2)
308
+
309
+ t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
310
+ timestep = jnp.broadcast_to(t, latents.shape[0])
311
+ next_t = jnp.where(step < num_inference_steps -1, jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step + 1], 0)
312
+ next_timestep = jnp.broadcast_to(next_t, latents.shape[0])
313
+
314
+ c_skip, c_out = scalings_for_boundary_conditions(
315
+ timestep, timestep_scaling=distill_timestep_scaling,
316
+ )
317
+ alpha_t, sigma_t = get_sqrt_alpha_prod(
318
+ scheduler_state.common,
319
+ latents, # only used for determining shape
320
+ latents, # unused code
321
+ timestep,
322
+ )
323
+ alpha_s, sigma_s = get_sqrt_alpha_prod(
324
+ scheduler_state.common,
325
+ latents, # only used for determining shape
326
+ latents, # unused code
327
+ next_timestep,
328
+ )
329
+
330
+ # jax.debug.print("timestep {}", timestep)
331
+ # jax.debug.print("next_timestep {}", next_timestep)
332
+ # jax.debug.print("c_skip {}", c_skip)
333
+ # jax.debug.print("c_out {}", c_out)
334
+ # jax.debug.print("alpha_s {}", alpha_s.mean())
335
+ # jax.debug.print("sigma_s {}", sigma_s.mean())
336
+
337
+ c_skip = broadcast_to_shape_from_left(c_skip, latents.shape)
338
+ c_out = broadcast_to_shape_from_left(c_out, latents.shape)
339
+
340
+ latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t)
341
+
342
+ down_block_res_samples, mid_block_res_sample = self.controlnet.apply(
343
+ {"params": params["controlnet"]},
344
+ jnp.array(latents_input),
345
+ jnp.array(jnp.concatenate([timestep] * 2, axis=0), dtype=jnp.int32),
346
+ encoder_hidden_states=context,
347
+ controlnet_cond=image,
348
+ conditioning_scale=controlnet_conditioning_scale,
349
+ return_dict=False,
350
+ )
351
+
352
+ # predict the noise residual
353
+ model_pred = self.unet.apply(
354
+ {"params": params["unet"]},
355
+ jnp.array(latents_input),
356
+ jnp.array(timestep, dtype=jnp.int32),
357
+ encoder_hidden_states=context,
358
+ down_block_additional_residuals=down_block_res_samples,
359
+ mid_block_additional_residual=mid_block_res_sample,
360
+ ).sample
361
+
362
+ # perform guidance
363
+ mode_pred_uncond, model_prediction_text = jnp.split(model_pred, 2, axis=0)
364
+ model_pred = mode_pred_uncond + guidance_scale * (model_prediction_text - mode_pred_uncond)
365
+
366
+ if onestepode_sample_eps == 'nprediction':
367
+ target_model_pred_x = (latents - sigma_t * model_pred ) / alpha_t
368
+ target_model_pred_epsilon = model_pred
369
+ elif args.onestepode_sample_eps == 'vprediction':
370
+ target_model_pred_epsilon = (
371
+ alpha_t * model_pred + sigma_t * latents_input
372
+ )
373
+ target_model_pred_x = (
374
+ alpha_t * latents - sigma_t * model_pred
375
+ )
376
+ elif args.onestepode_sample_eps == 'xprediction':
377
+ target_model_pred_x = model_pred
378
+ target_model_pred_epsilon = (latents - alpha_t * model_pred) / sigma_t
379
+ else:
380
+ raise NotImplementedError
381
+
382
+ target_model_pred_x = (
383
+ c_skip * latents + c_out * target_model_pred_x
384
+ )
385
+
386
+ latents = alpha_s * target_model_pred_x + sigma_s * target_model_pred_epsilon
387
+ return latents, scheduler_state
388
+
389
+ scheduler_state = params["scheduler"]
390
+ skipped_schedule = self.scheduler.num_train_timesteps // distill_learning_steps
391
+ timesteps = (jnp.arange(0, distill_learning_steps) * skipped_schedule).round()[::-1]
392
+ step_ratio = (distill_learning_steps + num_inference_steps - 1) // num_inference_steps
393
+ timesteps = timesteps[::step_ratio]
394
+ scheduler_state = scheduler_state.replace(
395
+ num_inference_steps=num_inference_steps, timesteps=timesteps
396
+ )
397
+
398
+ # scale the initial noise by the standard deviation required by the scheduler
399
+ latents = latents * params["scheduler"].init_noise_sigma
400
+
401
+ if DEBUG:
402
+ # run with python for loop
403
+ for i in range(num_inference_steps):
404
+ latents, scheduler_state = loop_body(i, (latents, scheduler_state))
405
+ else:
406
+ latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state))
407
+
408
+ # scale and decode the image latents with vae
409
+ latents = 1 / self.vae.config.scaling_factor * latents
410
+ image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
411
+
412
+ image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
413
+ return image
414
+
415
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
416
+ def __call__(
417
+ self,
418
+ prompt_ids: jnp.ndarray,
419
+ image: jnp.ndarray,
420
+ params: Union[Dict, FrozenDict],
421
+ prng_seed: jax.Array,
422
+ num_inference_steps: int = 50,
423
+ guidance_scale: Union[float, jnp.ndarray] = 7.5,
424
+ latents: jnp.ndarray = None,
425
+ neg_prompt_ids: jnp.ndarray = None,
426
+ controlnet_conditioning_scale: Union[float, jnp.ndarray] = 1.0,
427
+ return_dict: bool = True,
428
+ jit: bool = False,
429
+ height: int = 512,
430
+ width: int = 512,
431
+ ):
432
+ r"""
433
+ The call function to the pipeline for generation.
434
+
435
+ Args:
436
+ prompt_ids (`jnp.ndarray`):
437
+ The prompt or prompts to guide the image generation.
438
+ image (`jnp.ndarray`):
439
+ Array representing the ControlNet input condition to provide guidance to the `unet` for generation.
440
+ params (`Dict` or `FrozenDict`):
441
+ Dictionary containing the model parameters/weights.
442
+ prng_seed (`jax.Array`):
443
+ Array containing random number generator key.
444
+ num_inference_steps (`int`, *optional*, defaults to 50):
445
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
446
+ expense of slower inference.
447
+ guidance_scale (`float`, *optional*, defaults to 7.5):
448
+ A higher guidance scale value encourages the model to generate images closely linked to the text
449
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
450
+ latents (`jnp.ndarray`, *optional*):
451
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
452
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
453
+ array is generated by sampling using the supplied random `generator`.
454
+ controlnet_conditioning_scale (`float` or `jnp.ndarray`, *optional*, defaults to 1.0):
455
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
456
+ to the residual in the original `unet`.
457
+ return_dict (`bool`, *optional*, defaults to `True`):
458
+ Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
459
+ a plain tuple.
460
+ jit (`bool`, defaults to `False`):
461
+ Whether to run `pmap` versions of the generation and safety scoring functions.
462
+
463
+ <Tip warning={true}>
464
+
465
+ This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a
466
+ future release.
467
+
468
+ </Tip>
469
+
470
+ Examples:
471
+
472
+ Returns:
473
+ [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
474
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] is
475
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated images
476
+ and the second element is a list of `bool`s indicating whether the corresponding generated image
477
+ contains "not-safe-for-work" (nsfw) content.
478
+ """
479
+
480
+ if image is not None:
481
+ height, width = image.shape[-2:]
482
+
483
+ if isinstance(guidance_scale, float):
484
+ # Convert to a tensor so each device gets a copy. Follow the prompt_ids for
485
+ # shape information, as they may be sharded (when `jit` is `True`), or not.
486
+ guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0])
487
+ if len(prompt_ids.shape) > 2:
488
+ # Assume sharded
489
+ guidance_scale = guidance_scale[:, None]
490
+
491
+ if isinstance(controlnet_conditioning_scale, float):
492
+ # Convert to a tensor so each device gets a copy. Follow the prompt_ids for
493
+ # shape information, as they may be sharded (when `jit` is `True`), or not.
494
+ controlnet_conditioning_scale = jnp.array([controlnet_conditioning_scale] * prompt_ids.shape[0])
495
+ if len(prompt_ids.shape) > 2:
496
+ # Assume sharded
497
+ controlnet_conditioning_scale = controlnet_conditioning_scale[:, None]
498
+
499
+ if jit:
500
+ images = _p_generate(
501
+ self,
502
+ prompt_ids,
503
+ image,
504
+ params,
505
+ prng_seed,
506
+ num_inference_steps,
507
+ guidance_scale,
508
+ latents,
509
+ neg_prompt_ids,
510
+ controlnet_conditioning_scale,
511
+ height,
512
+ width,
513
+ )
514
+ else:
515
+ images = self._generate(
516
+ prompt_ids,
517
+ image,
518
+ params,
519
+ prng_seed,
520
+ num_inference_steps,
521
+ guidance_scale,
522
+ latents,
523
+ neg_prompt_ids,
524
+ controlnet_conditioning_scale,
525
+ height,
526
+ width,
527
+ )
528
+
529
+ if self.safety_checker is not None:
530
+ safety_params = params["safety_checker"]
531
+ images_uint8_casted = (images * 255).round().astype("uint8")
532
+ num_devices, batch_size = images.shape[:2]
533
+
534
+ images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
535
+ images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit)
536
+ images = np.array(images)
537
+
538
+ # block images
539
+ if any(has_nsfw_concept):
540
+ for i, is_nsfw in enumerate(has_nsfw_concept):
541
+ if is_nsfw:
542
+ images[i] = np.asarray(images_uint8_casted[i])
543
+
544
+ images = images.reshape(num_devices, batch_size, height, width, 3)
545
+ else:
546
+ images = np.asarray(images)
547
+ has_nsfw_concept = False
548
+
549
+ if not return_dict:
550
+ return (images, has_nsfw_concept)
551
+
552
+ return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
553
+
554
+
555
+ # Static argnums are pipe, num_inference_steps. A change would trigger recompilation.
556
+ # Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
557
+ @partial(
558
+ jax.pmap,
559
+ in_axes=(None, 0, 0, 0, 0, None, 0, 0, 0, 0, 0, 0),
560
+ static_broadcasted_argnums=(0, 5, 10, 11),
561
+ )
562
+ def _p_generate(
563
+ pipe,
564
+ prompt_ids,
565
+ image,
566
+ params,
567
+ prng_seed,
568
+ num_inference_steps,
569
+ guidance_scale,
570
+ latents,
571
+ neg_prompt_ids,
572
+ controlnet_conditioning_scale,
573
+ height,
574
+ width,
575
+ ):
576
+ return pipe._generate(
577
+ prompt_ids,
578
+ image,
579
+ params,
580
+ prng_seed,
581
+ num_inference_steps,
582
+ guidance_scale,
583
+ latents,
584
+ neg_prompt_ids,
585
+ controlnet_conditioning_scale,
586
+ height,
587
+ width,
588
+ )
589
+
590
+
591
+ @partial(jax.pmap, static_broadcasted_argnums=(0,))
592
+ def _p_get_has_nsfw_concepts(pipe, features, params):
593
+ return pipe._get_has_nsfw_concepts(features, params)
594
+
595
+
596
+ def unshard(x: jnp.ndarray):
597
+ # einops.rearrange(x, 'd b ... -> (d b) ...')
598
+ num_devices, batch_size = x.shape[:2]
599
+ rest = x.shape[2:]
600
+ return x.reshape(num_devices * batch_size, *rest)
601
+
602
+
603
+ def preprocess(image, dtype):
604
+ image = image.convert("RGB")
605
+ w, h = image.size
606
+ w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64
607
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
608
+ image = jnp.array(image).astype(dtype) / 255.0
609
+ image = image[None].transpose(0, 3, 1, 2)
610
+ return image