|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import flax.linen as nn |
|
import jax.numpy as jnp |
|
|
|
from ..attention_flax import FlaxTransformer2DModel |
|
from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D |
|
|
|
|
|
class FlaxCrossAttnDownBlock2D(nn.Module): |
|
r""" |
|
Cross Attention 2D Downsizing block - original architecture from Unet transformers: |
|
https://arxiv.org/abs/2103.06104 |
|
|
|
Parameters: |
|
in_channels (:obj:`int`): |
|
Input channels |
|
out_channels (:obj:`int`): |
|
Output channels |
|
dropout (:obj:`float`, *optional*, defaults to 0.0): |
|
Dropout rate |
|
num_layers (:obj:`int`, *optional*, defaults to 1): |
|
Number of attention blocks layers |
|
num_attention_heads (:obj:`int`, *optional*, defaults to 1): |
|
Number of attention heads of each spatial transformer block |
|
add_downsample (:obj:`bool`, *optional*, defaults to `True`): |
|
Whether to add downsampling layer before each final output |
|
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): |
|
enable memory efficient attention https://arxiv.org/abs/2112.05682 |
|
split_head_dim (`bool`, *optional*, defaults to `False`): |
|
Whether to split the head dimension into a new axis for the self-attention computation. In most cases, |
|
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. |
|
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): |
|
Parameters `dtype` |
|
""" |
|
|
|
in_channels: int |
|
out_channels: int |
|
dropout: float = 0.0 |
|
num_layers: int = 1 |
|
num_attention_heads: int = 1 |
|
add_downsample: bool = True |
|
use_linear_projection: bool = False |
|
only_cross_attention: bool = False |
|
use_memory_efficient_attention: bool = False |
|
split_head_dim: bool = False |
|
dtype: jnp.dtype = jnp.float32 |
|
transformer_layers_per_block: int = 1 |
|
|
|
def setup(self): |
|
resnets = [] |
|
attentions = [] |
|
|
|
for i in range(self.num_layers): |
|
in_channels = self.in_channels if i == 0 else self.out_channels |
|
|
|
res_block = FlaxResnetBlock2D( |
|
in_channels=in_channels, |
|
out_channels=self.out_channels, |
|
dropout_prob=self.dropout, |
|
dtype=self.dtype, |
|
) |
|
resnets.append(res_block) |
|
|
|
attn_block = FlaxTransformer2DModel( |
|
in_channels=self.out_channels, |
|
n_heads=self.num_attention_heads, |
|
d_head=self.out_channels // self.num_attention_heads, |
|
depth=self.transformer_layers_per_block, |
|
use_linear_projection=self.use_linear_projection, |
|
only_cross_attention=self.only_cross_attention, |
|
use_memory_efficient_attention=self.use_memory_efficient_attention, |
|
split_head_dim=self.split_head_dim, |
|
dtype=self.dtype, |
|
) |
|
attentions.append(attn_block) |
|
|
|
self.resnets = resnets |
|
self.attentions = attentions |
|
|
|
if self.add_downsample: |
|
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) |
|
|
|
def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): |
|
output_states = () |
|
|
|
for resnet, attn in zip(self.resnets, self.attentions): |
|
hidden_states = resnet(hidden_states, temb, deterministic=deterministic) |
|
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) |
|
output_states += (hidden_states,) |
|
|
|
if self.add_downsample: |
|
hidden_states = self.downsamplers_0(hidden_states) |
|
output_states += (hidden_states,) |
|
|
|
return hidden_states, output_states |
|
|
|
|
|
class FlaxDownBlock2D(nn.Module): |
|
r""" |
|
Flax 2D downsizing block |
|
|
|
Parameters: |
|
in_channels (:obj:`int`): |
|
Input channels |
|
out_channels (:obj:`int`): |
|
Output channels |
|
dropout (:obj:`float`, *optional*, defaults to 0.0): |
|
Dropout rate |
|
num_layers (:obj:`int`, *optional*, defaults to 1): |
|
Number of attention blocks layers |
|
add_downsample (:obj:`bool`, *optional*, defaults to `True`): |
|
Whether to add downsampling layer before each final output |
|
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): |
|
Parameters `dtype` |
|
""" |
|
|
|
in_channels: int |
|
out_channels: int |
|
dropout: float = 0.0 |
|
num_layers: int = 1 |
|
add_downsample: bool = True |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
resnets = [] |
|
|
|
for i in range(self.num_layers): |
|
in_channels = self.in_channels if i == 0 else self.out_channels |
|
|
|
res_block = FlaxResnetBlock2D( |
|
in_channels=in_channels, |
|
out_channels=self.out_channels, |
|
dropout_prob=self.dropout, |
|
dtype=self.dtype, |
|
) |
|
resnets.append(res_block) |
|
self.resnets = resnets |
|
|
|
if self.add_downsample: |
|
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) |
|
|
|
def __call__(self, hidden_states, temb, deterministic=True): |
|
output_states = () |
|
|
|
for resnet in self.resnets: |
|
hidden_states = resnet(hidden_states, temb, deterministic=deterministic) |
|
output_states += (hidden_states,) |
|
|
|
if self.add_downsample: |
|
hidden_states = self.downsamplers_0(hidden_states) |
|
output_states += (hidden_states,) |
|
|
|
return hidden_states, output_states |
|
|
|
|
|
class FlaxCrossAttnUpBlock2D(nn.Module): |
|
r""" |
|
Cross Attention 2D Upsampling block - original architecture from Unet transformers: |
|
https://arxiv.org/abs/2103.06104 |
|
|
|
Parameters: |
|
in_channels (:obj:`int`): |
|
Input channels |
|
out_channels (:obj:`int`): |
|
Output channels |
|
dropout (:obj:`float`, *optional*, defaults to 0.0): |
|
Dropout rate |
|
num_layers (:obj:`int`, *optional*, defaults to 1): |
|
Number of attention blocks layers |
|
num_attention_heads (:obj:`int`, *optional*, defaults to 1): |
|
Number of attention heads of each spatial transformer block |
|
add_upsample (:obj:`bool`, *optional*, defaults to `True`): |
|
Whether to add upsampling layer before each final output |
|
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): |
|
enable memory efficient attention https://arxiv.org/abs/2112.05682 |
|
split_head_dim (`bool`, *optional*, defaults to `False`): |
|
Whether to split the head dimension into a new axis for the self-attention computation. In most cases, |
|
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. |
|
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): |
|
Parameters `dtype` |
|
""" |
|
|
|
in_channels: int |
|
out_channels: int |
|
prev_output_channel: int |
|
dropout: float = 0.0 |
|
num_layers: int = 1 |
|
num_attention_heads: int = 1 |
|
add_upsample: bool = True |
|
use_linear_projection: bool = False |
|
only_cross_attention: bool = False |
|
use_memory_efficient_attention: bool = False |
|
split_head_dim: bool = False |
|
dtype: jnp.dtype = jnp.float32 |
|
transformer_layers_per_block: int = 1 |
|
|
|
def setup(self): |
|
resnets = [] |
|
attentions = [] |
|
|
|
for i in range(self.num_layers): |
|
res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels |
|
resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels |
|
|
|
res_block = FlaxResnetBlock2D( |
|
in_channels=resnet_in_channels + res_skip_channels, |
|
out_channels=self.out_channels, |
|
dropout_prob=self.dropout, |
|
dtype=self.dtype, |
|
) |
|
resnets.append(res_block) |
|
|
|
attn_block = FlaxTransformer2DModel( |
|
in_channels=self.out_channels, |
|
n_heads=self.num_attention_heads, |
|
d_head=self.out_channels // self.num_attention_heads, |
|
depth=self.transformer_layers_per_block, |
|
use_linear_projection=self.use_linear_projection, |
|
only_cross_attention=self.only_cross_attention, |
|
use_memory_efficient_attention=self.use_memory_efficient_attention, |
|
split_head_dim=self.split_head_dim, |
|
dtype=self.dtype, |
|
) |
|
attentions.append(attn_block) |
|
|
|
self.resnets = resnets |
|
self.attentions = attentions |
|
|
|
if self.add_upsample: |
|
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) |
|
|
|
def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True): |
|
for resnet, attn in zip(self.resnets, self.attentions): |
|
|
|
res_hidden_states = res_hidden_states_tuple[-1] |
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1] |
|
hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) |
|
|
|
hidden_states = resnet(hidden_states, temb, deterministic=deterministic) |
|
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) |
|
|
|
if self.add_upsample: |
|
hidden_states = self.upsamplers_0(hidden_states) |
|
|
|
return hidden_states |
|
|
|
|
|
class FlaxUpBlock2D(nn.Module): |
|
r""" |
|
Flax 2D upsampling block |
|
|
|
Parameters: |
|
in_channels (:obj:`int`): |
|
Input channels |
|
out_channels (:obj:`int`): |
|
Output channels |
|
prev_output_channel (:obj:`int`): |
|
Output channels from the previous block |
|
dropout (:obj:`float`, *optional*, defaults to 0.0): |
|
Dropout rate |
|
num_layers (:obj:`int`, *optional*, defaults to 1): |
|
Number of attention blocks layers |
|
add_downsample (:obj:`bool`, *optional*, defaults to `True`): |
|
Whether to add downsampling layer before each final output |
|
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): |
|
Parameters `dtype` |
|
""" |
|
|
|
in_channels: int |
|
out_channels: int |
|
prev_output_channel: int |
|
dropout: float = 0.0 |
|
num_layers: int = 1 |
|
add_upsample: bool = True |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
resnets = [] |
|
|
|
for i in range(self.num_layers): |
|
res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels |
|
resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels |
|
|
|
res_block = FlaxResnetBlock2D( |
|
in_channels=resnet_in_channels + res_skip_channels, |
|
out_channels=self.out_channels, |
|
dropout_prob=self.dropout, |
|
dtype=self.dtype, |
|
) |
|
resnets.append(res_block) |
|
|
|
self.resnets = resnets |
|
|
|
if self.add_upsample: |
|
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) |
|
|
|
def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True): |
|
for resnet in self.resnets: |
|
|
|
res_hidden_states = res_hidden_states_tuple[-1] |
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1] |
|
hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) |
|
|
|
hidden_states = resnet(hidden_states, temb, deterministic=deterministic) |
|
|
|
if self.add_upsample: |
|
hidden_states = self.upsamplers_0(hidden_states) |
|
|
|
return hidden_states |
|
|
|
|
|
class FlaxUNetMidBlock2DCrossAttn(nn.Module): |
|
r""" |
|
Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104 |
|
|
|
Parameters: |
|
in_channels (:obj:`int`): |
|
Input channels |
|
dropout (:obj:`float`, *optional*, defaults to 0.0): |
|
Dropout rate |
|
num_layers (:obj:`int`, *optional*, defaults to 1): |
|
Number of attention blocks layers |
|
num_attention_heads (:obj:`int`, *optional*, defaults to 1): |
|
Number of attention heads of each spatial transformer block |
|
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): |
|
enable memory efficient attention https://arxiv.org/abs/2112.05682 |
|
split_head_dim (`bool`, *optional*, defaults to `False`): |
|
Whether to split the head dimension into a new axis for the self-attention computation. In most cases, |
|
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. |
|
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): |
|
Parameters `dtype` |
|
""" |
|
|
|
in_channels: int |
|
dropout: float = 0.0 |
|
num_layers: int = 1 |
|
num_attention_heads: int = 1 |
|
use_linear_projection: bool = False |
|
use_memory_efficient_attention: bool = False |
|
split_head_dim: bool = False |
|
dtype: jnp.dtype = jnp.float32 |
|
transformer_layers_per_block: int = 1 |
|
|
|
def setup(self): |
|
|
|
resnets = [ |
|
FlaxResnetBlock2D( |
|
in_channels=self.in_channels, |
|
out_channels=self.in_channels, |
|
dropout_prob=self.dropout, |
|
dtype=self.dtype, |
|
) |
|
] |
|
|
|
attentions = [] |
|
|
|
for _ in range(self.num_layers): |
|
attn_block = FlaxTransformer2DModel( |
|
in_channels=self.in_channels, |
|
n_heads=self.num_attention_heads, |
|
d_head=self.in_channels // self.num_attention_heads, |
|
depth=self.transformer_layers_per_block, |
|
use_linear_projection=self.use_linear_projection, |
|
use_memory_efficient_attention=self.use_memory_efficient_attention, |
|
split_head_dim=self.split_head_dim, |
|
dtype=self.dtype, |
|
) |
|
attentions.append(attn_block) |
|
|
|
res_block = FlaxResnetBlock2D( |
|
in_channels=self.in_channels, |
|
out_channels=self.in_channels, |
|
dropout_prob=self.dropout, |
|
dtype=self.dtype, |
|
) |
|
resnets.append(res_block) |
|
|
|
self.resnets = resnets |
|
self.attentions = attentions |
|
|
|
def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): |
|
hidden_states = self.resnets[0](hidden_states, temb) |
|
for attn, resnet in zip(self.attentions, self.resnets[1:]): |
|
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) |
|
hidden_states = resnet(hidden_states, temb, deterministic=deterministic) |
|
|
|
return hidden_states |
|
|