Spaces:
Build error
Build error
import keras_cv | |
import tensorflow as tf | |
from diffusers import (AutoencoderKL, StableDiffusionPipeline, | |
UNet2DConditionModel) | |
from diffusers.pipelines.stable_diffusion.safety_checker import \ | |
StableDiffusionSafetyChecker | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from conversion_utils import populate_text_encoder, populate_unet | |
PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4" | |
REVISION = None | |
NON_EMA_REVISION = None | |
IMG_HEIGHT = IMG_WIDTH = 512 | |
MAX_SEQ_LENGTH = 77 | |
def initialize_pt_models(): | |
"""Initializes the separate models of Stable Diffusion from diffusers and downloads | |
their pre-trained weights.""" | |
pt_text_encoder = CLIPTextModel.from_pretrained( | |
PRETRAINED_CKPT, subfolder="text_encoder", revision=REVISION | |
) | |
pt_tokenizer = CLIPTokenizer.from_pretrained(PRETRAINED_CKPT, subfolder="tokenizer") | |
pt_vae = AutoencoderKL.from_pretrained( | |
PRETRAINED_CKPT, subfolder="vae", revision=REVISION | |
) | |
pt_unet = UNet2DConditionModel.from_pretrained( | |
PRETRAINED_CKPT, subfolder="unet", revision=NON_EMA_REVISION | |
) | |
pt_safety_checker = StableDiffusionSafetyChecker.from_pretrained( | |
PRETRAINED_CKPT, subfolder="safety_checker", revision=NON_EMA_REVISION | |
) | |
return pt_text_encoder, pt_tokenizer, pt_vae, pt_unet, pt_safety_checker | |
def initialize_tf_models( | |
text_encoder_weights: str, unet_weights: str, placeholder_token: str = None | |
): | |
"""Initializes the separate models of Stable Diffusion from KerasCV and optionally | |
downloads their pre-trained weights.""" | |
tf_sd_model = keras_cv.models.StableDiffusion( | |
img_height=IMG_HEIGHT, img_width=IMG_WIDTH | |
) | |
if text_encoder_weights is None: | |
tf_text_encoder = tf_sd_model.text_encoder | |
else: | |
tf_text_encoder = keras_cv.models.stable_diffusion.TextEncoder( | |
MAX_SEQ_LENGTH, download_weights=False | |
) | |
if unet_weights is None: | |
tf_unet = tf_sd_model.diffusion_model | |
else: | |
tf_unet = keras_cv.models.stable_diffusion.DiffusionModel( | |
IMG_HEIGHT, IMG_WIDTH, MAX_SEQ_LENGTH, download_weights=False | |
) | |
tf_tokenizer = tf_sd_model.tokenizer | |
if placeholder_token is not None: | |
tf_tokenizer.add_tokens(placeholder_token) | |
return tf_text_encoder, tf_unet, tf_tokenizer | |
def create_new_text_encoder(tf_text_encoder, tf_tokenizer): | |
"""Initializes a fresh text encoder in case the weights are from Textual Inversion. | |
Reference: https://keras.io/examples/generative/fine_tune_via_textual_inversion/ | |
""" | |
new_vocab_size = len(tf_tokenizer.vocab) | |
new_text_encoder = keras_cv.models.stable_diffusion.TextEncoder( | |
MAX_SEQ_LENGTH, | |
vocab_size=new_vocab_size, | |
download_weights=False, | |
) | |
old_position_weights = tf_text_encoder.layers[2].position_embedding.get_weights() | |
new_text_encoder.layers[2].position_embedding.set_weights(old_position_weights) | |
return new_text_encoder | |
def run_conversion( | |
text_encoder_weights: str = None, | |
unet_weights: str = None, | |
placeholder_token: str = None, | |
): | |
( | |
pt_text_encoder, | |
pt_tokenizer, | |
pt_vae, | |
pt_unet, | |
pt_safety_checker, | |
) = initialize_pt_models() | |
tf_text_encoder, tf_unet, tf_tokenizer = initialize_tf_models( | |
text_encoder_weights, unet_weights, placeholder_token | |
) | |
print("Pre-trained model weights downloaded.") | |
if placeholder_token is not None: | |
print("Initializing a new text encoder with the placeholder token...") | |
tf_text_encoder = create_new_text_encoder(tf_text_encoder, tf_tokenizer) | |
print("Adding the placeholder token to PT CLIPTokenizer...") | |
num_added_tokens = pt_tokenizer.add_tokens(placeholder_token) | |
if num_added_tokens == 0: | |
raise ValueError( | |
f"The tokenizer already contains the token {placeholder_token}. Please pass a different" | |
" `placeholder_token` that is not already in the tokenizer." | |
) | |
if text_encoder_weights is not None: | |
print("Loading fine-tuned text encoder weights.") | |
text_encoder_weights_path = tf.keras.utils.get_file(origin=text_encoder_weights) | |
tf_text_encoder.load_weights(text_encoder_weights_path) | |
text_encoder_state_dict_from_tf = populate_text_encoder(tf_text_encoder) | |
pt_text_encoder.load_state_dict(text_encoder_state_dict_from_tf) | |
print("Populated PT text encoder from TF weights.") | |
if unet_weights is not None: | |
print("Loading fine-tuned UNet weights.") | |
unet_weights_path = tf.keras.utils.get_file(origin=unet_weights) | |
tf_unet.load_weights(unet_weights_path) | |
unet_state_dict_from_tf = populate_unet(tf_unet) | |
pt_unet.load_state_dict(unet_state_dict_from_tf) | |
print("Populated PT UNet from TF weights.") | |
print("Weights ported, preparing StabelDiffusionPipeline...") | |
pipeline = StableDiffusionPipeline.from_pretrained( | |
PRETRAINED_CKPT, | |
unet=pt_unet, | |
text_encoder=pt_text_encoder, | |
tokenizer=pt_tokenizer, | |
vae=pt_vae, | |
safety_checker=pt_safety_checker, | |
revision=None, | |
) | |
return pipeline | |