import keras_cv import tensorflow as tf from diffusers import (AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel) from diffusers.pipelines.stable_diffusion.safety_checker import \ StableDiffusionSafetyChecker from transformers import CLIPTextModel, CLIPTokenizer from conversion_utils import populate_text_encoder, populate_unet PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4" REVISION = None NON_EMA_REVISION = None IMG_HEIGHT = IMG_WIDTH = 512 MAX_SEQ_LENGTH = 77 def initialize_pt_models(): """Initializes the separate models of Stable Diffusion from diffusers and downloads their pre-trained weights.""" pt_text_encoder = CLIPTextModel.from_pretrained( PRETRAINED_CKPT, subfolder="text_encoder", revision=REVISION ) pt_tokenizer = CLIPTokenizer.from_pretrained(PRETRAINED_CKPT, subfolder="tokenizer") pt_vae = AutoencoderKL.from_pretrained( PRETRAINED_CKPT, subfolder="vae", revision=REVISION ) pt_unet = UNet2DConditionModel.from_pretrained( PRETRAINED_CKPT, subfolder="unet", revision=NON_EMA_REVISION ) pt_safety_checker = StableDiffusionSafetyChecker.from_pretrained( PRETRAINED_CKPT, subfolder="safety_checker", revision=NON_EMA_REVISION ) return pt_text_encoder, pt_tokenizer, pt_vae, pt_unet, pt_safety_checker def initialize_tf_models( text_encoder_weights: str, unet_weights: str, placeholder_token: str = None ): """Initializes the separate models of Stable Diffusion from KerasCV and optionally downloads their pre-trained weights.""" tf_sd_model = keras_cv.models.StableDiffusion( img_height=IMG_HEIGHT, img_width=IMG_WIDTH ) if text_encoder_weights is None: tf_text_encoder = tf_sd_model.text_encoder else: tf_text_encoder = keras_cv.models.stable_diffusion.TextEncoder( MAX_SEQ_LENGTH, download_weights=False ) if unet_weights is None: tf_unet = tf_sd_model.diffusion_model else: tf_unet = keras_cv.models.stable_diffusion.DiffusionModel( IMG_HEIGHT, IMG_WIDTH, MAX_SEQ_LENGTH, download_weights=False ) tf_tokenizer = tf_sd_model.tokenizer if placeholder_token is not None: tf_tokenizer.add_tokens(placeholder_token) return tf_text_encoder, tf_unet, tf_tokenizer def create_new_text_encoder(tf_text_encoder, tf_tokenizer): """Initializes a fresh text encoder in case the weights are from Textual Inversion. Reference: https://keras.io/examples/generative/fine_tune_via_textual_inversion/ """ new_vocab_size = len(tf_tokenizer.vocab) new_text_encoder = keras_cv.models.stable_diffusion.TextEncoder( MAX_SEQ_LENGTH, vocab_size=new_vocab_size, download_weights=False, ) old_position_weights = tf_text_encoder.layers[2].position_embedding.get_weights() new_text_encoder.layers[2].position_embedding.set_weights(old_position_weights) return new_text_encoder def run_conversion( text_encoder_weights: str = None, unet_weights: str = None, placeholder_token: str = None, ): ( pt_text_encoder, pt_tokenizer, pt_vae, pt_unet, pt_safety_checker, ) = initialize_pt_models() tf_text_encoder, tf_unet, tf_tokenizer = initialize_tf_models( text_encoder_weights, unet_weights, placeholder_token ) print("Pre-trained model weights downloaded.") if placeholder_token is not None: print("Initializing a new text encoder with the placeholder token...") tf_text_encoder = create_new_text_encoder(tf_text_encoder, tf_tokenizer) print("Adding the placeholder token to PT CLIPTokenizer...") num_added_tokens = pt_tokenizer.add_tokens(placeholder_token) if num_added_tokens == 0: raise ValueError( f"The tokenizer already contains the token {placeholder_token}. Please pass a different" " `placeholder_token` that is not already in the tokenizer." ) if text_encoder_weights is not None: print("Loading fine-tuned text encoder weights.") text_encoder_weights_path = tf.keras.utils.get_file(origin=text_encoder_weights) tf_text_encoder.load_weights(text_encoder_weights_path) text_encoder_state_dict_from_tf = populate_text_encoder(tf_text_encoder) pt_text_encoder.load_state_dict(text_encoder_state_dict_from_tf) print("Populated PT text encoder from TF weights.") if unet_weights is not None: print("Loading fine-tuned UNet weights.") unet_weights_path = tf.keras.utils.get_file(origin=unet_weights) tf_unet.load_weights(unet_weights_path) unet_state_dict_from_tf = populate_unet(tf_unet) pt_unet.load_state_dict(unet_state_dict_from_tf) print("Populated PT UNet from TF weights.") print("Weights ported, preparing StabelDiffusionPipeline...") pipeline = StableDiffusionPipeline.from_pretrained( PRETRAINED_CKPT, unet=pt_unet, text_encoder=pt_text_encoder, tokenizer=pt_tokenizer, vae=pt_vae, safety_checker=pt_safety_checker, revision=None, ) return pipeline