import logging import torch import torch.nn as nn import torch.nn.functional as F from diffusers import AutoencoderKL, DDPMScheduler from leffa.diffusion_model.unet_ref import ( UNet2DConditionModel as ReferenceUNet, ) from leffa.diffusion_model.unet_gen import ( UNet2DConditionModel as GenerativeUNet, ) logger: logging.Logger = logging.getLogger(__name__) class LeffaModel(nn.Module): def __init__( self, pretrained_model_name_or_path: str = "", pretrained_model: str = "", new_in_channels: int = 12, # noisy_image: 4, mask: 1, masked_image: 4, densepose: 3 height: int = 1024, width: int = 768, ): super().__init__() self.height = height self.width = width self.build_models( pretrained_model_name_or_path, pretrained_model, new_in_channels, ) def build_models( self, pretrained_model_name_or_path: str = "", pretrained_model: str = "", new_in_channels: int = 12, ): diffusion_model_type = "" if "stable-diffusion-inpainting" in pretrained_model_name_or_path: diffusion_model_type = "sd15" elif "stable-diffusion-xl-1.0-inpainting-0.1" in pretrained_model_name_or_path: diffusion_model_type = "sdxl" # Noise Scheduler self.noise_scheduler = DDPMScheduler.from_pretrained( pretrained_model_name_or_path, subfolder="scheduler", rescale_betas_zero_snr=False if diffusion_model_type == "sd15" else True, ) # VAE vae_config, vae_kwargs = AutoencoderKL.load_config( pretrained_model_name_or_path, subfolder="vae", return_unused_kwargs=True, ) self.vae = AutoencoderKL.from_config(vae_config, **vae_kwargs) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) # Reference UNet unet_config, unet_kwargs = ReferenceUNet.load_config( pretrained_model_name_or_path, subfolder="unet", return_unused_kwargs=True, ) self.unet_encoder = ReferenceUNet.from_config(unet_config, **unet_kwargs) self.unet_encoder.config.addition_embed_type = None # Generative UNet unet_config, unet_kwargs = GenerativeUNet.load_config( pretrained_model_name_or_path, subfolder="unet", return_unused_kwargs=True, ) self.unet = GenerativeUNet.from_config(unet_config, **unet_kwargs) self.unet.config.addition_embed_type = None # Change Generative UNet conv_in and conv_out unet_conv_in_channel_changed = self.unet.config.in_channels != new_in_channels if unet_conv_in_channel_changed: self.unet.conv_in = self.replace_conv_in_layer(self.unet, new_in_channels) self.unet.config.in_channels = new_in_channels unet_conv_out_channel_changed = ( self.unet.config.out_channels != self.vae.config.latent_channels ) if unet_conv_out_channel_changed: self.unet.conv_out = self.replace_conv_out_layer( self.unet, self.vae.config.latent_channels ) self.unet.config.out_channels = self.vae.config.latent_channels unet_encoder_conv_in_channel_changed = ( self.unet_encoder.config.in_channels != self.vae.config.latent_channels ) if unet_encoder_conv_in_channel_changed: self.unet_encoder.conv_in = self.replace_conv_in_layer( self.unet_encoder, self.vae.config.latent_channels ) self.unet_encoder.config.in_channels = self.vae.config.latent_channels unet_encoder_conv_out_channel_changed = ( self.unet_encoder.config.out_channels != self.vae.config.latent_channels ) if unet_encoder_conv_out_channel_changed: self.unet_encoder.conv_out = self.replace_conv_out_layer( self.unet_encoder, self.vae.config.latent_channels ) self.unet_encoder.config.out_channels = self.vae.config.latent_channels # Remove Cross Attention remove_cross_attention(self.unet) remove_cross_attention(self.unet_encoder, model_type="unet_encoder") # Load pretrained model if pretrained_model != "" and pretrained_model is not None: self.load_state_dict(torch.load(pretrained_model, map_location="cpu")) logger.info("Load pretrained model from {}".format(pretrained_model)) def replace_conv_in_layer(self, unet_model, new_in_channels): original_conv_in = unet_model.conv_in if original_conv_in.in_channels == new_in_channels: return original_conv_in new_conv_in = torch.nn.Conv2d( in_channels=new_in_channels, out_channels=original_conv_in.out_channels, kernel_size=original_conv_in.kernel_size, padding=1, ) new_conv_in.weight.data.zero_() new_conv_in.bias.data = original_conv_in.bias.data.clone() if original_conv_in.in_channels < new_in_channels: new_conv_in.weight.data[:, : original_conv_in.in_channels] = ( original_conv_in.weight.data ) else: new_conv_in.weight.data[:, :new_in_channels] = original_conv_in.weight.data[ :, :new_in_channels ] return new_conv_in def replace_conv_out_layer(self, unet_model, new_out_channels): original_conv_out = unet_model.conv_out if original_conv_out.out_channels == new_out_channels: return original_conv_out new_conv_out = torch.nn.Conv2d( in_channels=original_conv_out.in_channels, out_channels=new_out_channels, kernel_size=original_conv_out.kernel_size, padding=1, ) new_conv_out.weight.data.zero_() new_conv_out.bias.data[: original_conv_out.out_channels] = ( original_conv_out.bias.data.clone() ) if original_conv_out.out_channels < new_out_channels: new_conv_out.weight.data[: original_conv_out.out_channels] = ( original_conv_out.weight.data ) else: new_conv_out.weight.data[:new_out_channels] = original_conv_out.weight.data[ :new_out_channels ] return new_conv_out def vae_encode(self, pixel_values): pixel_values = pixel_values.to(device=self.vae.device, dtype=self.vae.dtype) with torch.no_grad(): latent = self.vae.encode(pixel_values).latent_dist.sample() latent = latent * self.vae.config.scaling_factor return latent class SkipAttnProcessor(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__() def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): return hidden_states def remove_cross_attention( unet, cross_attn_cls=SkipAttnProcessor, self_attn_cls=None, cross_attn_dim=None, **kwargs, ): if cross_attn_dim is None: cross_attn_dim = unet.config.cross_attention_dim attn_procs = {} for name in unet.attn_processors.keys(): cross_attention_dim = ( None if name.endswith("attn1.processor") else cross_attn_dim ) if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] if cross_attention_dim is None: if self_attn_cls is not None: attn_procs[name] = self_attn_cls( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs, ) else: # retain the original attn processor attn_procs[name] = AttnProcessor2_0( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, layer_name=name, **kwargs, ) else: attn_procs[name] = cross_attn_cls( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs, ) unet.set_attn_processor(attn_procs) adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) return adapter_modules class AttnProcessor2_0(torch.nn.Module): r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ def __init__( self, hidden_size=None, cross_attention_dim=None, layer_name=None, **kwargs ): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) self.layer_name = layer_name self.model_type = kwargs.get("model_type", "none") def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, *args, **kwargs, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view( batch_size, channel, height * width ).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask( attention_mask, sequence_length, batch_size ) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view( batch_size, attn.heads, -1, attention_mask.shape[-1] ) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( 1, 2 ) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states( encoder_hidden_states ) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape( batch_size, -1, attn.heads * head_dim ) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape( batch_size, channel, height, width ) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states