File size: 2,331 Bytes
f0e6b7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d89243e
 
 
 
 
 
 
 
 
f0e6b7a
 
 
 
d89243e
 
 
 
 
 
 
 
f280910
 
f0e6b7a
 
 
 
 
f280910
f0e6b7a
 
 
 
 
 
 
 
 
 
 
f280910
 
f0e6b7a
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch

default_num_train_timesteps = 1000


@torch.no_grad()
def make_sigmas(beta_start=0.00085, beta_end=0.012, num_train_timesteps=default_num_train_timesteps, device=None):
    betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32, device=device) ** 2

    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)

    # TODO - would be nice to use a direct expression for this
    sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5

    return sigmas


_with_tqdm = False


def set_with_tqdm(it):
    global _with_tqdm

    _with_tqdm = it


@torch.no_grad()
def rk_ode_solver_diffusion_loop(eps_theta, timesteps, sigmas, x_T, rk_steps_weights):
    x_t = x_T

    iter_over = range(len(timesteps) - 1, -1, -1)

    if _with_tqdm:
        from tqdm import tqdm

        iter_over = tqdm(iter_over)

    for i in iter_over:
        t = timesteps[i].unsqueeze(0)
        sigma = sigmas[t]

        if i == 0:
            eps_hat = eps_theta(x_t=x_t, t=t, sigma=sigma)
            x_0_hat = x_t - sigma * eps_hat
        else:
            dt = sigmas[timesteps[i - 1]] - sigma

            dx_by_dt = torch.zeros_like(x_t)
            dx_by_dt_cur = torch.zeros_like(x_t)

            for rk_step, rk_weight in rk_steps_weights:
                dt_ = dt * rk_step
                t_ = t + dt_
                x_t_ = x_t + dx_by_dt_cur * dt_
                eps_hat = eps_theta(x_t=x_t_, t=t_, sigma=sigma)
                # TODO - note which specific ode this is the solution to and
                # how input scaling does/doesn't effect the solution
                # dx_by_dt_cur = (x_t_ - sigma * eps_hat) / sigma
                dx_by_dt_cur = eps_hat
                dx_by_dt += dx_by_dt_cur * rk_weight

            x_t_minus_1 = x_t + dx_by_dt * dt

            x_t = x_t_minus_1

    return x_0_hat


euler_ode_solver_diffusion_loop = lambda *args, **kwargs: rk_ode_solver_diffusion_loop(*args, **kwargs, rk_steps_weights=[[0, 1]])

heun_ode_solver_diffusion_loop = lambda *args, **kwargs: rk_ode_solver_diffusion_loop(*args, **kwargs, rk_steps_weights=[[0, 0.5], [1, 0.5]])

rk4_ode_solver_diffusion_loop = lambda *args, **kwargs: rk_ode_solver_diffusion_loop(*args, **kwargs, rk_steps_weights=[[0, 1 / 6], [1 / 2, 1 / 3], [1 / 2, 1 / 3], [1, 1 / 6]])