File size: 5,147 Bytes
2ce7b1a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
from jax import random
import jax.numpy as np
from jax.nn.initializers import lecun_normal
from jax.numpy.linalg import eigh
def make_HiPPO(N):
""" Create a HiPPO-LegS matrix.
N (int32): state size
N x N HiPPO LegS matrix
P = np.sqrt(1 + 2 * np.arange(N))
A = P[:, np.newaxis] * P[np.newaxis, :]
A = np.tril(A) - np.diag(np.arange(N))
return -A
def make_NPLR_HiPPO(N):
Makes components needed for NPLR representation of HiPPO-LegS
N (int32): state size
N x N HiPPO LegS matrix, low-rank factor P, HiPPO input matrix B
# Make -HiPPO
hippo = make_HiPPO(N)
# Add in a rank 1 term. Makes it Normal.
P = np.sqrt(np.arange(N) + 0.5)
# HiPPO also specifies the B matrix
B = np.sqrt(2 * np.arange(N) + 1.0)
return hippo, P, B
def make_DPLR_HiPPO(N):
Makes components needed for DPLR representation of HiPPO-LegS
Note, we will only use the diagonal part
eigenvalues Lambda, low-rank term P, conjugated HiPPO input matrix B,
eigenvectors V, HiPPO B pre-conjugation
A, P, B = make_NPLR_HiPPO(N)
S = A + P[:, np.newaxis] * P[np.newaxis, :]
S_diag = np.diagonal(S)
Lambda_real = np.mean(S_diag) * np.ones_like(S_diag)
# Diagonalize S to V \Lambda V^*
Lambda_imag, V = eigh(S * -1j)
P = V.conj().T @ P
B_orig = B
B = V.conj().T @ B
return Lambda_real + 1j * Lambda_imag, P, B, V, B_orig
def log_step_initializer(dt_min=0.001, dt_max=0.1):
""" Initialize the learnable timescale Delta by sampling
uniformly between dt_min and dt_max.
dt_min (float32): minimum value
dt_max (float32): maximum value
init function
def init(key, shape):
""" Init function
key: jax random key
shape tuple: desired shape
sampled log_step (float32)
return random.uniform(key, shape) * (
np.log(dt_max) - np.log(dt_min)
) + np.log(dt_min)
return init
def init_log_steps(key, input):
""" Initialize an array of learnable timescale parameters
key: jax random key
input: tuple containing the array shape H and
dt_min and dt_max
initialized array of timescales (float32): (H,)
H, dt_min, dt_max = input
log_steps = []
for i in range(H):
key, skey = random.split(key)
log_step = log_step_initializer(dt_min=dt_min, dt_max=dt_max)(skey, shape=(1,))
return np.array(log_steps)
def init_VinvB(init_fun, rng, shape, Vinv):
""" Initialize B_tilde=V^{-1}B. First samples B. Then compute V^{-1}B.
Note we will parameterize this with two different matrices for complex
init_fun: the initialization function to use, e.g. lecun_normal()
rng: jax random key to be used with init function.
shape (tuple): desired shape (P,H)
Vinv: (complex64) the inverse eigenvectors used for initialization
B_tilde (complex64) of shape (P,H,2)
B = init_fun(rng, shape)
VinvB = Vinv @ B
VinvB_real = VinvB.real
VinvB_imag = VinvB.imag
return np.concatenate((VinvB_real[..., None], VinvB_imag[..., None]), axis=-1)
def trunc_standard_normal(key, shape):
""" Sample C with a truncated normal distribution with standard deviation 1.
key: jax random key
shape (tuple): desired shape, of length 3, (H,P,_)
sampled C matrix (float32) of shape (H,P,2) (for complex parameterization)
H, P, _ = shape
Cs = []
for i in range(H):
key, skey = random.split(key)
C = lecun_normal()(skey, shape=(1, P, 2))
return np.array(Cs)[:, 0]
def init_CV(init_fun, rng, shape, V):
""" Initialize C_tilde=CV. First sample C. Then compute CV.
Note we will parameterize this with two different matrices for complex
init_fun: the initialization function to use, e.g. lecun_normal()
rng: jax random key to be used with init function.
shape (tuple): desired shape (H,P)
V: (complex64) the eigenvectors used for initialization
C_tilde (complex64) of shape (H,P,2)
C_ = init_fun(rng, shape)
C = C_[..., 0] + 1j * C_[..., 1]
CV = C @ V
CV_real = CV.real
CV_imag = CV.imag
return np.concatenate((CV_real[..., None], CV_imag[..., None]), axis=-1)