diffusers-sdxl-controlnet / scripts /convert_unclip_txt2img_to_image_variation.py
svjack's picture
Upload 1392 files
43b7e92 verified
raw
history blame
1.31 kB
import argparse
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from diffusers import UnCLIPImageVariationPipeline, UnCLIPPipeline
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
parser.add_argument(
"--txt2img_unclip",
default="kakaobrain/karlo-v1-alpha",
type=str,
required=False,
help="The pretrained txt2img unclip.",
)
args = parser.parse_args()
txt2img = UnCLIPPipeline.from_pretrained(args.txt2img_unclip)
feature_extractor = CLIPImageProcessor()
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
img2img = UnCLIPImageVariationPipeline(
decoder=txt2img.decoder,
text_encoder=txt2img.text_encoder,
tokenizer=txt2img.tokenizer,
text_proj=txt2img.text_proj,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
super_res_first=txt2img.super_res_first,
super_res_last=txt2img.super_res_last,
decoder_scheduler=txt2img.decoder_scheduler,
super_res_scheduler=txt2img.super_res_scheduler,
)
img2img.save_pretrained(args.dump_path)