Spaces:
Build error
Build error
improve runtime.
Browse files- app.py +1 -1
- convert.py +16 -7
app.py
CHANGED
@@ -27,7 +27,7 @@ def run(hf_token, text_encoder_weights, unet_weights, repo_prefix):
|
|
27 |
text_encoder_weights = None
|
28 |
if unet_weights == "":
|
29 |
unet_weights = None
|
30 |
-
|
31 |
pipeline = run_conversion(text_encoder_weights, unet_weights)
|
32 |
output_path = "kerascv_sd_diffusers_pipeline"
|
33 |
pipeline.save_pretrained(output_path)
|
|
|
27 |
text_encoder_weights = None
|
28 |
if unet_weights == "":
|
29 |
unet_weights = None
|
30 |
+
|
31 |
pipeline = run_conversion(text_encoder_weights, unet_weights)
|
32 |
output_path = "kerascv_sd_diffusers_pipeline"
|
33 |
pipeline.save_pretrained(output_path)
|
convert.py
CHANGED
@@ -13,6 +13,7 @@ PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
|
|
13 |
REVISION = None
|
14 |
NON_EMA_REVISION = None
|
15 |
IMG_HEIGHT = IMG_WIDTH = 512
|
|
|
16 |
|
17 |
|
18 |
def initialize_pt_models():
|
@@ -34,17 +35,25 @@ def initialize_pt_models():
|
|
34 |
return pt_text_encoder, pt_vae, pt_unet, pt_safety_checker
|
35 |
|
36 |
|
37 |
-
def initialize_tf_models():
|
38 |
"""Initializes the separate models of Stable Diffusion from KerasCV and downloads
|
39 |
their pre-trained weights."""
|
40 |
tf_sd_model = keras_cv.models.StableDiffusion(
|
41 |
img_height=IMG_HEIGHT, img_width=IMG_WIDTH
|
42 |
)
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
46 |
tf_vae = tf_sd_model.image_encoder
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
48 |
return tf_sd_model, tf_text_encoder, tf_vae, tf_unet
|
49 |
|
50 |
|
@@ -55,11 +64,11 @@ def run_conversion(text_encoder_weights: str = None, unet_weights: str = None):
|
|
55 |
|
56 |
if text_encoder_weights is not None:
|
57 |
print("Loading fine-tuned text encoder weights.")
|
58 |
-
text_encoder_weights_path = tf.keras.utils.get_file(text_encoder_weights)
|
59 |
tf_text_encoder.load_weights(text_encoder_weights_path)
|
60 |
if unet_weights is not None:
|
61 |
print("Loading fine-tuned UNet weights.")
|
62 |
-
unet_weights_path = tf.keras.utils.get_file(unet_weights)
|
63 |
tf_unet.load_weights(unet_weights_path)
|
64 |
|
65 |
text_encoder_state_dict_from_tf = populate_text_encoder(tf_text_encoder)
|
|
|
13 |
REVISION = None
|
14 |
NON_EMA_REVISION = None
|
15 |
IMG_HEIGHT = IMG_WIDTH = 512
|
16 |
+
MAX_SEQ_LENGTH = 77
|
17 |
|
18 |
|
19 |
def initialize_pt_models():
|
|
|
35 |
return pt_text_encoder, pt_vae, pt_unet, pt_safety_checker
|
36 |
|
37 |
|
38 |
+
def initialize_tf_models(text_encoder_weights: str, unet_weights: str):
|
39 |
"""Initializes the separate models of Stable Diffusion from KerasCV and downloads
|
40 |
their pre-trained weights."""
|
41 |
tf_sd_model = keras_cv.models.StableDiffusion(
|
42 |
img_height=IMG_HEIGHT, img_width=IMG_WIDTH
|
43 |
)
|
44 |
+
if text_encoder_weights is None:
|
45 |
+
tf_text_encoder = tf_sd_model.text_encoder
|
46 |
+
else:
|
47 |
+
tf_text_encoder = keras_cv.models.stable_diffusion.TextEncoder(
|
48 |
+
MAX_SEQ_LENGTH, download_weights=False
|
49 |
+
)
|
50 |
tf_vae = tf_sd_model.image_encoder
|
51 |
+
if unet_weights is None:
|
52 |
+
tf_unet = tf_sd_model.diffusion_model
|
53 |
+
else:
|
54 |
+
tf_unet = keras_cv.models.stable_diffusion.DiffusionModel(
|
55 |
+
IMG_HEIGHT, IMG_WIDTH, MAX_SEQ_LENGTH, download_weights=False
|
56 |
+
)
|
57 |
return tf_sd_model, tf_text_encoder, tf_vae, tf_unet
|
58 |
|
59 |
|
|
|
64 |
|
65 |
if text_encoder_weights is not None:
|
66 |
print("Loading fine-tuned text encoder weights.")
|
67 |
+
text_encoder_weights_path = tf.keras.utils.get_file(origin=text_encoder_weights)
|
68 |
tf_text_encoder.load_weights(text_encoder_weights_path)
|
69 |
if unet_weights is not None:
|
70 |
print("Loading fine-tuned UNet weights.")
|
71 |
+
unet_weights_path = tf.keras.utils.get_file(origin=unet_weights)
|
72 |
tf_unet.load_weights(unet_weights_path)
|
73 |
|
74 |
text_encoder_state_dict_from_tf = populate_text_encoder(tf_text_encoder)
|