|
from jax import random |
|
import jax.numpy as np |
|
from jax.nn.initializers import lecun_normal |
|
from jax.numpy.linalg import eigh |
|
|
|
|
|
def make_HiPPO(N): |
|
""" Create a HiPPO-LegS matrix. |
|
From https://github.com/srush/annotated-s4/blob/main/s4/s4.py |
|
Args: |
|
N (int32): state size |
|
Returns: |
|
N x N HiPPO LegS matrix |
|
""" |
|
P = np.sqrt(1 + 2 * np.arange(N)) |
|
A = P[:, np.newaxis] * P[np.newaxis, :] |
|
A = np.tril(A) - np.diag(np.arange(N)) |
|
return -A |
|
|
|
|
|
def make_NPLR_HiPPO(N): |
|
""" |
|
Makes components needed for NPLR representation of HiPPO-LegS |
|
From https://github.com/srush/annotated-s4/blob/main/s4/s4.py |
|
Args: |
|
N (int32): state size |
|
|
|
Returns: |
|
N x N HiPPO LegS matrix, low-rank factor P, HiPPO input matrix B |
|
|
|
""" |
|
|
|
hippo = make_HiPPO(N) |
|
|
|
|
|
P = np.sqrt(np.arange(N) + 0.5) |
|
|
|
|
|
B = np.sqrt(2 * np.arange(N) + 1.0) |
|
return hippo, P, B |
|
|
|
|
|
def make_DPLR_HiPPO(N): |
|
""" |
|
Makes components needed for DPLR representation of HiPPO-LegS |
|
From https://github.com/srush/annotated-s4/blob/main/s4/s4.py |
|
Note, we will only use the diagonal part |
|
Args: |
|
N: |
|
|
|
Returns: |
|
eigenvalues Lambda, low-rank term P, conjugated HiPPO input matrix B, |
|
eigenvectors V, HiPPO B pre-conjugation |
|
|
|
""" |
|
A, P, B = make_NPLR_HiPPO(N) |
|
|
|
S = A + P[:, np.newaxis] * P[np.newaxis, :] |
|
|
|
S_diag = np.diagonal(S) |
|
Lambda_real = np.mean(S_diag) * np.ones_like(S_diag) |
|
|
|
|
|
Lambda_imag, V = eigh(S * -1j) |
|
|
|
P = V.conj().T @ P |
|
B_orig = B |
|
B = V.conj().T @ B |
|
return Lambda_real + 1j * Lambda_imag, P, B, V, B_orig |
|
|
|
|
|
def log_step_initializer(dt_min=0.001, dt_max=0.1): |
|
""" Initialize the learnable timescale Delta by sampling |
|
uniformly between dt_min and dt_max. |
|
Args: |
|
dt_min (float32): minimum value |
|
dt_max (float32): maximum value |
|
Returns: |
|
init function |
|
""" |
|
def init(key, shape): |
|
""" Init function |
|
Args: |
|
key: jax random key |
|
shape tuple: desired shape |
|
Returns: |
|
sampled log_step (float32) |
|
""" |
|
return random.uniform(key, shape) * ( |
|
np.log(dt_max) - np.log(dt_min) |
|
) + np.log(dt_min) |
|
|
|
return init |
|
|
|
|
|
def init_log_steps(key, input): |
|
""" Initialize an array of learnable timescale parameters |
|
Args: |
|
key: jax random key |
|
input: tuple containing the array shape H and |
|
dt_min and dt_max |
|
Returns: |
|
initialized array of timescales (float32): (H,) |
|
""" |
|
H, dt_min, dt_max = input |
|
log_steps = [] |
|
for i in range(H): |
|
key, skey = random.split(key) |
|
log_step = log_step_initializer(dt_min=dt_min, dt_max=dt_max)(skey, shape=(1,)) |
|
log_steps.append(log_step) |
|
|
|
return np.array(log_steps) |
|
|
|
|
|
def init_VinvB(init_fun, rng, shape, Vinv): |
|
""" Initialize B_tilde=V^{-1}B. First samples B. Then compute V^{-1}B. |
|
Note we will parameterize this with two different matrices for complex |
|
numbers. |
|
Args: |
|
init_fun: the initialization function to use, e.g. lecun_normal() |
|
rng: jax random key to be used with init function. |
|
shape (tuple): desired shape (P,H) |
|
Vinv: (complex64) the inverse eigenvectors used for initialization |
|
Returns: |
|
B_tilde (complex64) of shape (P,H,2) |
|
""" |
|
B = init_fun(rng, shape) |
|
VinvB = Vinv @ B |
|
VinvB_real = VinvB.real |
|
VinvB_imag = VinvB.imag |
|
return np.concatenate((VinvB_real[..., None], VinvB_imag[..., None]), axis=-1) |
|
|
|
|
|
def trunc_standard_normal(key, shape): |
|
""" Sample C with a truncated normal distribution with standard deviation 1. |
|
Args: |
|
key: jax random key |
|
shape (tuple): desired shape, of length 3, (H,P,_) |
|
Returns: |
|
sampled C matrix (float32) of shape (H,P,2) (for complex parameterization) |
|
""" |
|
H, P, _ = shape |
|
Cs = [] |
|
for i in range(H): |
|
key, skey = random.split(key) |
|
C = lecun_normal()(skey, shape=(1, P, 2)) |
|
Cs.append(C) |
|
return np.array(Cs)[:, 0] |
|
|
|
|
|
def init_CV(init_fun, rng, shape, V): |
|
""" Initialize C_tilde=CV. First sample C. Then compute CV. |
|
Note we will parameterize this with two different matrices for complex |
|
numbers. |
|
Args: |
|
init_fun: the initialization function to use, e.g. lecun_normal() |
|
rng: jax random key to be used with init function. |
|
shape (tuple): desired shape (H,P) |
|
V: (complex64) the eigenvectors used for initialization |
|
Returns: |
|
C_tilde (complex64) of shape (H,P,2) |
|
""" |
|
C_ = init_fun(rng, shape) |
|
C = C_[..., 0] + 1j * C_[..., 1] |
|
CV = C @ V |
|
CV_real = CV.real |
|
CV_imag = CV.imag |
|
return np.concatenate((CV_real[..., None], CV_imag[..., None]), axis=-1) |
|
|