from jax import random |
import jax.numpy as np |
from jax.nn.initializers import lecun_normal |
from jax.numpy.linalg import eigh |
def make_HiPPO(N): |
""" Create a HiPPO-LegS matrix. |
From https://github.com/srush/annotated-s4/blob/main/s4/s4.py |
Args: |
N (int32): state size |
Returns: |
N x N HiPPO LegS matrix |
""" |
P = np.sqrt(1 + 2 * np.arange(N)) |
A = P[:, np.newaxis] * P[np.newaxis, :] |
A = np.tril(A) - np.diag(np.arange(N)) |
return -A |
def make_NPLR_HiPPO(N): |
""" |
Makes components needed for NPLR representation of HiPPO-LegS |
From https://github.com/srush/annotated-s4/blob/main/s4/s4.py |
Args: |
N (int32): state size |
Returns: |
N x N HiPPO LegS matrix, low-rank factor P, HiPPO input matrix B |
""" |
hippo = make_HiPPO(N) |
P = np.sqrt(np.arange(N) + 0.5) |
B = np.sqrt(2 * np.arange(N) + 1.0) |
return hippo, P, B |
def make_DPLR_HiPPO(N): |
""" |
Makes components needed for DPLR representation of HiPPO-LegS |
From https://github.com/srush/annotated-s4/blob/main/s4/s4.py |
Note, we will only use the diagonal part |
Args: |
N: |
Returns: |
eigenvalues Lambda, low-rank term P, conjugated HiPPO input matrix B, |
eigenvectors V, HiPPO B pre-conjugation |
""" |
A, P, B = make_NPLR_HiPPO(N) |
S = A + P[:, np.newaxis] * P[np.newaxis, :] |
S_diag = np.diagonal(S) |
Lambda_real = np.mean(S_diag) * np.ones_like(S_diag) |
Lambda_imag, V = eigh(S * -1j) |
P = V.conj().T @ P |
B_orig = B |
B = V.conj().T @ B |
return Lambda_real + 1j * Lambda_imag, P, B, V, B_orig |
def log_step_initializer(dt_min=0.001, dt_max=0.1): |
""" Initialize the learnable timescale Delta by sampling |
uniformly between dt_min and dt_max. |
Args: |
dt_min (float32): minimum value |
dt_max (float32): maximum value |
Returns: |
init function |
""" |
def init(key, shape): |
""" Init function |
Args: |
key: jax random key |
shape tuple: desired shape |
Returns: |
sampled log_step (float32) |
""" |
return random.uniform(key, shape) * ( |
np.log(dt_max) - np.log(dt_min) |
) + np.log(dt_min) |
return init |
def init_log_steps(key, input): |
""" Initialize an array of learnable timescale parameters |
Args: |
key: jax random key |
input: tuple containing the array shape H and |
dt_min and dt_max |
Returns: |
initialized array of timescales (float32): (H,) |
""" |
H, dt_min, dt_max = input |
log_steps = [] |
for i in range(H): |
key, skey = random.split(key) |
log_step = log_step_initializer(dt_min=dt_min, dt_max=dt_max)(skey, shape=(1,)) |
log_steps.append(log_step) |
return np.array(log_steps) |
def init_VinvB(init_fun, rng, shape, Vinv): |
""" Initialize B_tilde=V^{-1}B. First samples B. Then compute V^{-1}B. |
Note we will parameterize this with two different matrices for complex |
numbers. |
Args: |
init_fun: the initialization function to use, e.g. lecun_normal() |
rng: jax random key to be used with init function. |
shape (tuple): desired shape (P,H) |
Vinv: (complex64) the inverse eigenvectors used for initialization |
Returns: |
B_tilde (complex64) of shape (P,H,2) |
""" |
B = init_fun(rng, shape) |
VinvB = Vinv @ B |
VinvB_real = VinvB.real |
VinvB_imag = VinvB.imag |
return np.concatenate((VinvB_real[..., None], VinvB_imag[..., None]), axis=-1) |
def trunc_standard_normal(key, shape): |
""" Sample C with a truncated normal distribution with standard deviation 1. |
Args: |
key: jax random key |
shape (tuple): desired shape, of length 3, (H,P,_) |
Returns: |
sampled C matrix (float32) of shape (H,P,2) (for complex parameterization) |
""" |
H, P, _ = shape |
Cs = [] |
for i in range(H): |
key, skey = random.split(key) |
C = lecun_normal()(skey, shape=(1, P, 2)) |
Cs.append(C) |
return np.array(Cs)[:, 0] |
def init_CV(init_fun, rng, shape, V): |
""" Initialize C_tilde=CV. First sample C. Then compute CV. |
Note we will parameterize this with two different matrices for complex |
numbers. |
Args: |
init_fun: the initialization function to use, e.g. lecun_normal() |
rng: jax random key to be used with init function. |
shape (tuple): desired shape (H,P) |
V: (complex64) the eigenvectors used for initialization |
Returns: |
C_tilde (complex64) of shape (H,P,2) |
""" |
C_ = init_fun(rng, shape) |
C = C_[..., 0] + 1j * C_[..., 1] |
CV = C @ V |
CV_real = CV.real |
CV_imag = CV.imag |
return np.concatenate((CV_real[..., None], CV_imag[..., None]), axis=-1) |