SEED-Story / src /models_ipa /adapter_modules.py
Andyson's picture
huggingface hub
0447610
import torch
import torch.nn as nn
import itertools
import torch.nn.functional as F
from typing import List
from diffusers import (
StableDiffusionPipeline,
StableDiffusionXLPipeline,
StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionInstructPix2PixPipeline,
)
from PIL import Image
from .ipa_utils import is_torch2_available
if is_torch2_available():
from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
else:
from .attention_processor import IPAttnProcessor, AttnProcessor
from diffusers.loaders import LoraLoaderMixin
from diffusers.models.lora import LoRALinearLayer
from diffusers.models.unet_2d_blocks import DownBlock2D
from transformers import AutoModel
# from .pipeline_stable_diffusion_xl_t2i_edit import StableDiffusionXLText2ImageAndEditPipeline
# from .pipeline_stable_diffusion_t2i_edit import StableDiffusionText2ImageAndEditPipeline
class IPAdapterSD(nn.Module):
def __init__(self, unet, resampler) -> None:
super().__init__()
self.unet = unet
self.resampler = resampler
self.set_ip_adapter()
self.set_trainable()
def set_ip_adapter(self):
attn_procs = {}
unet_sd = self.unet.state_dict()
for name in self.unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = self.unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = self.unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor()
else:
layer_name = name.split(".processor")[0]
weights = {
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
}
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
attn_procs[name].load_state_dict(weights)
self.unet.set_attn_processor(attn_procs)
self.adapter = torch.nn.ModuleList(self.unet.attn_processors.values())
def set_trainable(self):
self.unet.requires_grad_(False)
self.resampler.requires_grad_(True)
self.adapter.requires_grad_(True)
def params_to_opt(self):
return itertools.chain(self.resampler.parameters(), self.adapter.parameters())
def forward(self, noisy_latents, timesteps, image_embeds, text_embeds, noise):
image_embeds = self.resampler(image_embeds)
# image_embeds = image_embeds.to(dtype=text_embeds.dtype)
text_embeds = torch.cat([text_embeds, image_embeds], dim=1)
# Predict the noise residual and compute loss
noise_pred = self.unet(noisy_latents, timesteps, text_embeds).sample
# if noise is not None:
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
# else:
# loss = torch.tensor(0.0, device=noisy_latents)
return {'total_loss': loss, 'noise_pred': noise_pred}
def encode_image_embeds(self, image_embeds):
dtype = image_embeds.dtype
image_embeds = self.resampler(image_embeds)
image_embeds = image_embeds.to(dtype=dtype)
return image_embeds
@classmethod
def from_pretrained(cls,
unet,
resampler,
pretrained_model_path=None,
pretrained_resampler_path=None,
pretrained_adapter_path=None):
model = cls(unet=unet, resampler=resampler)
if pretrained_model_path is not None:
ckpt = torch.load(pretrained_model_path, map_location='cpu')
missing, unexpected = model.load_state_dict(ckpt, strict=False)
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
if pretrained_resampler_path is not None:
ckpt = torch.load(pretrained_resampler_path, map_location='cpu')
missing, unexpected = model.resampler.load_state_dict(ckpt, strict=True)
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
if pretrained_adapter_path is not None:
ckpt = torch.load(pretrained_adapter_path, map_location='cpu')
missing, unexpected = model.adapter.load_state_dict(ckpt, strict=True)
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
return model
@classmethod
def from_pretrained_legacy(cls, unet, resampler, pretrained_model_path=None):
model = cls(unet=unet, resampler=resampler)
if pretrained_model_path is not None:
ckpt = torch.load(pretrained_model_path, map_location='cpu')
ckpt_image_proj = {}
ckpt_ip_layers = {}
for key, value in ckpt.items():
if key.startswith('image_proj_model'):
new_key = key.replace('image_proj_model.', '')
ckpt_image_proj[new_key] = value
elif key.startswith('adapter_modules.'):
new_key = key.replace('adapter_modules.', '')
ckpt_ip_layers[new_key] = value
missing, unexpected = model.resampler.load_state_dict(ckpt_image_proj, strict=True)
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
missing, unexpected = model.adapter.load_state_dict(ckpt_ip_layers, strict=True)
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
return model
class IPAdapterSDPipe(nn.Module):
def __init__(
self,
ip_adapter,
discrete_model,
vae,
visual_encoder,
text_encoder,
tokenizer,
scheduler,
image_transform,
device,
dtype,
) -> None:
super().__init__()
self.ip_adapter = ip_adapter
self.vae = vae
self.visual_encoder = visual_encoder
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.scheduler = scheduler
self.image_transform = image_transform
self.discrete_model = discrete_model
self.device = device
self.dtype = dtype
self.sd_pipe = StableDiffusionPipeline(vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=ip_adapter.unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False)
def set_scale(self, scale):
for attn_processor in self.sd_pipe.unet.attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor):
attn_processor.scale = scale
@torch.inference_mode()
def get_image_embeds(self, image_pil=None, image_tensor=None, return_negative=True):
assert int(image_pil is not None) + int(image_tensor is not None) == 1
if image_pil is not None:
image_tensor = self.image_transform(image_pil).unsqueeze(0).to(self.device, dtype=self.dtype)
if return_negative:
image_tensor_neg = torch.zeros_like(image_tensor)
image_tensor = torch.cat([image_tensor, image_tensor_neg], dim=0)
with torch.cuda.amp.autocast(dtype=self.dtype):
image_embeds = self.visual_encoder(image_tensor)
image_embeds = self.discrete_model.encode_image_embeds(image_embeds)
image_embeds = self.ip_adapter.encode_image_embeds(image_embeds)
if return_negative:
# bz = image_embeds.shape[0]
# image_embeds_neg = image_embeds[bz // 2:]
# image_embeds = image_embeds[0:bz // 2]
image_embeds, image_embeds_neg = image_embeds.chunk(2)
else:
image_embeds_neg = None
return image_embeds, image_embeds_neg
def generate(self,
image_pil=None,
image_tensor=None,
prompt=None,
negative_prompt=None,
scale=1.0,
num_samples=1,
seed=42,
guidance_scale=7.5,
num_inference_steps=30,
**kwargs):
self.set_scale(scale)
assert int(image_pil is not None) + int(image_tensor is not None) == 1
if image_pil is not None:
assert isinstance(image_pil, Image.Image)
num_prompts = 1
else:
num_prompts = image_tensor.shape[0]
if prompt is None:
# prompt = "best quality, high quality"
prompt = ""
if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
image_pil=image_pil,
image_tensor=image_tensor,
return_negative=True,
)
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
with torch.inference_mode():
prompt_embeds, negative_prompt_embeds = self.sd_pipe.encode_prompt(
prompt,
device=self.device,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
images = self.sd_pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
**kwargs,
).images
return images
def compute_time_ids(original_size, crops_coords_top_left, target_resolution):
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
target_size = (target_resolution, target_resolution)
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
# add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
return add_time_ids
class SDXLAdapter(nn.Module):
def __init__(self, unet, resampler, full_ft=False) -> None:
super().__init__()
self.unet = unet
self.resampler = resampler
self.full_ft = full_ft
self.set_trainable_v2()
# self.set_adapter()
# self.set_trainable()
# def set_adapter(self):
# adapter = []
# for name, module in self.unet.named_modules():
# if name.endswith('to_k') or name.endswith('to_v'):
# if module is not None:
# adapter.append(module)
# self.adapter = torch.nn.ModuleList(adapter)
# print(f'adapter: {self.adapter}')
# def set_trainable(self):
# self.unet.requires_grad_(False)
# self.resampler.requires_grad_(True)
# self.adapter.requires_grad_(True)
def set_trainable_v2(self):
self.resampler.requires_grad_(True)
adapter_parameters = []
if self.full_ft:
self.unet.requires_grad_(True)
adapter_parameters.extend(self.unet.parameters())
else:
self.unet.requires_grad_(False)
for name, module in self.unet.named_modules():
if name.endswith('to_k') or name.endswith('to_v'):
if module is not None:
adapter_parameters.extend(module.parameters())
self.adapter_parameters = adapter_parameters
for param in self.adapter_parameters:
param.requires_grad_(True)
# def params_to_opt(self):
# return itertools.chain(self.resampler.parameters(), self.adapter.parameters())
def params_to_opt(self):
return itertools.chain(self.resampler.parameters(), self.adapter_parameters)
def forward(self, noisy_latents, timesteps, image_embeds, text_embeds, noise, time_ids):
image_embeds, pooled_image_embeds = self.resampler(image_embeds)
unet_added_conditions = {"time_ids": time_ids, 'text_embeds': pooled_image_embeds}
noise_pred = self.unet(noisy_latents, timesteps, image_embeds, added_cond_kwargs=unet_added_conditions).sample
# if noise is not None:
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
# else:
# loss = torch.tensor(0.0, device=noisy_latents)
return {'total_loss': loss, 'noise_pred': noise_pred}
def encode_image_embeds(self, image_embeds):
image_embeds, pooled_image_embeds = self.resampler(image_embeds)
return image_embeds, pooled_image_embeds
@classmethod
def from_pretrained(cls, pretrained_model_path=None, subfolder=None, **kwargs):
model = cls(**kwargs)
if pretrained_model_path is not None:
# Load model from Hugging Face Hub with subfolder specification
if 'TencentARC/SEED-Story' in pretrained_model_path:
# Use `subfolder` to specify the location within the repository
ckpt = AutoModel.from_pretrained(pretrained_model_path, subfolder=subfolder)
missing, unexpected = model.load_state_dict(ckpt.state_dict(), strict=False)
print('Detokenizer model, missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
else:
# For local path loading
ckpt = torch.load(pretrained_model_path, map_location='cpu')
missing, unexpected = model.load_state_dict(ckpt, strict=False)
print('Detokenizer model, missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
return model
def init_pipe(self,
vae,
scheduler,
visual_encoder,
image_transform,
discrete_model=None,
dtype=torch.float16,
device='cuda'):
self.device = device
self.dtype = dtype
sdxl_pipe = StableDiffusionXLPipeline(tokenizer=None,
tokenizer_2=None,
text_encoder=None,
text_encoder_2=None,
vae=vae,
unet=self.unet,
scheduler=scheduler)
self.sdxl_pipe = sdxl_pipe # .to(self.device, dtype=self.dtype)
# print(sdxl_pipe.text_encoder_2, sdxl_pipe.text_encoder)
self.visual_encoder = visual_encoder.to(self.device, dtype=self.dtype)
if discrete_model is not None:
self.discrete_model = discrete_model.to(self.device, dtype=self.dtype)
else:
self.discrete_model = None
self.image_transform = image_transform
@torch.inference_mode()
def get_image_embeds(self,
image_pil=None,
image_tensor=None,
image_embeds=None,
return_negative=True,
image_size=448
):
assert int(image_pil is not None) + int(image_tensor is not None) + int(image_embeds is not None) == 1
if image_pil is not None:
image_tensor = self.image_transform(image_pil).unsqueeze(0).to(self.device, dtype=self.dtype)
if image_tensor is not None:
if return_negative:
image_tensor_neg = torch.zeros_like(image_tensor)
image_tensor = torch.cat([image_tensor, image_tensor_neg], dim=0)
image_embeds = self.visual_encoder(image_tensor)
elif return_negative:
image_tensor_neg = torch.zeros(
1, 3,
image_size, image_size
).to(
image_embeds.device, dtype=image_embeds.dtype
)
image_embeds_neg = self.visual_encoder(image_tensor_neg)
image_embeds = torch.cat([image_embeds, image_embeds_neg], dim=0)
if self.discrete_model is not None:
image_embeds = self.discrete_model.encode_image_embeds(image_embeds)
image_embeds, pooled_image_embeds = self.encode_image_embeds(image_embeds)
if return_negative:
image_embeds, image_embeds_neg = image_embeds.chunk(2)
pooled_image_embeds, pooled_image_embeds_neg = pooled_image_embeds.chunk(2)
else:
image_embeds_neg = None
pooled_image_embeds_neg = None
return image_embeds, image_embeds_neg, pooled_image_embeds, pooled_image_embeds_neg
def generate(self,
image_pil=None,
image_tensor=None,
image_embeds=None,
seed=42,
height=1024,
width=1024,
guidance_scale=7.5,
num_inference_steps=30,
input_image_size=448,
**kwargs):
if image_pil is not None:
assert isinstance(image_pil, Image.Image)
image_prompt_embeds, uncond_image_prompt_embeds, pooled_image_prompt_embeds, \
pooled_uncond_image_prompt_embeds = self.get_image_embeds(
image_pil=image_pil,
image_tensor=image_tensor,
image_embeds=image_embeds,
return_negative=True,
image_size=input_image_size,
)
# print(image_prompt_embeds.shape, pooled_image_prompt_embeds.shape)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
images = self.sdxl_pipe(
prompt_embeds=image_prompt_embeds,
negative_prompt_embeds=uncond_image_prompt_embeds,
pooled_prompt_embeds=pooled_image_prompt_embeds,
negative_pooled_prompt_embeds=pooled_uncond_image_prompt_embeds,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
height=height,
width=width,
**kwargs,
).images
return images
class SDXLText2ImageAndEditAdapter(nn.Module):
def __init__(self, unet, resampler, lora_rank=16, fully_ft=False) -> None:
super().__init__()
self.unet = unet
self.resampler = resampler
self.lora_rank = lora_rank
if fully_ft:
self.set_fully_trainable()
else:
self.set_adapter()
def set_adapter(self):
self.unet.requires_grad_(False)
adapter_parameters = []
in_channels = 8
out_channels = self.unet.conv_in.out_channels
self.unet.register_to_config(in_channels=in_channels)
with torch.no_grad():
new_conv_in = nn.Conv2d(in_channels, out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride,
self.unet.conv_in.padding)
new_conv_in.weight.zero_()
new_conv_in.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)
self.unet.conv_in = new_conv_in
self.unet.conv_in.requires_grad_(True)
print('Make conv_in trainable.')
adapter_parameters.extend(self.unet.conv_in.parameters())
for name, module in self.unet.named_modules():
if isinstance(module, DownBlock2D):
module.requires_grad_(True)
adapter_parameters.extend(module.parameters())
print('Make DownBlock2D trainable.')
for attn_processor_name, attn_processor in self.unet.attn_processors.items():
# Parse the attention module.
attn_module = self.unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)
# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(in_features=attn_module.to_q.in_features,
out_features=attn_module.to_q.out_features,
rank=self.lora_rank))
# attn_module.to_k.set_lora_layer(
# LoRALinearLayer(in_features=attn_module.to_k.in_features,
# out_features=attn_module.to_k.out_features,
# rank=self.lora_rank))
# attn_module.to_v.set_lora_layer(
# LoRALinearLayer(in_features=attn_module.to_v.in_features,
# out_features=attn_module.to_v.out_features,
# rank=self.lora_rank))
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=self.lora_rank,
))
attn_module.to_k.requires_grad_(True)
attn_module.to_v.requires_grad_(True)
adapter_parameters.extend(attn_module.to_q.lora_layer.parameters())
adapter_parameters.extend(attn_module.to_k.parameters())
adapter_parameters.extend(attn_module.to_v.parameters())
adapter_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
self.adapter_parameters = adapter_parameters
def set_fully_trainable(self):
in_channels = 8
out_channels = self.unet.conv_in.out_channels
self.unet.register_to_config(in_channels=in_channels)
with torch.no_grad():
new_conv_in = nn.Conv2d(in_channels, out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride,
self.unet.conv_in.padding)
new_conv_in.weight.zero_()
new_conv_in.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)
self.unet.conv_in = new_conv_in
self.unet.requires_grad_(True)
self.adapter_parameters = self.unet.parameters()
def params_to_opt(self):
return itertools.chain(self.resampler.parameters(), self.adapter_parameters)
def forward(self, noisy_latents, timesteps, image_embeds, text_embeds, noise, time_ids, pooled_text_embeds=None):
text_embeds, pooled_text_embeds = self.resampler(text_embeds, pooled_text_embeds=pooled_text_embeds)
unet_added_conditions = {"time_ids": time_ids, 'text_embeds': pooled_text_embeds}
noise_pred = self.unet(noisy_latents, timesteps, text_embeds, added_cond_kwargs=unet_added_conditions).sample
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
return {'total_loss': loss, 'noise_pred': noise_pred}
def encode_text_embeds(self, text_embeds, pooled_text_embeds=None):
text_embeds, pooled_text_embeds = self.resampler(text_embeds, pooled_text_embeds=pooled_text_embeds)
return text_embeds, pooled_text_embeds
@classmethod
def from_pretrained(cls, unet, resampler, pretrained_model_path=None, **kwargs):
model = cls(unet=unet, resampler=resampler, **kwargs)
if pretrained_model_path is not None:
ckpt = torch.load(pretrained_model_path, map_location='cpu')
missing, unexpected = model.load_state_dict(ckpt, strict=False)
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
return model
def init_pipe(self,
vae,
scheduler,
text_encoder,
text_encoder_2,
tokenizer,
tokenizer_2,
dtype=torch.float16,
device='cuda'):
self.device = device
self.dtype = dtype
sdxl_pipe = StableDiffusionXLText2ImageAndEditPipeline(
tokenizer=None,
tokenizer_2=None,
text_encoder=None,
text_encoder_2=None,
vae=vae,
unet=self.unet,
scheduler=scheduler,
)
self.sdxl_pipe = sdxl_pipe
self.sdxl_pipe.to(device, dtype=dtype)
self.tokenizer = tokenizer
self.tokenizer_2 = tokenizer_2
self.text_encoder = text_encoder
self.text_encoder_2 = text_encoder_2
@torch.inference_mode()
def get_text_embeds(self, prompt=None, negative_prompt='', text_embeds=None):
assert int(prompt is not None) + int(text_embeds is not None) == 1
if prompt is not None:
text_input_ids = self.tokenizer([prompt, negative_prompt],
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt").input_ids
text_input_ids_2 = self.tokenizer_2([prompt, negative_prompt],
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt").input_ids
encoder_output = self.text_encoder(text_input_ids.to(self.device), output_hidden_states=True)
text_embeds = encoder_output.hidden_states[-2]
encoder_output_2 = self.text_encoder_2(text_input_ids_2.to(self.device), output_hidden_states=True)
pooled_text_embeds = encoder_output_2[0]
text_embeds_2 = encoder_output_2.hidden_states[-2]
text_embeds = torch.cat([text_embeds, text_embeds_2], dim=-1)
else:
text_input_ids = self.tokenizer(negative_prompt,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt").input_ids
text_input_ids_2 = self.tokenizer_2(negative_prompt,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt").input_ids
encoder_output = self.text_encoder(text_input_ids.to(self.device), output_hidden_states=True)
text_embeds_neg = encoder_output.hidden_states[-2]
encoder_output_2 = self.text_encoder_2(text_input_ids_2.to(self.device), output_hidden_states=True)
text_embeds_neg_2 = encoder_output_2.hidden_states[-2]
pooled_text_embeds = encoder_output_2[0]
text_embeds_neg = torch.cat([text_embeds_neg, text_embeds_neg_2], dim=-1)
text_embeds = torch.cat([text_embeds, text_embeds_neg], dim=0)
text_embeds, pooled_text_embeds = self.encode_text_embeds(text_embeds, pooled_text_embeds=pooled_text_embeds)
text_embeds, text_embeds_neg = text_embeds.chunk(2)
pooled_text_embeds, pooled_text_embeds_neg = pooled_text_embeds.chunk(2)
return text_embeds, text_embeds_neg, pooled_text_embeds, pooled_text_embeds_neg
def generate(self,
prompt=None,
negative_prompt='',
image=None,
text_embeds=None,
seed=42,
height=1024,
width=1024,
guidance_scale=7.5,
num_inference_steps=30,
**kwargs):
text_embeds, text_embeds_neg, pooled_text_embeds, pooled_text_embeds_neg = self.get_text_embeds(
prompt=prompt, negative_prompt=negative_prompt, text_embeds=text_embeds)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
images = self.sdxl_pipe(
image=image,
prompt_embeds=text_embeds,
negative_prompt_embeds=text_embeds_neg,
pooled_prompt_embeds=pooled_text_embeds,
negative_pooled_prompt_embeds=pooled_text_embeds_neg,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
height=height,
width=width,
**kwargs,
).images
return images
class SD21Text2ImageAndEditAdapter(SDXLText2ImageAndEditAdapter):
def forward(self, noisy_latents, timesteps, image_embeds, text_embeds, noise):
text_embeds, _ = self.resampler(text_embeds)
# unet_added_conditions = {"time_ids": time_ids, 'text_embeds': pooled_text_embeds}
noise_pred = self.unet(noisy_latents, timesteps, text_embeds).sample
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
return {'total_loss': loss, 'noise_pred': noise_pred}
def init_pipe(self,
vae,
scheduler,
text_encoder,
tokenizer,
feature_extractor,
dtype=torch.float16,
device='cuda'):
self.device = device
self.dtype = dtype
sd_pipe = StableDiffusionText2ImageAndEditPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
unet=self.unet,
feature_extractor=feature_extractor,
safety_checker=None,
requires_safety_checker=False,
scheduler=scheduler,
)
self.sd_pipe = sd_pipe
self.sd_pipe.to(device, dtype=dtype)
self.tokenizer = tokenizer
self.text_encoder = text_encoder
@torch.inference_mode()
def get_text_embeds(self, prompt=None, negative_prompt='', text_embeds=None):
assert int(prompt is not None) + int(text_embeds is not None) == 1
if prompt is not None:
text_input_ids = self.tokenizer([prompt, negative_prompt],
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt").input_ids
encoder_output = self.text_encoder(text_input_ids.to(self.device))
text_embeds = encoder_output[0]
else:
text_input_ids = self.tokenizer(negative_prompt,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt").input_ids
encoder_output = self.text_encoder(text_input_ids.to(self.device))
text_embeds_neg = encoder_output[0]
text_embeds = torch.cat([text_embeds, text_embeds_neg], dim=0)
text_embeds, _ = self.encode_text_embeds(text_embeds)
text_embeds, text_embeds_neg = text_embeds.chunk(2)
return text_embeds, text_embeds_neg
def generate(self,
prompt=None,
negative_prompt='',
image=None,
text_embeds=None,
seed=42,
height=1024,
width=1024,
guidance_scale=7.5,
num_inference_steps=30,
**kwargs):
text_embeds, text_embeds_neg = self.get_text_embeds(
prompt=prompt, negative_prompt=negative_prompt, text_embeds=text_embeds)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
print(f'text_embeds: {text_embeds.shape}')
print(f'text_embeds_neg: {text_embeds_neg.shape}')
images = self.sd_pipe(
image=image,
prompt_embeds=text_embeds,
negative_prompt_embeds=text_embeds_neg,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
height=height,
width=width,
**kwargs,
).images
return images
class SDXLAdapterWithLatentImage(SDXLAdapter):
def __init__(self, unet, resampler, full_ft=False, set_trainable_late=False) -> None:
nn.Module.__init__(self)
self.unet = unet
self.resampler = resampler
self.full_ft = full_ft
if not set_trainable_late:
self.set_trainable()
def set_trainable(self):
self.resampler.requires_grad_(True)
adapter_parameters = []
in_channels = 8
out_channels = self.unet.conv_in.out_channels
self.unet.register_to_config(in_channels=in_channels)
self.unet.requires_grad_(False)
with torch.no_grad():
new_conv_in = nn.Conv2d(in_channels, out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride,
self.unet.conv_in.padding)
new_conv_in.weight.zero_()
new_conv_in.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)
self.unet.conv_in = new_conv_in
self.unet.conv_in.requires_grad_(True)
if self.full_ft:
self.unet.requires_grad_(True)
adapter_parameters.extend(self.unet.parameters())
else:
adapter_parameters.extend(self.unet.conv_in.parameters())
for name, module in self.unet.named_modules():
if name.endswith('to_k') or name.endswith('to_v'):
if module is not None:
adapter_parameters.extend(module.parameters())
self.adapter_parameters = adapter_parameters
@classmethod
def from_pretrained(cls, unet, resampler, pretrained_model_path=None, set_trainable_late=False, **kwargs):
model = cls(unet=unet, resampler=resampler, set_trainable_late=set_trainable_late, **kwargs)
if pretrained_model_path is not None:
ckpt = torch.load(pretrained_model_path, map_location='cpu')
missing, unexpected = model.load_state_dict(ckpt, strict=False)
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
if set_trainable_late:
model.set_trainable()
return model
def init_pipe(self,
vae,
scheduler,
visual_encoder,
image_transform,
dtype=torch.float16,
device='cuda'):
self.device = device
self.dtype = dtype
sdxl_pipe = StableDiffusionXLText2ImageAndEditPipeline(
tokenizer=None,
tokenizer_2=None,
text_encoder=None,
text_encoder_2=None,
vae=vae,
unet=self.unet,
scheduler=scheduler,
)
self.sdxl_pipe = sdxl_pipe
self.sdxl_pipe.to(device, dtype=dtype)
self.discrete_model = None
self.visual_encoder = visual_encoder.to(self.device, dtype=self.dtype)
self.image_transform = image_transform
def generate(self,
image_pil=None,
image_tensor=None,
image_embeds=None,
latent_image=None,
seed=42,
height=1024,
width=1024,
guidance_scale=7.5,
num_inference_steps=30,
input_image_size=448,
**kwargs):
if image_pil is not None:
assert isinstance(image_pil, Image.Image)
image_prompt_embeds, uncond_image_prompt_embeds, \
pooled_image_prompt_embeds, pooled_uncond_image_prompt_embeds = self.get_image_embeds(
image_pil=image_pil,
image_tensor=image_tensor,
image_embeds=image_embeds,
return_negative=True,
image_size=input_image_size,
)
# print(image_prompt_embeds.shape, pooled_image_prompt_embeds.shape)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
images = self.sdxl_pipe(
image=latent_image,
prompt_embeds=image_prompt_embeds,
negative_prompt_embeds=uncond_image_prompt_embeds,
pooled_prompt_embeds=pooled_image_prompt_embeds,
negative_pooled_prompt_embeds=pooled_uncond_image_prompt_embeds,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
height=height,
width=width,
**kwargs,
).images
return images