import os import spaces import torch import random import numpy as np import gradio as gr from glob import glob from datetime import datetime from diffusers import StableDiffusionPipeline,AutoencoderKL from diffusers import DDIMScheduler, LCMScheduler, EulerDiscreteScheduler import torch.nn.functional as F from PIL import Image,ImageDraw from utils.pipeline import ZePoPipeline from utils.attn_control import AttentionStyle from torchvision.utils import save_image import utils.ptp_utils as ptp_utils import torchvision.transforms as transforms try: import xformers is_xformers = True except ImportError: is_xformers = False css = """ .toolbutton { margin-buttom: 0em 0em 0em 0em; max-width: 2.5em; min-width: 2.5em !important; height: 2.5em; } """ # import sys # sys.setrecursionlimit(100000) class GlobalText: def __init__(self): # config dirs self.basedir = os.getcwd() self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion") self.personalized_model_dir = './models/Stable-diffusion' self.lora_model_dir = './models/Lora' self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) self.savedir_sample = os.path.join(self.savedir, "sample") # self.savedir_mask = os.path.join(self.savedir, "mask") self.stable_diffusion_list = ["SimianLuo/LCM_Dreamshaper_v7"] self.personalized_model_list = [] self.lora_model_list = [] self.tokenizer = None self.text_encoder = None self.vae = None self.unet = None self.pipeline = None self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 self.lora_model_state_dict = {} self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") self.nsfw_image = Image.open('./data/nsfw.jpg') # to float in [0,1] def init_source_image_path(self, source_path): self.source_paths = sorted(glob(os.path.join(source_path, '*'))) self.max_source_index = len(self.source_paths) // 12 return self.source_paths[0:12] def init_style_image_path(self, style_path): self.style_paths = sorted(glob(os.path.join(style_path, '*'))) self.max_style_index = len(self.style_paths) // 12 return self.style_paths[0:12] def init_results_image_path(self): results_paths = [os.path.join(self.savedir_sample, file) for file in os.listdir(self.savedir_sample)] self.results_paths = sorted(results_paths, key=os.path.getctime, reverse=True) self.max_results_index = len(self.results_paths) // 12 return self.results_paths[0:12] @spaces.GPU def load_base_pipeline(self, model_path): time_start = datetime.now() self.scheduler = 'LCM' scheduler = LCMScheduler.from_pretrained(model_path, subfolder="scheduler") self.pipeline = ZePoPipeline.from_pretrained(model_path,scheduler=scheduler,torch_dtype=torch.float16,).to('cuda') if is_xformers: self.pipeline.enable_xformers_memory_efficient_attention() time_end = datetime.now() print(f'Load {model_path} successful in {time_end-time_start}') return gr.Dropdown() def refresh_stable_diffusion(self,model_path): self.load_base_pipeline(model_path) return self.stable_diffusion_list[0] def update_base_model(self, base_model_dropdown): if self.pipeline is None: gr.Info(f"Please select a pretrained model path.") return None else: base_model = self.personalized_model_list[base_model_dropdown] mid_model = StableDiffusionPipeline.from_single_file(base_model) self.pipeline.vae = mid_model.vae self.pipeline.unet = mid_model.unet self.pipeline.text_encoder = mid_model.text_encoder self.pipeline.to(self.device) self.personal_model_loaded = base_model_dropdown.split('.')[0] print(f'load {base_model_dropdown} model success!') return gr.Dropdown() @spaces.GPU def generate(self, source, style, num_steps, co_feat_step,strength, start_ac_layer, end_ac_layer, sty_guidance,cfg_scale, mix_q_scale, Scheduler, save_intermediate, seed, de_bug, target_prompt, negative_prompt_textbox, width_slider,height_slider, tome_sx, tome_sy, tome_ratio,tome, ): os.makedirs(self.savedir, exist_ok=True) os.makedirs(self.savedir_sample, exist_ok=True) if self.pipeline == None: self.refresh_stable_diffusion(self.stable_diffusion_list[-1]) model = self.pipeline if Scheduler == 'DDIM': model.scheduler = DDIMScheduler.from_config(model.scheduler.config) print(f"Successful adoption of DDIM scheduler") if Scheduler == 'LCM': model.scheduler = LCMScheduler.from_config(model.scheduler.config) print(f"Successful adoption of LCM scheduler") if Scheduler == 'EulerDiscrete': model.scheduler = EulerDiscreteScheduler.from_config(model.scheduler.config) if seed != '-1' and seed != "": torch.manual_seed(int(seed)) else: torch.seed() seed = torch.initial_seed() print(f"Seed: {seed}") self.sample_count = len(os.listdir(self.savedir_sample)) prompts = [target_prompt] * 3 source = source.resize((width_slider, height_slider)) style = style.resize((width_slider, height_slider)) with torch.no_grad(): controller = AttentionStyle(num_steps, start_ac_layer, end_ac_layer, style_guidance=sty_guidance, mix_q_scale=mix_q_scale, de_bug=de_bug, ) ptp_utils.register_attention_control(model, controller, tome, sx=tome_sx, sy=tome_sy, ratio=tome_ratio, de_bug=de_bug,) time_begin = datetime.now() results = model(prompt=prompts, negative_prompt=negative_prompt_textbox, image=source, style=style, num_inference_steps=num_steps, eta=0.0, guidance_scale=cfg_scale, strength=strength, save_intermediate=save_intermediate, fix_step_index=co_feat_step, de_bug = de_bug, callback = None ) generate_image = results.images for idx, has_nsfw_concept in enumerate(results.nsfw_content_detected): if has_nsfw_concept: generate_image[idx] = np.array(self.nsfw_image.resize((height_slider,width_slider))).astype(np.float32) / 255.0 time_end = datetime.now() print('generate one image with time {}'.format(time_end-time_begin)) save_file_name = f"{self.sample_count}_step{num_steps}_sl{start_ac_layer}_el{end_ac_layer}_ST{strength}_CF{co_feat_step}_STG{sty_guidance}_MQ{mix_q_scale}_CFG{cfg_scale}_seed{seed}.jpg" save_file_path = os.path.join(self.savedir, save_file_name) save_image(torch.tensor(generate_image).permute(0, 3, 1, 2), save_file_path, nrow=3, padding=0) save_image(torch.tensor(generate_image[2:]).permute(0, 3, 1, 2), os.path.join(self.savedir_sample, save_file_name), nrow=3, padding=0) self.init_results_image_path() return [ generate_image[0], generate_image[1], generate_image[2], ] global_text = GlobalText() def ui(): with gr.Blocks(css=css) as demo: gr.Markdown( """ # [ZePo: Zero-Shot Portrait Stylization with Faster Sampling](https://arxiv.org/abs/2408.05492) Jin Liu, Huaibo Huang, Jie Cao, Ran He
[Arxiv](https://arxiv.org/abs/2408.05492) | [Github](https://github.com/liujin112/ZePo) """ ) with gr.Column(variant="panel"): gr.Markdown( """ ### 1. Select a pretrained model. """ ) with gr.Row(): stable_diffusion_dropdown = gr.Dropdown( label="Pretrained Model Path", choices=global_text.stable_diffusion_list, interactive=True, allow_custom_value=True ) stable_diffusion_dropdown.change(fn=global_text.load_base_pipeline, inputs=[stable_diffusion_dropdown], outputs=[stable_diffusion_dropdown]) stable_diffusion_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") def update_stable_diffusion(stable_diffusion_dropdown): global_text.refresh_stable_diffusion(stable_diffusion_dropdown) stable_diffusion_refresh_button.click(fn=update_stable_diffusion, inputs=[stable_diffusion_dropdown], outputs=[stable_diffusion_dropdown]) with gr.Column(variant="panel"): gr.Markdown( """ ### 2. Configs for ZePo. """ ) with gr.Tab("Configs"): with gr.Row(): with gr.Column(): with gr.Row(): source_image = gr.Image(label="Source Image", elem_id="img2maskimg", sources="upload", type="pil",image_mode="RGB", height=256) style_image = gr.Image(label="Style Image", elem_id="img2maskimg", sources="upload", type="pil", image_mode="RGB", height=256) generate_image = gr.Image(label="Image with PortraitDiff", type="pil", interactive=True, image_mode="RGB", height=512) with gr.Row(): recons_content = gr.Image(label="reconstructed content", type="pil", image_mode="RGB", height=256) recons_style = gr.Image(label="reconstructed style", type="pil", image_mode="RGB", height=256) prompt_textbox = gr.Textbox(label="Prompt", value='head', lines=1) negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=1) with gr.Row(equal_height=False): with gr.Column(): with gr.Tab("Resolution"): width_slider = gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64) height_slider = gr.Slider(label="Height", value=512, minimum=256, maximum=1024, step=64) Scheduler = gr.Dropdown( ["DDIM", "LCM", "EulerDiscrete"], value="LCM", label="Scheduler", info="Select a Scheduler") with gr.Tab("Content Gallery"): with gr.Row(): source_path = gr.Textbox(value='./data/content', label="Source Path") refresh_source_list_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") source_gallery_index = gr.Slider(label="Index", value=0, minimum=0, maximum=50, step=1) num_gallery_images = 12 source_image_gallery = gr.Gallery(value=[], columns=4, label="Source Image List") refresh_source_list_button.click(fn=global_text.init_source_image_path, inputs=[source_path], outputs=[source_image_gallery]) def update_source_list(index): if int(index) < 0: index = 0 if int(index) > global_text.max_source_index: index = global_text.max_source_index return global_text.source_paths[int(index)*num_gallery_images:(int(index)+1)*num_gallery_images] source_gallery_index.change(fn=update_source_list, inputs=[source_gallery_index], outputs=[source_image_gallery]) with gr.Tab("Style Gallery"): with gr.Row(): style_path = gr.Textbox(value='./data/style', label="style Path") refresh_style_list_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") style_gallery_index = gr.Slider(label="Index", value=0, minimum=0, maximum=50, step=1) num_gallery_images = 12 style_image_gallery = gr.Gallery(value=[], columns=4, label="style Image List") refresh_style_list_button.click(fn=global_text.init_style_image_path, inputs=[style_path], outputs=[style_image_gallery]) def update_style_list(index): if int(index) < 0: index = 0 if int(index) > global_text.max_style_index: index = global_text.max_style_index return global_text.style_paths[int(index)*num_gallery_images:(int(index)+1)*num_gallery_images] style_gallery_index.change(fn=update_style_list, inputs=[style_gallery_index], outputs=[style_image_gallery]) # with gr.Tab("Results Gallery"): # with gr.Row(): # refresh_results_list_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") # results_gallery_index = gr.Slider(label="Index", value=0, minimum=0, maximum=50, step=1) # num_gallery_images = 12 # results_image_gallery = gr.Gallery(value=[], columns=4, label="style Image List") # refresh_results_list_button.click(fn=global_text.init_results_image_path, inputs=[], outputs=[results_image_gallery]) # def update_results_list(index): # if int(index) < 0: # index = 0 # if int(index) > global_text.max_results_index: # index = global_text.max_results_index # return global_text.results_paths[int(index)*num_gallery_images:(int(index)+1)*num_gallery_images] # results_gallery_index.change(fn=update_results_list, inputs=[results_gallery_index], outputs=[style_image_gallery]) with gr.Row(): generate_button = gr.Button(value="Generate", variant='primary') with gr.Tab('Base Configs'): num_steps = gr.Slider(label="Total Steps", value=4, minimum=0, maximum=25, step=1) strength = gr.Slider(label="Noisy Ratio", value=0.5, minimum=0, maximum=1, step=0.01,info="How much noise applied to souce image, 50% for better balance.") co_feat_step = gr.Slider(label="Consistency Feature Extract Step", value=99, minimum=0, maximum=999, step=1) with gr.Row(): start_ac_layer = gr.Slider(label="Start Layer of AC", minimum=0, maximum=16, value=8, step=1) end_ac_layer = gr.Slider(label="End Layer of AC", minimum=0, maximum=16, value=16, step=1) with gr.Row(): Style_Guidance = gr.Slider(label="Style Guidance Scale", minimum=-1, maximum=3, value=1.2, step=0.01, ) mix_q_scale = gr.Slider(label='Query Mix Ratio', minimum=0, maximum=2, step=0.05, value=1.0, ) cfg_scale_slider = gr.Slider(label="CFG Scale", value=2.5, minimum=0, maximum=20, info="Classifier-free guidance scale.") with gr.Row(): save_intermediate = gr.Checkbox(label="save_intermediate", value=False) de_bug = gr.Checkbox(value=False,label='DeBug') with gr.Tab('ToMe'): with gr.Row(): tome = gr.Checkbox(label="Token Merge", value=True) tome_ratio = gr.Slider(label='ratio: ', minimum=0, maximum=1, step=0.1, value=0.5) with gr.Row(): tome_sx = gr.Slider(label='sx:', minimum=0, maximum=64, step=2, value=2) tome_sy = gr.Slider(label='sy:', minimum=0, maximum=64, step=2, value=2) with gr.Row(): seed_textbox = gr.Textbox(label="Seed", value=-1) seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton") seed_button.click(fn=lambda: random.randint(1, 1e16), inputs=[], outputs=[seed_textbox]) inputs = [ source_image, style_image, num_steps,co_feat_step,strength, start_ac_layer, end_ac_layer, Style_Guidance,cfg_scale_slider,mix_q_scale, Scheduler, save_intermediate, seed_textbox, de_bug, prompt_textbox, negative_prompt_textbox, width_slider,height_slider, tome_sx, tome_sy, tome_ratio, tome, ] generate_button.click( fn=global_text.generate, inputs=inputs, outputs=[recons_style,recons_content,generate_image] ) ex = gr.Examples( [ ["./data/content/27032.jpg","./data/style/27.jpg",4,0.8,0.5,8,8427921159605868845], ["./data/content/29812.jpg","./data/style/47.jpg",4,0.5,0.65,11,8119359809263726691], ], [source_image, style_image, num_steps,strength, mix_q_scale, start_ac_layer, seed_textbox], [ "Example 1", ],) return demo if __name__ == "__main__": demo = ui() demo.launch(show_error=True)