S5 / Essay_classifier /s5 /ssm_init.py
dbal0503's picture
Upload 693 files
2ce7b1a
from jax import random
import jax.numpy as np
from jax.nn.initializers import lecun_normal
from jax.numpy.linalg import eigh
def make_HiPPO(N):
""" Create a HiPPO-LegS matrix.
From https://github.com/srush/annotated-s4/blob/main/s4/s4.py
Args:
N (int32): state size
Returns:
N x N HiPPO LegS matrix
"""
P = np.sqrt(1 + 2 * np.arange(N))
A = P[:, np.newaxis] * P[np.newaxis, :]
A = np.tril(A) - np.diag(np.arange(N))
return -A
def make_NPLR_HiPPO(N):
"""
Makes components needed for NPLR representation of HiPPO-LegS
From https://github.com/srush/annotated-s4/blob/main/s4/s4.py
Args:
N (int32): state size
Returns:
N x N HiPPO LegS matrix, low-rank factor P, HiPPO input matrix B
"""
# Make -HiPPO
hippo = make_HiPPO(N)
# Add in a rank 1 term. Makes it Normal.
P = np.sqrt(np.arange(N) + 0.5)
# HiPPO also specifies the B matrix
B = np.sqrt(2 * np.arange(N) + 1.0)
return hippo, P, B
def make_DPLR_HiPPO(N):
"""
Makes components needed for DPLR representation of HiPPO-LegS
From https://github.com/srush/annotated-s4/blob/main/s4/s4.py
Note, we will only use the diagonal part
Args:
N:
Returns:
eigenvalues Lambda, low-rank term P, conjugated HiPPO input matrix B,
eigenvectors V, HiPPO B pre-conjugation
"""
A, P, B = make_NPLR_HiPPO(N)
S = A + P[:, np.newaxis] * P[np.newaxis, :]
S_diag = np.diagonal(S)
Lambda_real = np.mean(S_diag) * np.ones_like(S_diag)
# Diagonalize S to V \Lambda V^*
Lambda_imag, V = eigh(S * -1j)
P = V.conj().T @ P
B_orig = B
B = V.conj().T @ B
return Lambda_real + 1j * Lambda_imag, P, B, V, B_orig
def log_step_initializer(dt_min=0.001, dt_max=0.1):
""" Initialize the learnable timescale Delta by sampling
uniformly between dt_min and dt_max.
Args:
dt_min (float32): minimum value
dt_max (float32): maximum value
Returns:
init function
"""
def init(key, shape):
""" Init function
Args:
key: jax random key
shape tuple: desired shape
Returns:
sampled log_step (float32)
"""
return random.uniform(key, shape) * (
np.log(dt_max) - np.log(dt_min)
) + np.log(dt_min)
return init
def init_log_steps(key, input):
""" Initialize an array of learnable timescale parameters
Args:
key: jax random key
input: tuple containing the array shape H and
dt_min and dt_max
Returns:
initialized array of timescales (float32): (H,)
"""
H, dt_min, dt_max = input
log_steps = []
for i in range(H):
key, skey = random.split(key)
log_step = log_step_initializer(dt_min=dt_min, dt_max=dt_max)(skey, shape=(1,))
log_steps.append(log_step)
return np.array(log_steps)
def init_VinvB(init_fun, rng, shape, Vinv):
""" Initialize B_tilde=V^{-1}B. First samples B. Then compute V^{-1}B.
Note we will parameterize this with two different matrices for complex
numbers.
Args:
init_fun: the initialization function to use, e.g. lecun_normal()
rng: jax random key to be used with init function.
shape (tuple): desired shape (P,H)
Vinv: (complex64) the inverse eigenvectors used for initialization
Returns:
B_tilde (complex64) of shape (P,H,2)
"""
B = init_fun(rng, shape)
VinvB = Vinv @ B
VinvB_real = VinvB.real
VinvB_imag = VinvB.imag
return np.concatenate((VinvB_real[..., None], VinvB_imag[..., None]), axis=-1)
def trunc_standard_normal(key, shape):
""" Sample C with a truncated normal distribution with standard deviation 1.
Args:
key: jax random key
shape (tuple): desired shape, of length 3, (H,P,_)
Returns:
sampled C matrix (float32) of shape (H,P,2) (for complex parameterization)
"""
H, P, _ = shape
Cs = []
for i in range(H):
key, skey = random.split(key)
C = lecun_normal()(skey, shape=(1, P, 2))
Cs.append(C)
return np.array(Cs)[:, 0]
def init_CV(init_fun, rng, shape, V):
""" Initialize C_tilde=CV. First sample C. Then compute CV.
Note we will parameterize this with two different matrices for complex
numbers.
Args:
init_fun: the initialization function to use, e.g. lecun_normal()
rng: jax random key to be used with init function.
shape (tuple): desired shape (H,P)
V: (complex64) the eigenvectors used for initialization
Returns:
C_tilde (complex64) of shape (H,P,2)
"""
C_ = init_fun(rng, shape)
C = C_[..., 0] + 1j * C_[..., 1]
CV = C @ V
CV_real = CV.real
CV_imag = CV.imag
return np.concatenate((CV_real[..., None], CV_imag[..., None]), axis=-1)