Spaces:
Build error
Build error
File size: 4,171 Bytes
ddc8a59 3304f7d ddc8a59 3304f7d ddc8a59 6c04d23 ddc8a59 3304f7d ddc8a59 3304f7d ddc8a59 3304f7d 6c04d23 ddc8a59 3304f7d 6c04d23 3304f7d 6c04d23 ddc8a59 30552b4 ddc8a59 3304f7d ddc8a59 6c04d23 ddc8a59 6c04d23 ddc8a59 94913a9 ddc8a59 9435d99 ddc8a59 94913a9 ddc8a59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import keras_cv
import tensorflow as tf
from diffusers import (AutoencoderKL, StableDiffusionPipeline,
UNet2DConditionModel)
from diffusers.pipelines.stable_diffusion.safety_checker import \
StableDiffusionSafetyChecker
from transformers import CLIPTextModel
from conversion_utils import (populate_text_encoder, populate_unet,
run_assertion)
PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
REVISION = None
NON_EMA_REVISION = None
IMG_HEIGHT = IMG_WIDTH = 512
MAX_SEQ_LENGTH = 77
def initialize_pt_models():
"""Initializes the separate models of Stable Diffusion from diffusers and downloads
their pre-trained weights."""
pt_text_encoder = CLIPTextModel.from_pretrained(
PRETRAINED_CKPT, subfolder="text_encoder", revision=REVISION
)
pt_vae = AutoencoderKL.from_pretrained(
PRETRAINED_CKPT, subfolder="vae", revision=REVISION
)
pt_unet = UNet2DConditionModel.from_pretrained(
PRETRAINED_CKPT, subfolder="unet", revision=NON_EMA_REVISION
)
pt_safety_checker = StableDiffusionSafetyChecker.from_pretrained(
PRETRAINED_CKPT, subfolder="safety_checker", revision=NON_EMA_REVISION
)
return pt_text_encoder, pt_vae, pt_unet, pt_safety_checker
def initialize_tf_models(text_encoder_weights: str, unet_weights: str):
"""Initializes the separate models of Stable Diffusion from KerasCV and downloads
their pre-trained weights."""
tf_sd_model = keras_cv.models.StableDiffusion(
img_height=IMG_HEIGHT, img_width=IMG_WIDTH
)
if text_encoder_weights is None:
tf_text_encoder = tf_sd_model.text_encoder
else:
tf_text_encoder = keras_cv.models.stable_diffusion.TextEncoder(
MAX_SEQ_LENGTH, download_weights=False
)
tf_vae = tf_sd_model.image_encoder
if unet_weights is None:
tf_unet = tf_sd_model.diffusion_model
else:
tf_unet = keras_cv.models.stable_diffusion.DiffusionModel(
IMG_HEIGHT, IMG_WIDTH, MAX_SEQ_LENGTH, download_weights=False
)
return tf_sd_model, tf_text_encoder, tf_vae, tf_unet
def run_conversion(text_encoder_weights: str = None, unet_weights: str = None):
pt_text_encoder, pt_vae, pt_unet, pt_safety_checker = initialize_pt_models()
tf_sd_model, tf_text_encoder, tf_vae, tf_unet = initialize_tf_models(
text_encoder_weights, unet_weights
)
print("Pre-trained model weights downloaded.")
if text_encoder_weights is not None:
print("Loading fine-tuned text encoder weights.")
text_encoder_weights_path = tf.keras.utils.get_file(origin=text_encoder_weights)
tf_text_encoder.load_weights(text_encoder_weights_path)
if unet_weights is not None:
print("Loading fine-tuned UNet weights.")
unet_weights_path = tf.keras.utils.get_file(origin=unet_weights)
tf_unet.load_weights(unet_weights_path)
text_encoder_state_dict_from_tf = populate_text_encoder(tf_text_encoder)
unet_state_dict_from_tf = populate_unet(tf_unet)
print("Conversion done, now running optional assertions...")
# Since we cannot compare the fine-tuned weights.
if text_encoder_weights is None:
text_encoder_state_dict_from_pt = pt_text_encoder.state_dict()
run_assertion(text_encoder_state_dict_from_pt, text_encoder_state_dict_from_tf)
if unet_weights is None:
unet_state_dict_from_pt = pt_unet.state_dict()
run_assertion(unet_state_dict_from_pt, unet_state_dict_from_tf)
if text_encoder_weights is None or unet_weights is None:
print(
"Assertions successful, populating the converted parameters into the diffusers models..."
)
pt_text_encoder.load_state_dict(text_encoder_state_dict_from_tf)
pt_unet.load_state_dict(unet_state_dict_from_tf)
print("Parameters ported, preparing StabelDiffusionPipeline...")
pipeline = StableDiffusionPipeline.from_pretrained(
PRETRAINED_CKPT,
unet=pt_unet,
text_encoder=pt_text_encoder,
vae=pt_vae,
safety_checker=pt_safety_checker,
revision=None,
)
return pipeline
|