#!/usr/bin/env python # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and import argparse import typing from typing import Optional, Union import torch from PIL import Image from torchvision import transforms # type: ignore from diffusers.image_processor import VaeImageProcessor from diffusers.models.autoencoders.autoencoder_kl import ( AutoencoderKL, AutoencoderKLOutput, ) from diffusers.models.autoencoders.autoencoder_tiny import ( AutoencoderTiny, AutoencoderTinyOutput, ) from diffusers.models.autoencoders.vae import DecoderOutput SupportedAutoencoder = Union[AutoencoderKL, AutoencoderTiny] def load_vae_model( *, device: torch.device, model_name_or_path: str, revision: Optional[str], variant: Optional[str], # NOTE: use subfolder="vae" if the pointed model is for stable diffusion as a whole instead of just the VAE subfolder: Optional[str], use_tiny_nn: bool, ) -> SupportedAutoencoder: if use_tiny_nn: # NOTE: These scaling factors don't have to be the same as each other. down_scale = 2 up_scale = 2 vae = AutoencoderTiny.from_pretrained( # type: ignore model_name_or_path, subfolder=subfolder, revision=revision, variant=variant, downscaling_scaling_factor=down_scale, upsampling_scaling_factor=up_scale, ) assert isinstance(vae, AutoencoderTiny) else: vae = AutoencoderKL.from_pretrained( # type: ignore model_name_or_path, subfolder=subfolder, revision=revision, variant=variant, ) assert isinstance(vae, AutoencoderKL) vae = vae.to(device) vae.eval() # Set the model to inference mode return vae def pil_to_nhwc( *, device: torch.device, image: Image.Image, ) -> torch.Tensor: assert image.mode == "RGB" transform = transforms.ToTensor() nhwc = transform(image).unsqueeze(0).to(device) # type: ignore assert isinstance(nhwc, torch.Tensor) return nhwc def nhwc_to_pil( *, nhwc: torch.Tensor, ) -> Image.Image: assert nhwc.shape[0] == 1 hwc = nhwc.squeeze(0).cpu() return transforms.ToPILImage()(hwc) # type: ignore def concatenate_images( *, left: Image.Image, right: Image.Image, vertical: bool = False, ) -> Image.Image: width1, height1 = left.size width2, height2 = right.size if vertical: total_height = height1 + height2 max_width = max(width1, width2) new_image = Image.new("RGB", (max_width, total_height)) new_image.paste(left, (0, 0)) new_image.paste(right, (0, height1)) else: total_width = width1 + width2 max_height = max(height1, height2) new_image = Image.new("RGB", (total_width, max_height)) new_image.paste(left, (0, 0)) new_image.paste(right, (width1, 0)) return new_image def to_latent( *, rgb_nchw: torch.Tensor, vae: SupportedAutoencoder, ) -> torch.Tensor: rgb_nchw = VaeImageProcessor.normalize(rgb_nchw) # type: ignore encoding_nchw = vae.encode(typing.cast(torch.FloatTensor, rgb_nchw)) if isinstance(encoding_nchw, AutoencoderKLOutput): latent = encoding_nchw.latent_dist.sample() # type: ignore assert isinstance(latent, torch.Tensor) elif isinstance(encoding_nchw, AutoencoderTinyOutput): latent = encoding_nchw.latents do_internal_vae_scaling = False # Is this needed? if do_internal_vae_scaling: latent = vae.scale_latents(latent).mul(255).round().byte() # type: ignore latent = vae.unscale_latents(latent / 255.0) # type: ignore assert isinstance(latent, torch.Tensor) else: assert False, f"Unknown encoding type: {type(encoding_nchw)}" return latent def from_latent( *, latent_nchw: torch.Tensor, vae: SupportedAutoencoder, ) -> torch.Tensor: decoding_nchw = vae.decode(latent_nchw) # type: ignore assert isinstance(decoding_nchw, DecoderOutput) rgb_nchw = VaeImageProcessor.denormalize(decoding_nchw.sample) # type: ignore assert isinstance(rgb_nchw, torch.Tensor) return rgb_nchw def main_kwargs( *, device: torch.device, input_image_path: str, pretrained_model_name_or_path: str, revision: Optional[str], variant: Optional[str], subfolder: Optional[str], use_tiny_nn: bool, ) -> None: vae = load_vae_model( device=device, model_name_or_path=pretrained_model_name_or_path, revision=revision, variant=variant, subfolder=subfolder, use_tiny_nn=use_tiny_nn, ) original_pil = Image.open(input_image_path).convert("RGB") original_image = pil_to_nhwc( device=device, image=original_pil, ) print(f"Original image shape: {original_image.shape}") reconstructed_image: Optional[torch.Tensor] = None with torch.no_grad(): latent_image = to_latent(rgb_nchw=original_image, vae=vae) print(f"Latent shape: {latent_image.shape}") reconstructed_image = from_latent(latent_nchw=latent_image, vae=vae) reconstructed_pil = nhwc_to_pil(nhwc=reconstructed_image) combined_image = concatenate_images( left=original_pil, right=reconstructed_pil, vertical=False, ) combined_image.show("Original | Reconstruction") print(f"Reconstructed image shape: {reconstructed_image.shape}") def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Inference with VAE") parser.add_argument( "--input_image", type=str, required=True, help="Path to the input image for inference.", ) parser.add_argument( "--pretrained_model_name_or_path", type=str, required=True, help="Path to pretrained VAE model.", ) parser.add_argument( "--revision", type=str, default=None, help="Model version.", ) parser.add_argument( "--variant", type=str, default=None, help="Model file variant, e.g., 'fp16'.", ) parser.add_argument( "--subfolder", type=str, default=None, help="Subfolder in the model file.", ) parser.add_argument( "--use_cuda", action="store_true", help="Use CUDA if available.", ) parser.add_argument( "--use_tiny_nn", action="store_true", help="Use tiny neural network.", ) return parser.parse_args() # EXAMPLE USAGE: # # python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "runwayml/stable-diffusion-v1-5" --subfolder "vae" --input_image "foo.png" # # python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "madebyollin/taesd" --use_tiny_nn --input_image "foo.png" # def main_cli() -> None: args = parse_args() input_image_path = args.input_image assert isinstance(input_image_path, str) pretrained_model_name_or_path = args.pretrained_model_name_or_path assert isinstance(pretrained_model_name_or_path, str) revision = args.revision assert isinstance(revision, (str, type(None))) variant = args.variant assert isinstance(variant, (str, type(None))) subfolder = args.subfolder assert isinstance(subfolder, (str, type(None))) use_cuda = args.use_cuda assert isinstance(use_cuda, bool) use_tiny_nn = args.use_tiny_nn assert isinstance(use_tiny_nn, bool) device = torch.device("cuda" if use_cuda else "cpu") main_kwargs( device=device, input_image_path=input_image_path, pretrained_model_name_or_path=pretrained_model_name_or_path, revision=revision, variant=variant, subfolder=subfolder, use_tiny_nn=use_tiny_nn, ) if __name__ == "__main__": main_cli()