|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.distributed as dist |
|
from einops import rearrange |
|
from torch.distributions import LogisticNormal |
|
from tqdm import tqdm |
|
|
|
|
|
def _extract_into_tensor(arr, timesteps, broadcast_shape): |
|
""" |
|
Extract values from a 1-D numpy array for a batch of indices. |
|
:param arr: the 1-D numpy array. |
|
:param timesteps: a tensor of indices into the array to extract. |
|
:param broadcast_shape: a larger shape of K dimensions with the batch |
|
dimension equal to the length of timesteps. |
|
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. |
|
""" |
|
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float() |
|
while len(res.shape) < len(broadcast_shape): |
|
res = res[..., None] |
|
return res + torch.zeros(broadcast_shape, device=timesteps.device) |
|
|
|
|
|
def mean_flat(tensor: torch.Tensor, mask=None): |
|
""" |
|
Take the mean over all non-batch dimensions. |
|
""" |
|
if mask is None: |
|
return tensor.mean(dim=list(range(1, len(tensor.shape)))) |
|
else: |
|
assert tensor.dim() == 5 |
|
assert tensor.shape[2] == mask.shape[1] |
|
tensor = rearrange(tensor, "b c t h w -> b t (c h w)") |
|
denom = mask.sum(dim=1) * tensor.shape[-1] |
|
loss = (tensor * mask.unsqueeze(2)).sum(dim=1).sum(dim=1) / denom |
|
return loss |
|
|
|
|
|
def timestep_transform( |
|
t, |
|
model_kwargs, |
|
base_resolution=512 * 512, |
|
base_num_frames=1, |
|
scale=1.0, |
|
num_timesteps=1, |
|
): |
|
t = t / num_timesteps |
|
resolution = model_kwargs["height"] * model_kwargs["width"] |
|
ratio_space = (resolution / base_resolution).sqrt() |
|
|
|
|
|
if model_kwargs["num_frames"][0] == 1: |
|
num_frames = torch.ones_like(model_kwargs["num_frames"]) |
|
else: |
|
num_frames = model_kwargs["num_frames"] // 17 * 5 |
|
ratio_time = (num_frames / base_num_frames).sqrt() |
|
|
|
ratio = ratio_space * ratio_time * scale |
|
new_t = ratio * t / (1 + (ratio - 1) * t) |
|
|
|
new_t = new_t * num_timesteps |
|
return new_t |
|
|
|
|
|
class RFlowScheduler: |
|
def __init__( |
|
self, |
|
num_timesteps=1000, |
|
num_sampling_steps=10, |
|
use_discrete_timesteps=False, |
|
sample_method="uniform", |
|
loc=0.0, |
|
scale=1.0, |
|
use_timestep_transform=False, |
|
transform_scale=1.0, |
|
): |
|
self.num_timesteps = num_timesteps |
|
self.num_sampling_steps = num_sampling_steps |
|
self.use_discrete_timesteps = use_discrete_timesteps |
|
|
|
|
|
assert sample_method in ["uniform", "logit-normal"] |
|
assert ( |
|
sample_method == "uniform" or not use_discrete_timesteps |
|
), "Only uniform sampling is supported for discrete timesteps" |
|
self.sample_method = sample_method |
|
if sample_method == "logit-normal": |
|
self.distribution = LogisticNormal(torch.tensor([loc]), torch.tensor([scale])) |
|
self.sample_t = lambda x: self.distribution.sample((x.shape[0],))[:, 0].to(x.device) |
|
|
|
|
|
self.use_timestep_transform = use_timestep_transform |
|
self.transform_scale = transform_scale |
|
|
|
def training_losses(self, model, x_start, model_kwargs=None, noise=None, mask=None, weights=None, t=None): |
|
""" |
|
Compute training losses for a single timestep. |
|
Arguments format copied from opensora/schedulers/iddpm/gaussian_diffusion.py/training_losses |
|
Note: t is int tensor and should be rescaled from [0, num_timesteps-1] to [1,0] |
|
""" |
|
if t is None: |
|
if self.use_discrete_timesteps: |
|
t = torch.randint(0, self.num_timesteps, (x_start.shape[0],), device=x_start.device) |
|
elif self.sample_method == "uniform": |
|
t = torch.rand((x_start.shape[0],), device=x_start.device) * self.num_timesteps |
|
elif self.sample_method == "logit-normal": |
|
t = self.sample_t(x_start) * self.num_timesteps |
|
|
|
if self.use_timestep_transform: |
|
t = timestep_transform(t, model_kwargs, scale=self.transform_scale, num_timesteps=self.num_timesteps) |
|
|
|
if model_kwargs is None: |
|
model_kwargs = {} |
|
if noise is None: |
|
noise = torch.randn_like(x_start) |
|
assert noise.shape == x_start.shape |
|
|
|
x_t = self.add_noise(x_start, noise, t) |
|
if mask is not None: |
|
t0 = torch.zeros_like(t) |
|
x_t0 = self.add_noise(x_start, noise, t0) |
|
x_t = torch.where(mask[:, None, :, None, None], x_t, x_t0) |
|
|
|
terms = {} |
|
model_output = model(x_t, t, **model_kwargs) |
|
velocity_pred = model_output.chunk(2, dim=1)[0] |
|
if weights is None: |
|
loss = mean_flat((velocity_pred - (x_start - noise)).pow(2), mask=mask) |
|
else: |
|
weight = _extract_into_tensor(weights, t, x_start.shape) |
|
loss = mean_flat(weight * (velocity_pred - (x_start - noise)).pow(2), mask=mask) |
|
terms["loss"] = loss |
|
|
|
return terms |
|
|
|
def add_noise( |
|
self, |
|
original_samples: torch.FloatTensor, |
|
noise: torch.FloatTensor, |
|
timesteps: torch.IntTensor, |
|
) -> torch.FloatTensor: |
|
""" |
|
compatible with diffusers add_noise() |
|
""" |
|
timepoints = timesteps.float() / self.num_timesteps |
|
timepoints = 1 - timepoints |
|
|
|
|
|
|
|
timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1) |
|
timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4]) |
|
|
|
return timepoints * original_samples + (1 - timepoints) * noise |
|
|
|
|
|
class RFLOW: |
|
def __init__( |
|
self, |
|
num_sampling_steps=10, |
|
num_timesteps=1000, |
|
cfg_scale=4.0, |
|
use_discrete_timesteps=False, |
|
use_timestep_transform=False, |
|
**kwargs, |
|
): |
|
self.num_sampling_steps = num_sampling_steps |
|
self.num_timesteps = num_timesteps |
|
self.cfg_scale = cfg_scale |
|
self.use_discrete_timesteps = use_discrete_timesteps |
|
self.use_timestep_transform = use_timestep_transform |
|
|
|
self.scheduler = RFlowScheduler( |
|
num_timesteps=num_timesteps, |
|
num_sampling_steps=num_sampling_steps, |
|
use_discrete_timesteps=use_discrete_timesteps, |
|
use_timestep_transform=use_timestep_transform, |
|
**kwargs, |
|
) |
|
|
|
def sample( |
|
self, |
|
model, |
|
z, |
|
model_args, |
|
y_null, |
|
device, |
|
mask=None, |
|
guidance_scale=None, |
|
progress=True, |
|
verbose=False, |
|
): |
|
|
|
if guidance_scale is None: |
|
guidance_scale = self.cfg_scale |
|
|
|
|
|
model_args["y"] = torch.cat([model_args["y"], y_null], 0) |
|
|
|
|
|
timesteps = [(1.0 - i / self.num_sampling_steps) * self.num_timesteps for i in range(self.num_sampling_steps)] |
|
if self.use_discrete_timesteps: |
|
timesteps = [int(round(t)) for t in timesteps] |
|
timesteps = [torch.tensor([t] * z.shape[0], device=device) for t in timesteps] |
|
if self.use_timestep_transform: |
|
timesteps = [timestep_transform(t, model_args, num_timesteps=self.num_timesteps) for t in timesteps] |
|
|
|
if mask is not None: |
|
noise_added = torch.zeros_like(mask, dtype=torch.bool) |
|
noise_added = noise_added | (mask == 1) |
|
|
|
progress_wrap = tqdm if progress and dist.get_rank() == 0 else (lambda x: x) |
|
|
|
dtype = model.x_embedder.proj.weight.dtype |
|
all_timesteps = [int(t.to(dtype).item()) for t in timesteps] |
|
for i, t in progress_wrap(list(enumerate(timesteps))): |
|
|
|
if mask is not None: |
|
mask_t = mask * self.num_timesteps |
|
x0 = z.clone() |
|
x_noise = self.scheduler.add_noise(x0, torch.randn_like(x0), t) |
|
|
|
mask_t_upper = mask_t >= t.unsqueeze(1) |
|
model_args["x_mask"] = mask_t_upper.repeat(2, 1) |
|
mask_add_noise = mask_t_upper & ~noise_added |
|
|
|
z = torch.where(mask_add_noise[:, None, :, None, None], x_noise, x0) |
|
noise_added = mask_t_upper |
|
|
|
|
|
z_in = torch.cat([z, z], 0) |
|
t = torch.cat([t, t], 0) |
|
|
|
|
|
output = model(z_in, t, all_timesteps, **model_args) |
|
|
|
pred = output.chunk(2, dim=1)[0] |
|
pred_cond, pred_uncond = pred.chunk(2, dim=0) |
|
v_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) |
|
|
|
|
|
dt = timesteps[i] - timesteps[i + 1] if i < len(timesteps) - 1 else timesteps[i] |
|
dt = dt / self.num_timesteps |
|
z = z + v_pred * dt[:, None, None, None, None] |
|
|
|
if mask is not None: |
|
z = torch.where(mask_t_upper[:, None, :, None, None], z, x0) |
|
|
|
return z |
|
|
|
def training_losses(self, model, x_start, model_kwargs=None, noise=None, mask=None, weights=None, t=None): |
|
return self.scheduler.training_losses(model, x_start, model_kwargs, noise, mask, weights, t) |
|
|