Plonk / models /samplers /riemannian_flow_sampler.py
nicolas-dufour's picture
squash: merge all unpushed commits
c4c7cee
raw
history blame
2.87 kB
import torch
from utils.manifolds import Sphere
from tqdm.auto import tqdm
def riemannian_flow_sampler(
net,
batch,
manifold=Sphere(),
conditioning_keys=None,
scheduler=None,
num_steps=250,
cfg_rate=0,
generator=None,
return_trajectories=False,
):
if scheduler is None:
raise ValueError("Scheduler must be provided")
x_cur = batch["y"].to(torch.float32)
if return_trajectories:
traj = [x_cur.detach()]
step_indices = torch.arange(num_steps + 1, dtype=torch.float32, device=x_cur.device)
steps = 1 - step_indices / num_steps
gammas = scheduler(steps)
dtype = torch.float32
if cfg_rate > 0 and conditioning_keys is not None:
stacked_batch = {}
stacked_batch[conditioning_keys] = torch.cat(
[batch[conditioning_keys], torch.zeros_like(batch[conditioning_keys])],
dim=0,
)
for step, (gamma_now, gamma_next) in enumerate(zip(gammas[:-1], gammas[1:])):
with torch.cuda.amp.autocast(dtype=dtype):
if cfg_rate > 0 and conditioning_keys is not None:
stacked_batch["y"] = torch.cat([x_cur, x_cur], dim=0)
stacked_batch["gamma"] = gamma_now.expand(x_cur.shape[0] * 2)
denoised_all = net(stacked_batch)
denoised_cond, denoised_uncond = denoised_all.chunk(2, dim=0)
denoised = denoised_cond * (1 + cfg_rate) - denoised_uncond * cfg_rate
else:
batch["y"] = x_cur
batch["gamma"] = gamma_now.expand(x_cur.shape[0])
denoised = net(batch)
dt = gamma_next - gamma_now
x_next = x_cur + dt * denoised # manifold.expmap(x_cur, dt * denoised)
x_next = manifold.projx(x_next)
x_cur = x_next
if return_trajectories:
traj.append(x_cur.detach().to(torch.float32))
if return_trajectories:
return x_cur.to(torch.float32), traj
else:
return x_cur.to(torch.float32)
def ode_riemannian_flow_sampler(
odefunc,
x_1,
manifold=Sphere(),
scheduler=None,
num_steps=1000,
):
if scheduler is None:
raise ValueError("Scheduler must be provided")
x_cur = x_1.to(torch.float32)
steps = (
torch.arange(num_steps + 1, dtype=torch.float32, device=x_cur.device)
/ num_steps
)
dtype = torch.float32
for step, (t_now, t_next) in enumerate(zip(steps[:-1], steps[1:]), total=num_steps):
with torch.cuda.amp.autocast(dtype=dtype):
denoised = odefunc(t_now, x_cur)
gamma_now = scheduler(t_now)
gamma_next = scheduler(t_next)
dt = gamma_next - gamma_now
x_next = x_cur + dt * denoised # manifold.expmap(x_cur, dt * denoised)
x_next = manifold.projx(x_next)
x_cur = x_next
return x_cur.to(torch.float32)