|
import random |
|
from typing import Callable, Dict, List, Optional |
|
|
|
import torch |
|
from diffusers import DiffusionPipeline |
|
from diffusers.configuration_utils import ConfigMixin |
|
|
|
|
|
class SuperDiffPipeline(DiffusionPipeline, ConfigMixin): |
|
"""SuperDiffPipeline.""" |
|
|
|
def __init__(self, model: Callable, vae: Callable, text_encoder: Callable, scheduler: Callable, tokenizer: Callable, **kwargs) -> None: |
|
"""__init__. |
|
|
|
Parameters |
|
---------- |
|
model : Callable |
|
model |
|
vae : Callable |
|
vae |
|
text_encoder : Callable |
|
text_encoder |
|
scheduler : Callable |
|
scheduler |
|
tokenizer : Callable |
|
tokenizer |
|
kwargs : |
|
kwargs |
|
|
|
Returns |
|
------- |
|
None |
|
|
|
""" |
|
super().__init__() |
|
self.model = model |
|
self.vae = vae |
|
self.text_encoder = text_encoder |
|
self.tokenizer = tokenizer |
|
self.scheduler = scheduler |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
self.vae.to(device) |
|
self.model.to(device) |
|
self.text_encoder.to(device) |
|
|
|
self.register_to_config( |
|
|
|
|
|
|
|
|
|
|
|
device=device, |
|
batch_size=None, |
|
num_inference_steps=None, |
|
guidance_scale=None, |
|
lift=None, |
|
seed=None, |
|
) |
|
|
|
@torch.no_grad |
|
def get_batch(self, latents: Callable, nrow: int, ncol: int) -> Callable: |
|
"""get_batch. |
|
|
|
Parameters |
|
---------- |
|
latents : Callable |
|
latents |
|
nrow : int |
|
nrow |
|
ncol : int |
|
ncol |
|
|
|
Returns |
|
------- |
|
Callable |
|
|
|
""" |
|
image = self.vae.decode( |
|
latents / self.vae.config.scaling_factor, return_dict=False |
|
)[0] |
|
image = (image / 2 + 0.5).clamp(0, 1).squeeze() |
|
if len(image.shape) < 4: |
|
image = image.unsqueeze(0) |
|
image = (image.permute(0, 2, 3, 1) * 255).to(torch.uint8) |
|
return image |
|
|
|
@torch.no_grad |
|
def get_text_embedding(self, prompt: str) -> Callable: |
|
"""get_text_embedding. |
|
|
|
Parameters |
|
---------- |
|
prompt : str |
|
prompt |
|
|
|
Returns |
|
------- |
|
Callable |
|
|
|
""" |
|
text_input = self.tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=self.tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
return self.text_encoder(text_input.input_ids.to(self.device))[0] |
|
|
|
@torch.no_grad |
|
def get_vel(self, t: float, sigma: float, latents: Callable, embeddings: Callable): |
|
"""get_vel. |
|
|
|
Parameters |
|
---------- |
|
t : float |
|
t |
|
sigma : float |
|
sigma |
|
latents : Callable |
|
latents |
|
embeddings : Callable |
|
embeddings |
|
""" |
|
def v(_x, _e): return self.model( |
|
_x / ((sigma**2 + 1) ** 0.5), t, encoder_hidden_states=_e |
|
).sample |
|
embeds = torch.cat(embeddings) |
|
latent_input = latents |
|
vel = v(latent_input, embeds) |
|
return vel |
|
|
|
def preprocess( |
|
self, |
|
prompt_1: str, |
|
prompt_2: str, |
|
seed: int = None, |
|
num_inference_steps: int = 1000, |
|
batch_size: int = 1, |
|
lift: int = 0.0, |
|
height: int = 512, |
|
width: int = 512, |
|
guidance_scale: int = 7.5, |
|
) -> Callable: |
|
"""preprocess. |
|
|
|
Parameters |
|
---------- |
|
prompt_1 : str |
|
prompt_1 |
|
prompt_2 : str |
|
prompt_2 |
|
seed : int |
|
seed |
|
num_inference_steps : int |
|
num_inference_steps |
|
batch_size : int |
|
batch_size |
|
lift : int |
|
lift |
|
height : int |
|
height |
|
width : int |
|
width |
|
guidance_scale : int |
|
guidance_scale |
|
|
|
Returns |
|
------- |
|
Callable |
|
|
|
""" |
|
|
|
self.batch_size = batch_size |
|
self.num_inference_steps = num_inference_steps |
|
self.guidance_scale = guidance_scale |
|
self.lift = lift |
|
self.seed = seed |
|
if self.seed is None: |
|
self.seed = random.randint(0, 2**32 - 1) |
|
obj_prompt = [prompt_1] |
|
bg_prompt = [prompt_2] |
|
obj_embeddings = self.get_text_embedding(obj_prompt * batch_size) |
|
bg_embeddings = self.get_text_embedding(bg_prompt * batch_size) |
|
|
|
uncond_embeddings = self.get_text_embedding([""] * batch_size) |
|
|
|
generator = torch.cuda.manual_seed( |
|
self.seed |
|
) |
|
latents = torch.randn( |
|
(batch_size, self.model.config.in_channels, height // 8, width // 8), |
|
generator=generator, |
|
device=self.device, |
|
) |
|
|
|
latents_og = latents.clone().detach() |
|
latents_uncond_og = latents.clone().detach() |
|
|
|
self.scheduler.set_timesteps(num_inference_steps) |
|
latents = latents * self.scheduler.init_noise_sigma |
|
|
|
latents_uncond = latents.clone().detach() |
|
return { |
|
"latents": latents, |
|
"obj_embeddings": obj_embeddings, |
|
"uncond_embeddings": uncond_embeddings, |
|
"bg_embeddings": bg_embeddings, |
|
} |
|
|
|
def _forward(self, model_inputs: Dict) -> Callable: |
|
"""_forward. |
|
|
|
Parameters |
|
---------- |
|
model_inputs : Dict |
|
model_inputs |
|
|
|
Returns |
|
------- |
|
Callable |
|
|
|
""" |
|
latents = model_inputs["latents"] |
|
obj_embeddings = model_inputs["obj_embeddings"] |
|
uncond_embeddings = model_inputs["uncond_embeddings"] |
|
bg_embeddings = model_inputs["bg_embeddings"] |
|
|
|
kappa = 0.5 * torch.ones( |
|
(self.num_inference_steps + 1, self.batch_size), device=self.device |
|
) |
|
ll_obj = torch.ones( |
|
(self.num_inference_steps + 1, self.batch_size), device=self.device |
|
) |
|
ll_bg = torch.ones( |
|
(self.num_inference_steps + 1, self.batch_size), device=self.device |
|
) |
|
ll_uncond = torch.ones( |
|
(self.num_inference_steps + 1, self.batch_size), device=self.device |
|
) |
|
with torch.no_grad(): |
|
for i, t in enumerate(self.scheduler.timesteps): |
|
dsigma = self.scheduler.sigmas[i + |
|
1] - self.scheduler.sigmas[i] |
|
sigma = self.scheduler.sigmas[i] |
|
vel_obj = self.get_vel(t, sigma, latents, [obj_embeddings]) |
|
vel_uncond = self.get_vel( |
|
t, sigma, latents, [uncond_embeddings]) |
|
|
|
vel_bg = self.get_vel(t, sigma, latents, [bg_embeddings]) |
|
noise = torch.sqrt(2 * torch.abs(dsigma) * sigma) * torch.randn_like( |
|
latents |
|
) |
|
|
|
dx_ind = ( |
|
2 |
|
* dsigma |
|
* (vel_uncond + self.guidance_scale * (vel_bg - vel_uncond)) |
|
+ noise |
|
) |
|
kappa[i + 1] = ( |
|
(torch.abs(dsigma) * (vel_bg - vel_obj) * (vel_bg + vel_obj)).sum( |
|
(1, 2, 3) |
|
) |
|
- (dx_ind * ((vel_obj - vel_bg))).sum((1, 2, 3)) |
|
+ sigma * self.lift / self.num_inference_steps |
|
) |
|
kappa[i + 1] /= ( |
|
2 |
|
* dsigma |
|
* self.guidance_scale |
|
* ((vel_obj - vel_bg) ** 2).sum((1, 2, 3)) |
|
) |
|
|
|
vf = vel_uncond + self.guidance_scale * ( |
|
(vel_bg - vel_uncond) |
|
+ kappa[i + 1][:, None, None, None] * (vel_obj - vel_bg) |
|
) |
|
dx = 2 * dsigma * vf + noise |
|
latents += dx |
|
|
|
ll_obj[i + 1] = ll_obj[i] + ( |
|
-torch.abs(dsigma) / sigma * (vel_obj) ** 2 |
|
- (dx * (vel_obj / sigma)) |
|
).sum((1, 2, 3)) |
|
ll_bg[i + 1] = ll_bg[i] + ( |
|
-torch.abs(dsigma) / sigma * (vel_bg) ** 2 - |
|
(dx * (vel_bg / sigma)) |
|
).sum((1, 2, 3)) |
|
|
|
return latents |
|
|
|
def postprocess(self, latents: Callable) -> Callable: |
|
"""postprocess. |
|
|
|
Parameters |
|
---------- |
|
latents : Callable |
|
latents |
|
|
|
Returns |
|
------- |
|
Callable |
|
|
|
""" |
|
image = self.get_batch(latents, 1, self.batch_size) |
|
|
|
assert image.shape[-1] == 3 |
|
|
|
|
|
image = image.to(torch.uint8) |
|
|
|
return image |
|
|
|
def __call__( |
|
self, |
|
prompt_1: str, |
|
prompt_2: str, |
|
seed: int = None, |
|
num_inference_steps: int = 1000, |
|
batch_size: int = 1, |
|
lift: int = 0.0, |
|
height: int = 512, |
|
width: int = 512, |
|
guidance_scale: int = 7.5, |
|
) -> Callable: |
|
"""__call__. |
|
|
|
Parameters |
|
---------- |
|
prompt_1 : str |
|
prompt_1 |
|
prompt_2 : str |
|
prompt_2 |
|
seed : int |
|
seed |
|
num_inference_steps : int |
|
num_inference_steps |
|
batch_size : int |
|
batch_size |
|
lift : int |
|
lift |
|
height : int |
|
height |
|
width : int |
|
width |
|
guidance_scale : int |
|
guidance_scale |
|
|
|
Returns |
|
------- |
|
Callable |
|
|
|
""" |
|
|
|
model_inputs = self.preprocess( |
|
prompt_1, |
|
prompt_2, |
|
seed, |
|
num_inference_steps, |
|
batch_size, |
|
lift, |
|
height, |
|
width, |
|
guidance_scale, |
|
) |
|
|
|
|
|
latents = self._forward(model_inputs) |
|
|
|
|
|
images = self.postprocess(latents) |
|
return images |
|
|