File size: 4,171 Bytes
ddc8a59
 
3304f7d
 
 
 
 
ddc8a59
3304f7d
 
ddc8a59
 
 
 
 
6c04d23
ddc8a59
3304f7d
ddc8a59
 
 
 
3304f7d
ddc8a59
 
 
 
 
 
 
 
 
 
 
 
 
3304f7d
6c04d23
ddc8a59
 
3304f7d
 
 
6c04d23
 
 
 
 
 
3304f7d
6c04d23
 
 
 
 
 
ddc8a59
 
 
 
 
30552b4
 
 
ddc8a59
3304f7d
ddc8a59
 
6c04d23
ddc8a59
 
 
6c04d23
ddc8a59
 
 
 
94913a9
ddc8a59
 
 
 
 
 
9435d99
ddc8a59
 
94913a9
 
 
 
ddc8a59
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import keras_cv
import tensorflow as tf
from diffusers import (AutoencoderKL, StableDiffusionPipeline,
                       UNet2DConditionModel)
from diffusers.pipelines.stable_diffusion.safety_checker import \
    StableDiffusionSafetyChecker
from transformers import CLIPTextModel

from conversion_utils import (populate_text_encoder, populate_unet,
                              run_assertion)

PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
REVISION = None
NON_EMA_REVISION = None
IMG_HEIGHT = IMG_WIDTH = 512
MAX_SEQ_LENGTH = 77


def initialize_pt_models():
    """Initializes the separate models of Stable Diffusion from diffusers and downloads
    their pre-trained weights."""
    pt_text_encoder = CLIPTextModel.from_pretrained(
        PRETRAINED_CKPT, subfolder="text_encoder", revision=REVISION
    )
    pt_vae = AutoencoderKL.from_pretrained(
        PRETRAINED_CKPT, subfolder="vae", revision=REVISION
    )
    pt_unet = UNet2DConditionModel.from_pretrained(
        PRETRAINED_CKPT, subfolder="unet", revision=NON_EMA_REVISION
    )
    pt_safety_checker = StableDiffusionSafetyChecker.from_pretrained(
        PRETRAINED_CKPT, subfolder="safety_checker", revision=NON_EMA_REVISION
    )

    return pt_text_encoder, pt_vae, pt_unet, pt_safety_checker


def initialize_tf_models(text_encoder_weights: str, unet_weights: str):
    """Initializes the separate models of Stable Diffusion from KerasCV and downloads
    their pre-trained weights."""
    tf_sd_model = keras_cv.models.StableDiffusion(
        img_height=IMG_HEIGHT, img_width=IMG_WIDTH
    )
    if text_encoder_weights is None:
        tf_text_encoder = tf_sd_model.text_encoder
    else:
        tf_text_encoder = keras_cv.models.stable_diffusion.TextEncoder(
            MAX_SEQ_LENGTH, download_weights=False
        )
    tf_vae = tf_sd_model.image_encoder
    if unet_weights is None:
        tf_unet = tf_sd_model.diffusion_model
    else:
        tf_unet = keras_cv.models.stable_diffusion.DiffusionModel(
            IMG_HEIGHT, IMG_WIDTH, MAX_SEQ_LENGTH, download_weights=False
        )
    return tf_sd_model, tf_text_encoder, tf_vae, tf_unet


def run_conversion(text_encoder_weights: str = None, unet_weights: str = None):
    pt_text_encoder, pt_vae, pt_unet, pt_safety_checker = initialize_pt_models()
    tf_sd_model, tf_text_encoder, tf_vae, tf_unet = initialize_tf_models(
        text_encoder_weights, unet_weights
    )
    print("Pre-trained model weights downloaded.")

    if text_encoder_weights is not None:
        print("Loading fine-tuned text encoder weights.")
        text_encoder_weights_path = tf.keras.utils.get_file(origin=text_encoder_weights)
        tf_text_encoder.load_weights(text_encoder_weights_path)
    if unet_weights is not None:
        print("Loading fine-tuned UNet weights.")
        unet_weights_path = tf.keras.utils.get_file(origin=unet_weights)
        tf_unet.load_weights(unet_weights_path)

    text_encoder_state_dict_from_tf = populate_text_encoder(tf_text_encoder)
    unet_state_dict_from_tf = populate_unet(tf_unet)
    print("Conversion done, now running optional assertions...")

    # Since we cannot compare the fine-tuned weights.
    if text_encoder_weights is None:
        text_encoder_state_dict_from_pt = pt_text_encoder.state_dict()
        run_assertion(text_encoder_state_dict_from_pt, text_encoder_state_dict_from_tf)
    if unet_weights is None:
        unet_state_dict_from_pt = pt_unet.state_dict()
        run_assertion(unet_state_dict_from_pt, unet_state_dict_from_tf)

    if text_encoder_weights is None or unet_weights is None:
        print(
            "Assertions successful, populating the converted parameters into the diffusers models..."
        )
    pt_text_encoder.load_state_dict(text_encoder_state_dict_from_tf)
    pt_unet.load_state_dict(unet_state_dict_from_tf)

    print("Parameters ported, preparing StabelDiffusionPipeline...")
    pipeline = StableDiffusionPipeline.from_pretrained(
        PRETRAINED_CKPT,
        unet=pt_unet,
        text_encoder=pt_text_encoder,
        vae=pt_vae,
        safety_checker=pt_safety_checker,
        revision=None,
    )
    return pipeline