import sys sys.path.insert(0,'stable_diffusion') import gradio as gr from train_esd import train_esd from convertModels import convert_ldm_unet_checkpoint, create_unet_diffusers_config from omegaconf import OmegaConf from StableDiffuser import StableDiffuser from diffusers import UNet2DConditionModel ckpt_path = "stable_diffusion/models/ldm/sd-v1-4-full-ema.ckpt" config_path = "stable_diffusion/configs/stable-diffusion/v1-inference.yaml" diffusers_config_path = "stable_diffusion/config.json" class Demo: def __init__(self) -> None: self.training = False self.generating = False with gr.Blocks() as demo: self.layout() demo.queue(concurrency_count=10).launch() def disable(self): return [gr.update(interactive=False), gr.update(interactive=False)] def layout(self): with gr.Row(): with gr.Column(scale=1) as training_column: self.prompt_input = gr.Text( placeholder="Enter prompt...", label="Prompt to Erase", info="Prompt corresponding to concept to erase" ) self.train_method_input = gr.Dropdown( choices=['noxattn', 'selfattn', 'xattn', 'full'], value='xattn', label='Train Method', info='Method of training' ) self.neg_guidance_input = gr.Number( value=1, label="Negative Guidance", info='Guidance of negative training used to train' ) self.iterations_input = gr.Number( value=1000, precision=0, label="Iterations", info='iterations used to train' ) self.lr_input = gr.Number( value=1e-5, label="Learning Rate", info='Learning rate used to train' ) self.progress_bar = gr.Text(interactive=False, label="Training Progress") self.train_button = gr.Button( value="Train", ) with gr.Column(scale=2) as inference_column: with gr.Row(): with gr.Column(scale=4): self.prompt_input_infr = gr.Text( placeholder="Enter prompt...", label="Prompt", info="Prompt to generate" ) with gr.Column(scale=1): self.seed_infr = gr.Number( label="Seed", value=42 ) with gr.Row(): self.image_new = gr.Image( label="New Image", interactive=False ) self.image_orig = gr.Image( label="Orig Image", interactive=False ) with gr.Row(): self.infr_button = gr.Button( value="Generate", interactive=False ) self.infr_button.click(self.inference, inputs = [ self.prompt_input_infr, self.seed_infr ], outputs=[ self.image_new, self.image_orig ] ) self.train_button.click(self.disable, outputs=[self.train_button, self.infr_button] ) self.train_button.click(self.train, inputs = [ self.prompt_input, self.train_method_input, self.neg_guidance_input, self.iterations_input, self.lr_input ], outputs=[self.train_button, self.infr_button, self.progress_bar] ) def train(self, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)): if self.training: return [None, None, None] else: self.training = True model_orig, model_edited = train_esd(prompt, train_method, 3, neg_guidance, iterations, lr, config_path, ckpt_path, diffusers_config_path, ['cuda', 'cuda'] ) original_config = OmegaConf.load(config_path) original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = 4 unet_config = create_unet_diffusers_config(original_config, image_size=512) model_edited_sd = convert_ldm_unet_checkpoint(model_edited.state_dict(), unet_config) model_orig_sd = convert_ldm_unet_checkpoint(model_orig.state_dict(), unet_config) self.init_inference(model_edited_sd, model_orig_sd, unet_config) return [gr.update(interactive=True), gr.update(interactive=True), None] def init_inference(self, model_edited_sd, model_orig_sd, unet_config): self.model_edited_sd = model_edited_sd self.model_orig_sd = model_orig_sd self.diffuser = StableDiffuser(42) self.diffuser.unet = UNet2DConditionModel(**unet_config) self.diffuser.to('cuda') self.training = False def inference(self, prompt, seed, pbar = gr.Progress(track_tqdm=True)): if self.generating: return [None, None] else: self.generating = True self.diffuser.unet.load_state_dict(self.model_orig_sd) self.diffuser._seed = seed images = self.diffuser( prompt, n_steps=50, reseed=True ) orig_image = images[0][0] self.diffuser.unet.load_state_dict(self.model_edited_sd) images = self.diffuser( prompt, n_steps=50, reseed=True ) edited_image = images[0][0] self.generating = False return edited_image, orig_image demo = Demo()