sayakpaul's picture
sayakpaul HF staff
fix: textual inversion utility.
3bd4a93
raw
history blame
5.25 kB
import keras_cv
import tensorflow as tf
from diffusers import (AutoencoderKL, StableDiffusionPipeline,
UNet2DConditionModel)
from diffusers.pipelines.stable_diffusion.safety_checker import \
StableDiffusionSafetyChecker
from transformers import CLIPTextModel, CLIPTokenizer
from conversion_utils import populate_text_encoder, populate_unet
PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
REVISION = None
NON_EMA_REVISION = None
IMG_HEIGHT = IMG_WIDTH = 512
MAX_SEQ_LENGTH = 77
def initialize_pt_models():
"""Initializes the separate models of Stable Diffusion from diffusers and downloads
their pre-trained weights."""
pt_text_encoder = CLIPTextModel.from_pretrained(
PRETRAINED_CKPT, subfolder="text_encoder", revision=REVISION
)
pt_tokenizer = CLIPTokenizer.from_pretrained(PRETRAINED_CKPT, subfolder="tokenizer")
pt_vae = AutoencoderKL.from_pretrained(
PRETRAINED_CKPT, subfolder="vae", revision=REVISION
)
pt_unet = UNet2DConditionModel.from_pretrained(
PRETRAINED_CKPT, subfolder="unet", revision=NON_EMA_REVISION
)
pt_safety_checker = StableDiffusionSafetyChecker.from_pretrained(
PRETRAINED_CKPT, subfolder="safety_checker", revision=NON_EMA_REVISION
)
return pt_text_encoder, pt_tokenizer, pt_vae, pt_unet, pt_safety_checker
def initialize_tf_models(
text_encoder_weights: str, unet_weights: str, placeholder_token: str = None
):
"""Initializes the separate models of Stable Diffusion from KerasCV and optionally
downloads their pre-trained weights."""
tf_sd_model = keras_cv.models.StableDiffusion(
img_height=IMG_HEIGHT, img_width=IMG_WIDTH
)
if text_encoder_weights is None:
tf_text_encoder = tf_sd_model.text_encoder
else:
tf_text_encoder = keras_cv.models.stable_diffusion.TextEncoder(
MAX_SEQ_LENGTH, download_weights=False
)
if unet_weights is None:
tf_unet = tf_sd_model.diffusion_model
else:
tf_unet = keras_cv.models.stable_diffusion.DiffusionModel(
IMG_HEIGHT, IMG_WIDTH, MAX_SEQ_LENGTH, download_weights=False
)
tf_tokenizer = tf_sd_model.tokenizer
if placeholder_token is not None:
tf_tokenizer.add_tokens(placeholder_token)
return tf_text_encoder, tf_unet, tf_tokenizer
def create_new_text_encoder(tf_text_encoder, tf_tokenizer):
"""Initializes a fresh text encoder in case the weights are from Textual Inversion.
Reference: https://keras.io/examples/generative/fine_tune_via_textual_inversion/
"""
new_vocab_size = len(tf_tokenizer.vocab)
new_text_encoder = keras_cv.models.stable_diffusion.TextEncoder(
MAX_SEQ_LENGTH,
vocab_size=new_vocab_size,
download_weights=False,
)
old_position_weights = tf_text_encoder.layers[2].position_embedding.get_weights()
new_text_encoder.layers[2].position_embedding.set_weights(old_position_weights)
return new_text_encoder
def run_conversion(
text_encoder_weights: str = None,
unet_weights: str = None,
placeholder_token: str = None,
):
(
pt_text_encoder,
pt_tokenizer,
pt_vae,
pt_unet,
pt_safety_checker,
) = initialize_pt_models()
tf_text_encoder, tf_unet, tf_tokenizer = initialize_tf_models(
text_encoder_weights, unet_weights, placeholder_token
)
print("Pre-trained model weights downloaded.")
if placeholder_token is not None:
print("Initializing a new text encoder with the placeholder token...")
tf_text_encoder = create_new_text_encoder(tf_text_encoder, tf_tokenizer)
print("Adding the placeholder token to PT CLIPTokenizer...")
num_added_tokens = pt_tokenizer.add_tokens(placeholder_token)
if num_added_tokens == 0:
raise ValueError(
f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer."
)
if text_encoder_weights is not None:
print("Loading fine-tuned text encoder weights.")
text_encoder_weights_path = tf.keras.utils.get_file(origin=text_encoder_weights)
tf_text_encoder.load_weights(text_encoder_weights_path)
text_encoder_state_dict_from_tf = populate_text_encoder(tf_text_encoder)
pt_text_encoder.load_state_dict(text_encoder_state_dict_from_tf)
print("Populated PT text encoder from TF weights.")
if unet_weights is not None:
print("Loading fine-tuned UNet weights.")
unet_weights_path = tf.keras.utils.get_file(origin=unet_weights)
tf_unet.load_weights(unet_weights_path)
unet_state_dict_from_tf = populate_unet(tf_unet)
pt_unet.load_state_dict(unet_state_dict_from_tf)
print("Populated PT UNet from TF weights.")
print("Weights ported, preparing StabelDiffusionPipeline...")
pipeline = StableDiffusionPipeline.from_pretrained(
PRETRAINED_CKPT,
unet=pt_unet,
text_encoder=pt_text_encoder,
tokenizer=pt_tokenizer,
vae=pt_vae,
safety_checker=pt_safety_checker,
revision=None,
)
return pipeline