Spaces:
Build error
Build error
File size: 5,248 Bytes
ddc8a59 3304f7d 89a6b3b ddc8a59 7081a39 ddc8a59 6c04d23 ddc8a59 3304f7d ddc8a59 3304f7d ddc8a59 89a6b3b ddc8a59 89a6b3b ddc8a59 3304f7d 3bd4a93 3304f7d 3bd4a93 6c04d23 3bd4a93 6c04d23 3bd4a93 ddc8a59 89a6b3b 3bd4a93 30552b4 ddc8a59 3304f7d 3bd4a93 ddc8a59 6c04d23 ddc8a59 7081a39 89a6b3b ddc8a59 6c04d23 ddc8a59 7081a39 89a6b3b 7081a39 ddc8a59 89a6b3b ddc8a59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import keras_cv
import tensorflow as tf
from diffusers import (AutoencoderKL, StableDiffusionPipeline,
UNet2DConditionModel)
from diffusers.pipelines.stable_diffusion.safety_checker import \
StableDiffusionSafetyChecker
from transformers import CLIPTextModel, CLIPTokenizer
from conversion_utils import populate_text_encoder, populate_unet
PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
REVISION = None
NON_EMA_REVISION = None
IMG_HEIGHT = IMG_WIDTH = 512
MAX_SEQ_LENGTH = 77
def initialize_pt_models():
"""Initializes the separate models of Stable Diffusion from diffusers and downloads
their pre-trained weights."""
pt_text_encoder = CLIPTextModel.from_pretrained(
PRETRAINED_CKPT, subfolder="text_encoder", revision=REVISION
)
pt_tokenizer = CLIPTokenizer.from_pretrained(PRETRAINED_CKPT, subfolder="tokenizer")
pt_vae = AutoencoderKL.from_pretrained(
PRETRAINED_CKPT, subfolder="vae", revision=REVISION
)
pt_unet = UNet2DConditionModel.from_pretrained(
PRETRAINED_CKPT, subfolder="unet", revision=NON_EMA_REVISION
)
pt_safety_checker = StableDiffusionSafetyChecker.from_pretrained(
PRETRAINED_CKPT, subfolder="safety_checker", revision=NON_EMA_REVISION
)
return pt_text_encoder, pt_tokenizer, pt_vae, pt_unet, pt_safety_checker
def initialize_tf_models(
text_encoder_weights: str, unet_weights: str, placeholder_token: str = None
):
"""Initializes the separate models of Stable Diffusion from KerasCV and optionally
downloads their pre-trained weights."""
tf_sd_model = keras_cv.models.StableDiffusion(
img_height=IMG_HEIGHT, img_width=IMG_WIDTH
)
if text_encoder_weights is None:
tf_text_encoder = tf_sd_model.text_encoder
else:
tf_text_encoder = keras_cv.models.stable_diffusion.TextEncoder(
MAX_SEQ_LENGTH, download_weights=False
)
if unet_weights is None:
tf_unet = tf_sd_model.diffusion_model
else:
tf_unet = keras_cv.models.stable_diffusion.DiffusionModel(
IMG_HEIGHT, IMG_WIDTH, MAX_SEQ_LENGTH, download_weights=False
)
tf_tokenizer = tf_sd_model.tokenizer
if placeholder_token is not None:
tf_tokenizer.add_tokens(placeholder_token)
return tf_text_encoder, tf_unet, tf_tokenizer
def create_new_text_encoder(tf_text_encoder, tf_tokenizer):
"""Initializes a fresh text encoder in case the weights are from Textual Inversion.
Reference: https://keras.io/examples/generative/fine_tune_via_textual_inversion/
"""
new_vocab_size = len(tf_tokenizer.vocab)
new_text_encoder = keras_cv.models.stable_diffusion.TextEncoder(
MAX_SEQ_LENGTH,
vocab_size=new_vocab_size,
download_weights=False,
)
old_position_weights = tf_text_encoder.layers[2].position_embedding.get_weights()
new_text_encoder.layers[2].position_embedding.set_weights(old_position_weights)
return new_text_encoder
def run_conversion(
text_encoder_weights: str = None,
unet_weights: str = None,
placeholder_token: str = None,
):
(
pt_text_encoder,
pt_tokenizer,
pt_vae,
pt_unet,
pt_safety_checker,
) = initialize_pt_models()
tf_text_encoder, tf_unet, tf_tokenizer = initialize_tf_models(
text_encoder_weights, unet_weights, placeholder_token
)
print("Pre-trained model weights downloaded.")
if placeholder_token is not None:
print("Initializing a new text encoder with the placeholder token...")
tf_text_encoder = create_new_text_encoder(tf_text_encoder, tf_tokenizer)
print("Adding the placeholder token to PT CLIPTokenizer...")
num_added_tokens = pt_tokenizer.add_tokens(placeholder_token)
if num_added_tokens == 0:
raise ValueError(
f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer."
)
if text_encoder_weights is not None:
print("Loading fine-tuned text encoder weights.")
text_encoder_weights_path = tf.keras.utils.get_file(origin=text_encoder_weights)
tf_text_encoder.load_weights(text_encoder_weights_path)
text_encoder_state_dict_from_tf = populate_text_encoder(tf_text_encoder)
pt_text_encoder.load_state_dict(text_encoder_state_dict_from_tf)
print("Populated PT text encoder from TF weights.")
if unet_weights is not None:
print("Loading fine-tuned UNet weights.")
unet_weights_path = tf.keras.utils.get_file(origin=unet_weights)
tf_unet.load_weights(unet_weights_path)
unet_state_dict_from_tf = populate_unet(tf_unet)
pt_unet.load_state_dict(unet_state_dict_from_tf)
print("Populated PT UNet from TF weights.")
print("Weights ported, preparing StabelDiffusionPipeline...")
pipeline = StableDiffusionPipeline.from_pretrained(
PRETRAINED_CKPT,
unet=pt_unet,
text_encoder=pt_text_encoder,
tokenizer=pt_tokenizer,
vae=pt_vae,
safety_checker=pt_safety_checker,
revision=None,
)
return pipeline
|