Spaces:
Running
Running
import torch | |
from utils.manifolds import Sphere | |
from tqdm.auto import tqdm | |
def riemannian_flow_sampler( | |
net, | |
batch, | |
manifold=Sphere(), | |
conditioning_keys=None, | |
scheduler=None, | |
num_steps=250, | |
cfg_rate=0, | |
generator=None, | |
return_trajectories=False, | |
): | |
if scheduler is None: | |
raise ValueError("Scheduler must be provided") | |
x_cur = batch["y"].to(torch.float32) | |
if return_trajectories: | |
traj = [x_cur.detach()] | |
step_indices = torch.arange(num_steps + 1, dtype=torch.float32, device=x_cur.device) | |
steps = 1 - step_indices / num_steps | |
gammas = scheduler(steps) | |
dtype = torch.float32 | |
if cfg_rate > 0 and conditioning_keys is not None: | |
stacked_batch = {} | |
stacked_batch[conditioning_keys] = torch.cat( | |
[batch[conditioning_keys], torch.zeros_like(batch[conditioning_keys])], | |
dim=0, | |
) | |
for step, (gamma_now, gamma_next) in enumerate(zip(gammas[:-1], gammas[1:])): | |
with torch.cuda.amp.autocast(dtype=dtype): | |
if cfg_rate > 0 and conditioning_keys is not None: | |
stacked_batch["y"] = torch.cat([x_cur, x_cur], dim=0) | |
stacked_batch["gamma"] = gamma_now.expand(x_cur.shape[0] * 2) | |
denoised_all = net(stacked_batch) | |
denoised_cond, denoised_uncond = denoised_all.chunk(2, dim=0) | |
denoised = denoised_cond * (1 + cfg_rate) - denoised_uncond * cfg_rate | |
else: | |
batch["y"] = x_cur | |
batch["gamma"] = gamma_now.expand(x_cur.shape[0]) | |
denoised = net(batch) | |
dt = gamma_next - gamma_now | |
x_next = x_cur + dt * denoised # manifold.expmap(x_cur, dt * denoised) | |
x_next = manifold.projx(x_next) | |
x_cur = x_next | |
if return_trajectories: | |
traj.append(x_cur.detach().to(torch.float32)) | |
if return_trajectories: | |
return x_cur.to(torch.float32), traj | |
else: | |
return x_cur.to(torch.float32) | |
def ode_riemannian_flow_sampler( | |
odefunc, | |
x_1, | |
manifold=Sphere(), | |
scheduler=None, | |
num_steps=1000, | |
): | |
if scheduler is None: | |
raise ValueError("Scheduler must be provided") | |
x_cur = x_1.to(torch.float32) | |
steps = ( | |
torch.arange(num_steps + 1, dtype=torch.float32, device=x_cur.device) | |
/ num_steps | |
) | |
dtype = torch.float32 | |
for step, (t_now, t_next) in enumerate(zip(steps[:-1], steps[1:]), total=num_steps): | |
with torch.cuda.amp.autocast(dtype=dtype): | |
denoised = odefunc(t_now, x_cur) | |
gamma_now = scheduler(t_now) | |
gamma_next = scheduler(t_next) | |
dt = gamma_next - gamma_now | |
x_next = x_cur + dt * denoised # manifold.expmap(x_cur, dt * denoised) | |
x_next = manifold.projx(x_next) | |
x_cur = x_next | |
return x_cur.to(torch.float32) | |