import torch import torch.nn as nn import itertools import torch.nn.functional as F from typing import List from diffusers import ( StableDiffusionPipeline, StableDiffusionXLPipeline, StableDiffusionXLInstructPix2PixPipeline, StableDiffusionInstructPix2PixPipeline, ) from PIL import Image from .ipa_utils import is_torch2_available if is_torch2_available(): from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor else: from .attention_processor import IPAttnProcessor, AttnProcessor from diffusers.loaders import LoraLoaderMixin from diffusers.models.lora import LoRALinearLayer from diffusers.models.unet_2d_blocks import DownBlock2D # from .pipeline_stable_diffusion_xl_t2i_edit import StableDiffusionXLText2ImageAndEditPipeline # from .pipeline_stable_diffusion_t2i_edit import StableDiffusionText2ImageAndEditPipeline class IPAdapterSD(nn.Module): def __init__(self, unet, resampler) -> None: super().__init__() self.unet = unet self.resampler = resampler self.set_ip_adapter() self.set_trainable() def set_ip_adapter(self): attn_procs = {} unet_sd = self.unet.state_dict() for name in self.unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = self.unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = self.unet.config.block_out_channels[block_id] if cross_attention_dim is None: attn_procs[name] = AttnProcessor() else: layer_name = name.split(".processor")[0] weights = { "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], } attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) attn_procs[name].load_state_dict(weights) self.unet.set_attn_processor(attn_procs) self.adapter = torch.nn.ModuleList(self.unet.attn_processors.values()) def set_trainable(self): self.unet.requires_grad_(False) self.resampler.requires_grad_(True) self.adapter.requires_grad_(True) def params_to_opt(self): return itertools.chain(self.resampler.parameters(), self.adapter.parameters()) def forward(self, noisy_latents, timesteps, image_embeds, text_embeds, noise): image_embeds = self.resampler(image_embeds) # image_embeds = image_embeds.to(dtype=text_embeds.dtype) text_embeds = torch.cat([text_embeds, image_embeds], dim=1) # Predict the noise residual and compute loss noise_pred = self.unet(noisy_latents, timesteps, text_embeds).sample # if noise is not None: loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") # else: # loss = torch.tensor(0.0, device=noisy_latents) return {'total_loss': loss, 'noise_pred': noise_pred} def encode_image_embeds(self, image_embeds): dtype = image_embeds.dtype image_embeds = self.resampler(image_embeds) image_embeds = image_embeds.to(dtype=dtype) return image_embeds @classmethod def from_pretrained(cls, unet, resampler, pretrained_model_path=None, pretrained_resampler_path=None, pretrained_adapter_path=None): model = cls(unet=unet, resampler=resampler) if pretrained_model_path is not None: ckpt = torch.load(pretrained_model_path, map_location='cpu') missing, unexpected = model.load_state_dict(ckpt, strict=False) print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected)) if pretrained_resampler_path is not None: ckpt = torch.load(pretrained_resampler_path, map_location='cpu') missing, unexpected = model.resampler.load_state_dict(ckpt, strict=True) print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected)) if pretrained_adapter_path is not None: ckpt = torch.load(pretrained_adapter_path, map_location='cpu') missing, unexpected = model.adapter.load_state_dict(ckpt, strict=True) print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected)) return model @classmethod def from_pretrained_legacy(cls, unet, resampler, pretrained_model_path=None): model = cls(unet=unet, resampler=resampler) if pretrained_model_path is not None: ckpt = torch.load(pretrained_model_path, map_location='cpu') ckpt_image_proj = {} ckpt_ip_layers = {} for key, value in ckpt.items(): if key.startswith('image_proj_model'): new_key = key.replace('image_proj_model.', '') ckpt_image_proj[new_key] = value elif key.startswith('adapter_modules.'): new_key = key.replace('adapter_modules.', '') ckpt_ip_layers[new_key] = value missing, unexpected = model.resampler.load_state_dict(ckpt_image_proj, strict=True) print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected)) missing, unexpected = model.adapter.load_state_dict(ckpt_ip_layers, strict=True) print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected)) return model class IPAdapterSDPipe(nn.Module): def __init__( self, ip_adapter, discrete_model, vae, visual_encoder, text_encoder, tokenizer, scheduler, image_transform, device, dtype, ) -> None: super().__init__() self.ip_adapter = ip_adapter self.vae = vae self.visual_encoder = visual_encoder self.text_encoder = text_encoder self.tokenizer = tokenizer self.scheduler = scheduler self.image_transform = image_transform self.discrete_model = discrete_model self.device = device self.dtype = dtype self.sd_pipe = StableDiffusionPipeline(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=ip_adapter.unet, scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False) def set_scale(self, scale): for attn_processor in self.sd_pipe.unet.attn_processors.values(): if isinstance(attn_processor, IPAttnProcessor): attn_processor.scale = scale @torch.inference_mode() def get_image_embeds(self, image_pil=None, image_tensor=None, return_negative=True): assert int(image_pil is not None) + int(image_tensor is not None) == 1 if image_pil is not None: image_tensor = self.image_transform(image_pil).unsqueeze(0).to(self.device, dtype=self.dtype) if return_negative: image_tensor_neg = torch.zeros_like(image_tensor) image_tensor = torch.cat([image_tensor, image_tensor_neg], dim=0) with torch.cuda.amp.autocast(dtype=self.dtype): image_embeds = self.visual_encoder(image_tensor) image_embeds = self.discrete_model.encode_image_embeds(image_embeds) image_embeds = self.ip_adapter.encode_image_embeds(image_embeds) if return_negative: # bz = image_embeds.shape[0] # image_embeds_neg = image_embeds[bz // 2:] # image_embeds = image_embeds[0:bz // 2] image_embeds, image_embeds_neg = image_embeds.chunk(2) else: image_embeds_neg = None return image_embeds, image_embeds_neg def generate(self, image_pil=None, image_tensor=None, prompt=None, negative_prompt=None, scale=1.0, num_samples=1, seed=42, guidance_scale=7.5, num_inference_steps=30, **kwargs): self.set_scale(scale) assert int(image_pil is not None) + int(image_tensor is not None) == 1 if image_pil is not None: assert isinstance(image_pil, Image.Image) num_prompts = 1 else: num_prompts = image_tensor.shape[0] if prompt is None: # prompt = "best quality, high quality" prompt = "" if negative_prompt is None: negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" if not isinstance(prompt, List): prompt = [prompt] * num_prompts if not isinstance(negative_prompt, List): negative_prompt = [negative_prompt] * num_prompts image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( image_pil=image_pil, image_tensor=image_tensor, return_negative=True, ) bs_embed, seq_len, _ = image_prompt_embeds.shape image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) with torch.inference_mode(): prompt_embeds, negative_prompt_embeds = self.sd_pipe.encode_prompt( prompt, device=self.device, num_images_per_prompt=num_samples, do_classifier_free_guidance=True, negative_prompt=negative_prompt, ) prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None images = self.sd_pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, **kwargs, ).images return images def compute_time_ids(original_size, crops_coords_top_left, target_resolution): # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids target_size = (target_resolution, target_resolution) add_time_ids = list(original_size + crops_coords_top_left + target_size) add_time_ids = torch.tensor([add_time_ids]) # add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) return add_time_ids class SDXLAdapter(nn.Module): def __init__(self, unet, resampler, full_ft=False) -> None: super().__init__() self.unet = unet self.resampler = resampler self.full_ft = full_ft self.set_trainable_v2() # self.set_adapter() # self.set_trainable() # def set_adapter(self): # adapter = [] # for name, module in self.unet.named_modules(): # if name.endswith('to_k') or name.endswith('to_v'): # if module is not None: # adapter.append(module) # self.adapter = torch.nn.ModuleList(adapter) # print(f'adapter: {self.adapter}') # def set_trainable(self): # self.unet.requires_grad_(False) # self.resampler.requires_grad_(True) # self.adapter.requires_grad_(True) def set_trainable_v2(self): self.resampler.requires_grad_(True) adapter_parameters = [] if self.full_ft: self.unet.requires_grad_(True) adapter_parameters.extend(self.unet.parameters()) else: self.unet.requires_grad_(False) for name, module in self.unet.named_modules(): if name.endswith('to_k') or name.endswith('to_v'): if module is not None: adapter_parameters.extend(module.parameters()) self.adapter_parameters = adapter_parameters for param in self.adapter_parameters: param.requires_grad_(True) # def params_to_opt(self): # return itertools.chain(self.resampler.parameters(), self.adapter.parameters()) def params_to_opt(self): return itertools.chain(self.resampler.parameters(), self.adapter_parameters) def forward(self, noisy_latents, timesteps, image_embeds, text_embeds, noise, time_ids): image_embeds, pooled_image_embeds = self.resampler(image_embeds) unet_added_conditions = {"time_ids": time_ids, 'text_embeds': pooled_image_embeds} noise_pred = self.unet(noisy_latents, timesteps, image_embeds, added_cond_kwargs=unet_added_conditions).sample # if noise is not None: loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") # else: # loss = torch.tensor(0.0, device=noisy_latents) return {'total_loss': loss, 'noise_pred': noise_pred} def encode_image_embeds(self, image_embeds): image_embeds, pooled_image_embeds = self.resampler(image_embeds) return image_embeds, pooled_image_embeds @classmethod def from_pretrained(cls, unet, resampler, pretrained_model_path=None, **kwargs): model = cls(unet=unet, resampler=resampler, **kwargs) if pretrained_model_path is not None: ckpt = torch.load(pretrained_model_path, map_location='cpu') missing, unexpected = model.load_state_dict(ckpt, strict=False) print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected)) return model def init_pipe(self, vae, scheduler, visual_encoder, image_transform, discrete_model=None, dtype=torch.float16, device='cuda'): self.device = device self.dtype = dtype sdxl_pipe = StableDiffusionXLPipeline(tokenizer=None, tokenizer_2=None, text_encoder=None, text_encoder_2=None, vae=vae, unet=self.unet, scheduler=scheduler) self.sdxl_pipe = sdxl_pipe # .to(self.device, dtype=self.dtype) # print(sdxl_pipe.text_encoder_2, sdxl_pipe.text_encoder) self.visual_encoder = visual_encoder.to(self.device, dtype=self.dtype) if discrete_model is not None: self.discrete_model = discrete_model.to(self.device, dtype=self.dtype) else: self.discrete_model = None self.image_transform = image_transform @torch.inference_mode() def get_image_embeds(self, image_pil=None, image_tensor=None, image_embeds=None, return_negative=True, image_size=448 ): assert int(image_pil is not None) + int(image_tensor is not None) + int(image_embeds is not None) == 1 if image_pil is not None: image_tensor = self.image_transform(image_pil).unsqueeze(0).to(self.device, dtype=self.dtype) if image_tensor is not None: if return_negative: image_tensor_neg = torch.zeros_like(image_tensor) image_tensor = torch.cat([image_tensor, image_tensor_neg], dim=0) image_embeds = self.visual_encoder(image_tensor) elif return_negative: image_tensor_neg = torch.zeros( 1, 3, image_size, image_size ).to( image_embeds.device, dtype=image_embeds.dtype ) image_embeds_neg = self.visual_encoder(image_tensor_neg) image_embeds = torch.cat([image_embeds, image_embeds_neg], dim=0) if self.discrete_model is not None: image_embeds = self.discrete_model.encode_image_embeds(image_embeds) image_embeds, pooled_image_embeds = self.encode_image_embeds(image_embeds) if return_negative: image_embeds, image_embeds_neg = image_embeds.chunk(2) pooled_image_embeds, pooled_image_embeds_neg = pooled_image_embeds.chunk(2) else: image_embeds_neg = None pooled_image_embeds_neg = None return image_embeds, image_embeds_neg, pooled_image_embeds, pooled_image_embeds_neg def generate(self, image_pil=None, image_tensor=None, image_embeds=None, seed=42, height=1024, width=1024, guidance_scale=7.5, num_inference_steps=30, input_image_size=448, **kwargs): if image_pil is not None: assert isinstance(image_pil, Image.Image) image_prompt_embeds, uncond_image_prompt_embeds, pooled_image_prompt_embeds, \ pooled_uncond_image_prompt_embeds = self.get_image_embeds( image_pil=image_pil, image_tensor=image_tensor, image_embeds=image_embeds, return_negative=True, image_size=input_image_size, ) # print(image_prompt_embeds.shape, pooled_image_prompt_embeds.shape) generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None images = self.sdxl_pipe( prompt_embeds=image_prompt_embeds, negative_prompt_embeds=uncond_image_prompt_embeds, pooled_prompt_embeds=pooled_image_prompt_embeds, negative_pooled_prompt_embeds=pooled_uncond_image_prompt_embeds, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, height=height, width=width, **kwargs, ).images return images class SDXLText2ImageAndEditAdapter(nn.Module): def __init__(self, unet, resampler, lora_rank=16, fully_ft=False) -> None: super().__init__() self.unet = unet self.resampler = resampler self.lora_rank = lora_rank if fully_ft: self.set_fully_trainable() else: self.set_adapter() def set_adapter(self): self.unet.requires_grad_(False) adapter_parameters = [] in_channels = 8 out_channels = self.unet.conv_in.out_channels self.unet.register_to_config(in_channels=in_channels) with torch.no_grad(): new_conv_in = nn.Conv2d(in_channels, out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, self.unet.conv_in.padding) new_conv_in.weight.zero_() new_conv_in.weight[:, :4, :, :].copy_(self.unet.conv_in.weight) self.unet.conv_in = new_conv_in self.unet.conv_in.requires_grad_(True) print('Make conv_in trainable.') adapter_parameters.extend(self.unet.conv_in.parameters()) for name, module in self.unet.named_modules(): if isinstance(module, DownBlock2D): module.requires_grad_(True) adapter_parameters.extend(module.parameters()) print('Make DownBlock2D trainable.') for attn_processor_name, attn_processor in self.unet.attn_processors.items(): # Parse the attention module. attn_module = self.unet for n in attn_processor_name.split(".")[:-1]: attn_module = getattr(attn_module, n) # Set the `lora_layer` attribute of the attention-related matrices. attn_module.to_q.set_lora_layer( LoRALinearLayer(in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=self.lora_rank)) # attn_module.to_k.set_lora_layer( # LoRALinearLayer(in_features=attn_module.to_k.in_features, # out_features=attn_module.to_k.out_features, # rank=self.lora_rank)) # attn_module.to_v.set_lora_layer( # LoRALinearLayer(in_features=attn_module.to_v.in_features, # out_features=attn_module.to_v.out_features, # rank=self.lora_rank)) attn_module.to_out[0].set_lora_layer( LoRALinearLayer( in_features=attn_module.to_out[0].in_features, out_features=attn_module.to_out[0].out_features, rank=self.lora_rank, )) attn_module.to_k.requires_grad_(True) attn_module.to_v.requires_grad_(True) adapter_parameters.extend(attn_module.to_q.lora_layer.parameters()) adapter_parameters.extend(attn_module.to_k.parameters()) adapter_parameters.extend(attn_module.to_v.parameters()) adapter_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) self.adapter_parameters = adapter_parameters def set_fully_trainable(self): in_channels = 8 out_channels = self.unet.conv_in.out_channels self.unet.register_to_config(in_channels=in_channels) with torch.no_grad(): new_conv_in = nn.Conv2d(in_channels, out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, self.unet.conv_in.padding) new_conv_in.weight.zero_() new_conv_in.weight[:, :4, :, :].copy_(self.unet.conv_in.weight) self.unet.conv_in = new_conv_in self.unet.requires_grad_(True) self.adapter_parameters = self.unet.parameters() def params_to_opt(self): return itertools.chain(self.resampler.parameters(), self.adapter_parameters) def forward(self, noisy_latents, timesteps, image_embeds, text_embeds, noise, time_ids, pooled_text_embeds=None): text_embeds, pooled_text_embeds = self.resampler(text_embeds, pooled_text_embeds=pooled_text_embeds) unet_added_conditions = {"time_ids": time_ids, 'text_embeds': pooled_text_embeds} noise_pred = self.unet(noisy_latents, timesteps, text_embeds, added_cond_kwargs=unet_added_conditions).sample loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") return {'total_loss': loss, 'noise_pred': noise_pred} def encode_text_embeds(self, text_embeds, pooled_text_embeds=None): text_embeds, pooled_text_embeds = self.resampler(text_embeds, pooled_text_embeds=pooled_text_embeds) return text_embeds, pooled_text_embeds @classmethod def from_pretrained(cls, unet, resampler, pretrained_model_path=None, **kwargs): model = cls(unet=unet, resampler=resampler, **kwargs) if pretrained_model_path is not None: ckpt = torch.load(pretrained_model_path, map_location='cpu') missing, unexpected = model.load_state_dict(ckpt, strict=False) print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected)) return model def init_pipe(self, vae, scheduler, text_encoder, text_encoder_2, tokenizer, tokenizer_2, dtype=torch.float16, device='cuda'): self.device = device self.dtype = dtype sdxl_pipe = StableDiffusionXLText2ImageAndEditPipeline( tokenizer=None, tokenizer_2=None, text_encoder=None, text_encoder_2=None, vae=vae, unet=self.unet, scheduler=scheduler, ) self.sdxl_pipe = sdxl_pipe self.sdxl_pipe.to(device, dtype=dtype) self.tokenizer = tokenizer self.tokenizer_2 = tokenizer_2 self.text_encoder = text_encoder self.text_encoder_2 = text_encoder_2 @torch.inference_mode() def get_text_embeds(self, prompt=None, negative_prompt='', text_embeds=None): assert int(prompt is not None) + int(text_embeds is not None) == 1 if prompt is not None: text_input_ids = self.tokenizer([prompt, negative_prompt], max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids text_input_ids_2 = self.tokenizer_2([prompt, negative_prompt], max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids encoder_output = self.text_encoder(text_input_ids.to(self.device), output_hidden_states=True) text_embeds = encoder_output.hidden_states[-2] encoder_output_2 = self.text_encoder_2(text_input_ids_2.to(self.device), output_hidden_states=True) pooled_text_embeds = encoder_output_2[0] text_embeds_2 = encoder_output_2.hidden_states[-2] text_embeds = torch.cat([text_embeds, text_embeds_2], dim=-1) else: text_input_ids = self.tokenizer(negative_prompt, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids text_input_ids_2 = self.tokenizer_2(negative_prompt, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids encoder_output = self.text_encoder(text_input_ids.to(self.device), output_hidden_states=True) text_embeds_neg = encoder_output.hidden_states[-2] encoder_output_2 = self.text_encoder_2(text_input_ids_2.to(self.device), output_hidden_states=True) text_embeds_neg_2 = encoder_output_2.hidden_states[-2] pooled_text_embeds = encoder_output_2[0] text_embeds_neg = torch.cat([text_embeds_neg, text_embeds_neg_2], dim=-1) text_embeds = torch.cat([text_embeds, text_embeds_neg], dim=0) text_embeds, pooled_text_embeds = self.encode_text_embeds(text_embeds, pooled_text_embeds=pooled_text_embeds) text_embeds, text_embeds_neg = text_embeds.chunk(2) pooled_text_embeds, pooled_text_embeds_neg = pooled_text_embeds.chunk(2) return text_embeds, text_embeds_neg, pooled_text_embeds, pooled_text_embeds_neg def generate(self, prompt=None, negative_prompt='', image=None, text_embeds=None, seed=42, height=1024, width=1024, guidance_scale=7.5, num_inference_steps=30, **kwargs): text_embeds, text_embeds_neg, pooled_text_embeds, pooled_text_embeds_neg = self.get_text_embeds( prompt=prompt, negative_prompt=negative_prompt, text_embeds=text_embeds) generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None images = self.sdxl_pipe( image=image, prompt_embeds=text_embeds, negative_prompt_embeds=text_embeds_neg, pooled_prompt_embeds=pooled_text_embeds, negative_pooled_prompt_embeds=pooled_text_embeds_neg, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, height=height, width=width, **kwargs, ).images return images class SD21Text2ImageAndEditAdapter(SDXLText2ImageAndEditAdapter): def forward(self, noisy_latents, timesteps, image_embeds, text_embeds, noise): text_embeds, _ = self.resampler(text_embeds) # unet_added_conditions = {"time_ids": time_ids, 'text_embeds': pooled_text_embeds} noise_pred = self.unet(noisy_latents, timesteps, text_embeds).sample loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") return {'total_loss': loss, 'noise_pred': noise_pred} def init_pipe(self, vae, scheduler, text_encoder, tokenizer, feature_extractor, dtype=torch.float16, device='cuda'): self.device = device self.dtype = dtype sd_pipe = StableDiffusionText2ImageAndEditPipeline( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, unet=self.unet, feature_extractor=feature_extractor, safety_checker=None, requires_safety_checker=False, scheduler=scheduler, ) self.sd_pipe = sd_pipe self.sd_pipe.to(device, dtype=dtype) self.tokenizer = tokenizer self.text_encoder = text_encoder @torch.inference_mode() def get_text_embeds(self, prompt=None, negative_prompt='', text_embeds=None): assert int(prompt is not None) + int(text_embeds is not None) == 1 if prompt is not None: text_input_ids = self.tokenizer([prompt, negative_prompt], max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids encoder_output = self.text_encoder(text_input_ids.to(self.device)) text_embeds = encoder_output[0] else: text_input_ids = self.tokenizer(negative_prompt, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids encoder_output = self.text_encoder(text_input_ids.to(self.device)) text_embeds_neg = encoder_output[0] text_embeds = torch.cat([text_embeds, text_embeds_neg], dim=0) text_embeds, _ = self.encode_text_embeds(text_embeds) text_embeds, text_embeds_neg = text_embeds.chunk(2) return text_embeds, text_embeds_neg def generate(self, prompt=None, negative_prompt='', image=None, text_embeds=None, seed=42, height=1024, width=1024, guidance_scale=7.5, num_inference_steps=30, **kwargs): text_embeds, text_embeds_neg = self.get_text_embeds( prompt=prompt, negative_prompt=negative_prompt, text_embeds=text_embeds) generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None print(f'text_embeds: {text_embeds.shape}') print(f'text_embeds_neg: {text_embeds_neg.shape}') images = self.sd_pipe( image=image, prompt_embeds=text_embeds, negative_prompt_embeds=text_embeds_neg, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, height=height, width=width, **kwargs, ).images return images class SDXLAdapterWithLatentImage(SDXLAdapter): def __init__(self, unet, resampler, full_ft=False, set_trainable_late=False) -> None: nn.Module.__init__(self) self.unet = unet self.resampler = resampler self.full_ft = full_ft if not set_trainable_late: self.set_trainable() def set_trainable(self): self.resampler.requires_grad_(True) adapter_parameters = [] in_channels = 8 out_channels = self.unet.conv_in.out_channels self.unet.register_to_config(in_channels=in_channels) self.unet.requires_grad_(False) with torch.no_grad(): new_conv_in = nn.Conv2d(in_channels, out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, self.unet.conv_in.padding) new_conv_in.weight.zero_() new_conv_in.weight[:, :4, :, :].copy_(self.unet.conv_in.weight) self.unet.conv_in = new_conv_in self.unet.conv_in.requires_grad_(True) if self.full_ft: self.unet.requires_grad_(True) adapter_parameters.extend(self.unet.parameters()) else: adapter_parameters.extend(self.unet.conv_in.parameters()) for name, module in self.unet.named_modules(): if name.endswith('to_k') or name.endswith('to_v'): if module is not None: adapter_parameters.extend(module.parameters()) self.adapter_parameters = adapter_parameters @classmethod def from_pretrained(cls, unet, resampler, pretrained_model_path=None, set_trainable_late=False, **kwargs): model = cls(unet=unet, resampler=resampler, set_trainable_late=set_trainable_late, **kwargs) if pretrained_model_path is not None: ckpt = torch.load(pretrained_model_path, map_location='cpu') missing, unexpected = model.load_state_dict(ckpt, strict=False) print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected)) if set_trainable_late: model.set_trainable() return model def init_pipe(self, vae, scheduler, visual_encoder, image_transform, dtype=torch.float16, device='cuda'): self.device = device self.dtype = dtype sdxl_pipe = StableDiffusionXLText2ImageAndEditPipeline( tokenizer=None, tokenizer_2=None, text_encoder=None, text_encoder_2=None, vae=vae, unet=self.unet, scheduler=scheduler, ) self.sdxl_pipe = sdxl_pipe self.sdxl_pipe.to(device, dtype=dtype) self.discrete_model = None self.visual_encoder = visual_encoder.to(self.device, dtype=self.dtype) self.image_transform = image_transform def generate(self, image_pil=None, image_tensor=None, image_embeds=None, latent_image=None, seed=42, height=1024, width=1024, guidance_scale=7.5, num_inference_steps=30, input_image_size=448, **kwargs): if image_pil is not None: assert isinstance(image_pil, Image.Image) image_prompt_embeds, uncond_image_prompt_embeds, \ pooled_image_prompt_embeds, pooled_uncond_image_prompt_embeds = self.get_image_embeds( image_pil=image_pil, image_tensor=image_tensor, image_embeds=image_embeds, return_negative=True, image_size=input_image_size, ) # print(image_prompt_embeds.shape, pooled_image_prompt_embeds.shape) generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None images = self.sdxl_pipe( image=latent_image, prompt_embeds=image_prompt_embeds, negative_prompt_embeds=uncond_image_prompt_embeds, pooled_prompt_embeds=pooled_image_prompt_embeds, negative_pooled_prompt_embeds=pooled_uncond_image_prompt_embeds, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, height=height, width=width, **kwargs, ).images return images