File size: 5,248 Bytes
ddc8a59
 
3304f7d
 
 
 
89a6b3b
ddc8a59
7081a39
ddc8a59
 
 
 
 
6c04d23
ddc8a59
3304f7d
ddc8a59
 
 
 
3304f7d
ddc8a59
89a6b3b
ddc8a59
 
 
 
 
 
 
 
 
 
89a6b3b
ddc8a59
3304f7d
3bd4a93
 
 
 
 
3304f7d
 
 
3bd4a93
6c04d23
 
 
 
 
 
3bd4a93
6c04d23
 
 
 
 
 
3bd4a93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddc8a59
 
89a6b3b
 
 
 
 
 
 
 
 
 
 
 
3bd4a93
 
30552b4
ddc8a59
3304f7d
3bd4a93
 
 
 
 
 
 
 
 
 
 
 
ddc8a59
 
6c04d23
ddc8a59
7081a39
 
 
89a6b3b
ddc8a59
 
6c04d23
ddc8a59
7081a39
 
89a6b3b
 
7081a39
ddc8a59
 
 
 
89a6b3b
ddc8a59
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import keras_cv
import tensorflow as tf
from diffusers import (AutoencoderKL, StableDiffusionPipeline,
                       UNet2DConditionModel)
from diffusers.pipelines.stable_diffusion.safety_checker import \
    StableDiffusionSafetyChecker
from transformers import CLIPTextModel, CLIPTokenizer

from conversion_utils import populate_text_encoder, populate_unet

PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
REVISION = None
NON_EMA_REVISION = None
IMG_HEIGHT = IMG_WIDTH = 512
MAX_SEQ_LENGTH = 77


def initialize_pt_models():
    """Initializes the separate models of Stable Diffusion from diffusers and downloads
    their pre-trained weights."""
    pt_text_encoder = CLIPTextModel.from_pretrained(
        PRETRAINED_CKPT, subfolder="text_encoder", revision=REVISION
    )
    pt_tokenizer = CLIPTokenizer.from_pretrained(PRETRAINED_CKPT, subfolder="tokenizer")
    pt_vae = AutoencoderKL.from_pretrained(
        PRETRAINED_CKPT, subfolder="vae", revision=REVISION
    )
    pt_unet = UNet2DConditionModel.from_pretrained(
        PRETRAINED_CKPT, subfolder="unet", revision=NON_EMA_REVISION
    )
    pt_safety_checker = StableDiffusionSafetyChecker.from_pretrained(
        PRETRAINED_CKPT, subfolder="safety_checker", revision=NON_EMA_REVISION
    )

    return pt_text_encoder, pt_tokenizer, pt_vae, pt_unet, pt_safety_checker


def initialize_tf_models(
    text_encoder_weights: str, unet_weights: str, placeholder_token: str = None
):
    """Initializes the separate models of Stable Diffusion from KerasCV and optionally
    downloads their pre-trained weights."""
    tf_sd_model = keras_cv.models.StableDiffusion(
        img_height=IMG_HEIGHT, img_width=IMG_WIDTH
    )

    if text_encoder_weights is None:
        tf_text_encoder = tf_sd_model.text_encoder
    else:
        tf_text_encoder = keras_cv.models.stable_diffusion.TextEncoder(
            MAX_SEQ_LENGTH, download_weights=False
        )

    if unet_weights is None:
        tf_unet = tf_sd_model.diffusion_model
    else:
        tf_unet = keras_cv.models.stable_diffusion.DiffusionModel(
            IMG_HEIGHT, IMG_WIDTH, MAX_SEQ_LENGTH, download_weights=False
        )

    tf_tokenizer = tf_sd_model.tokenizer
    if placeholder_token is not None:
        tf_tokenizer.add_tokens(placeholder_token)

    return tf_text_encoder, tf_unet, tf_tokenizer


def create_new_text_encoder(tf_text_encoder, tf_tokenizer):
    """Initializes a fresh text encoder in case the weights are from Textual Inversion.

    Reference: https://keras.io/examples/generative/fine_tune_via_textual_inversion/
    """
    new_vocab_size = len(tf_tokenizer.vocab)
    new_text_encoder = keras_cv.models.stable_diffusion.TextEncoder(
        MAX_SEQ_LENGTH,
        vocab_size=new_vocab_size,
        download_weights=False,
    )

    old_position_weights = tf_text_encoder.layers[2].position_embedding.get_weights()
    new_text_encoder.layers[2].position_embedding.set_weights(old_position_weights)
    return new_text_encoder


def run_conversion(
    text_encoder_weights: str = None,
    unet_weights: str = None,
    placeholder_token: str = None,
):
    (
        pt_text_encoder,
        pt_tokenizer,
        pt_vae,
        pt_unet,
        pt_safety_checker,
    ) = initialize_pt_models()
    tf_text_encoder, tf_unet, tf_tokenizer = initialize_tf_models(
        text_encoder_weights, unet_weights, placeholder_token
    )
    print("Pre-trained model weights downloaded.")

    if placeholder_token is not None:
        print("Initializing a new text encoder with the placeholder token...")
        tf_text_encoder = create_new_text_encoder(tf_text_encoder, tf_tokenizer)

        print("Adding the placeholder token to PT CLIPTokenizer...")
        num_added_tokens = pt_tokenizer.add_tokens(placeholder_token)
        if num_added_tokens == 0:
            raise ValueError(
                f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
                " `placeholder_token` that is not already in the tokenizer."
            )

    if text_encoder_weights is not None:
        print("Loading fine-tuned text encoder weights.")
        text_encoder_weights_path = tf.keras.utils.get_file(origin=text_encoder_weights)
        tf_text_encoder.load_weights(text_encoder_weights_path)
        text_encoder_state_dict_from_tf = populate_text_encoder(tf_text_encoder)
        pt_text_encoder.load_state_dict(text_encoder_state_dict_from_tf)
        print("Populated PT text encoder from TF weights.")

    if unet_weights is not None:
        print("Loading fine-tuned UNet weights.")
        unet_weights_path = tf.keras.utils.get_file(origin=unet_weights)
        tf_unet.load_weights(unet_weights_path)
        unet_state_dict_from_tf = populate_unet(tf_unet)
        pt_unet.load_state_dict(unet_state_dict_from_tf)
        print("Populated PT UNet from TF weights.")

    print("Weights ported, preparing StabelDiffusionPipeline...")
    pipeline = StableDiffusionPipeline.from_pretrained(
        PRETRAINED_CKPT,
        unet=pt_unet,
        text_encoder=pt_text_encoder,
        tokenizer=pt_tokenizer,
        vae=pt_vae,
        safety_checker=pt_safety_checker,
        revision=None,
    )
    return pipeline