|
import jax |
|
import jax.numpy as np |
|
from flax import linen as nn |
|
from .layers import SequenceLayer |
|
|
|
|
|
class StackedEncoderModel(nn.Module): |
|
""" Defines a stack of S5 layers to be used as an encoder. |
|
Args: |
|
ssm (nn.Module): the SSM to be used (i.e. S5 ssm) |
|
d_model (int32): this is the feature size of the layer inputs and outputs |
|
we usually refer to this size as H |
|
n_layers (int32): the number of S5 layers to stack |
|
activation (string): Type of activation function to use |
|
dropout (float32): dropout rate |
|
training (bool): whether in training mode or not |
|
prenorm (bool): apply prenorm if true or postnorm if false |
|
batchnorm (bool): apply batchnorm if true or layernorm if false |
|
bn_momentum (float32): the batchnorm momentum if batchnorm is used |
|
step_rescale (float32): allows for uniformly changing the timescale parameter, |
|
e.g. after training on a different resolution for |
|
the speech commands benchmark |
|
""" |
|
ssm: nn.Module |
|
d_model: int |
|
n_layers: int |
|
activation: str = "gelu" |
|
dropout: float = 0.0 |
|
training: bool = True |
|
prenorm: bool = False |
|
batchnorm: bool = False |
|
bn_momentum: float = 0.9 |
|
step_rescale: float = 1.0 |
|
|
|
def setup(self): |
|
""" |
|
Initializes a linear encoder and the stack of S5 layers. |
|
""" |
|
self.encoder = nn.Dense(self.d_model) |
|
self.layers = [ |
|
SequenceLayer( |
|
ssm=self.ssm, |
|
dropout=self.dropout, |
|
d_model=self.d_model, |
|
activation=self.activation, |
|
training=self.training, |
|
prenorm=self.prenorm, |
|
batchnorm=self.batchnorm, |
|
bn_momentum=self.bn_momentum, |
|
step_rescale=self.step_rescale, |
|
) |
|
for _ in range(self.n_layers) |
|
] |
|
|
|
def __call__(self, x, integration_timesteps): |
|
""" |
|
Compute the LxH output of the stacked encoder given an Lxd_input |
|
input sequence. |
|
Args: |
|
x (float32): input sequence (L, d_input) |
|
Returns: |
|
output sequence (float32): (L, d_model) |
|
""" |
|
x = self.encoder(x) |
|
for layer in self.layers: |
|
x = layer(x) |
|
return x |
|
|
|
|
|
def masked_meanpool(x, lengths): |
|
""" |
|
Helper function to perform mean pooling across the sequence length |
|
when sequences have variable lengths. We only want to pool across |
|
the prepadded sequence length. |
|
Args: |
|
x (float32): input sequence (L, d_model) |
|
lengths (int32): the original length of the sequence before padding |
|
Returns: |
|
mean pooled output sequence (float32): (d_model) |
|
""" |
|
L = x.shape[0] |
|
mask = np.arange(L) < lengths |
|
return np.sum(mask[..., None]*x, axis=0)/lengths |
|
|
|
|
|
|
|
batch_masked_meanpool = jax.vmap(masked_meanpool) |
|
|
|
|
|
class ClassificationModel(nn.Module): |
|
""" S5 classificaton sequence model. This consists of the stacked encoder |
|
(which consists of a linear encoder and stack of S5 layers), mean pooling |
|
across the sequence length, a linear decoder, and a softmax operation. |
|
Args: |
|
ssm (nn.Module): the SSM to be used (i.e. S5 ssm) |
|
d_output (int32): the output dimension, i.e. the number of classes |
|
d_model (int32): this is the feature size of the layer inputs and outputs |
|
we usually refer to this size as H |
|
n_layers (int32): the number of S5 layers to stack |
|
padded: (bool): if true: padding was used |
|
activation (string): Type of activation function to use |
|
dropout (float32): dropout rate |
|
training (bool): whether in training mode or not |
|
mode (str): Options: [pool: use mean pooling, last: just take |
|
the last state] |
|
prenorm (bool): apply prenorm if true or postnorm if false |
|
batchnorm (bool): apply batchnorm if true or layernorm if false |
|
bn_momentum (float32): the batchnorm momentum if batchnorm is used |
|
step_rescale (float32): allows for uniformly changing the timescale parameter, |
|
e.g. after training on a different resolution for |
|
the speech commands benchmark |
|
""" |
|
ssm: nn.Module |
|
d_output: int |
|
d_model: int |
|
n_layers: int |
|
padded: bool |
|
activation: str = "gelu" |
|
dropout: float = 0.2 |
|
training: bool = True |
|
mode: str = "pool" |
|
prenorm: bool = False |
|
batchnorm: bool = False |
|
bn_momentum: float = 0.9 |
|
step_rescale: float = 1.0 |
|
|
|
def setup(self): |
|
""" |
|
Initializes the S5 stacked encoder and a linear decoder. |
|
""" |
|
self.encoder = StackedEncoderModel( |
|
ssm=self.ssm, |
|
d_model=self.d_model, |
|
n_layers=self.n_layers, |
|
activation=self.activation, |
|
dropout=self.dropout, |
|
training=self.training, |
|
prenorm=self.prenorm, |
|
batchnorm=self.batchnorm, |
|
bn_momentum=self.bn_momentum, |
|
step_rescale=self.step_rescale, |
|
) |
|
self.decoder = nn.Dense(self.d_output) |
|
|
|
def __call__(self, x, integration_timesteps): |
|
""" |
|
Compute the size d_output log softmax output given a |
|
Lxd_input input sequence. |
|
Args: |
|
x (float32): input sequence (L, d_input) |
|
Returns: |
|
output (float32): (d_output) |
|
""" |
|
if self.padded: |
|
x, length = x |
|
|
|
x = self.encoder(x, integration_timesteps) |
|
if self.mode in ["pool"]: |
|
|
|
if self.padded: |
|
x = masked_meanpool(x, length) |
|
else: |
|
x = np.mean(x, axis=0) |
|
|
|
elif self.mode in ["last"]: |
|
|
|
if self.padded: |
|
raise NotImplementedError("Mode must be in ['pool'] for self.padded=True (for now...)") |
|
else: |
|
x = x[-1] |
|
else: |
|
raise NotImplementedError("Mode must be in ['pool', 'last]") |
|
|
|
x = self.decoder(x) |
|
return nn.log_softmax(x, axis=-1) |
|
|
|
|
|
|
|
BatchClassificationModel = nn.vmap( |
|
ClassificationModel, |
|
in_axes=(0, 0), |
|
out_axes=0, |
|
variable_axes={"params": None, "dropout": None, 'batch_stats': None, "cache": 0, "prime": None}, |
|
split_rngs={"params": False, "dropout": True}, axis_name='batch') |
|
|
|
|
|
|
|
class RetrievalDecoder(nn.Module): |
|
""" |
|
Defines the decoder to be used for document matching tasks, |
|
e.g. the AAN task. This is defined as in the S4 paper where we apply |
|
an MLP to a set of 4 features. The features are computed as described in |
|
Tay et al 2020 https://arxiv.org/pdf/2011.04006.pdf. |
|
Args: |
|
d_output (int32): the output dimension, i.e. the number of classes |
|
d_model (int32): this is the feature size of the layer inputs and outputs |
|
we usually refer to this size as H |
|
""" |
|
d_model: int |
|
d_output: int |
|
|
|
def setup(self): |
|
""" |
|
Initializes 2 dense layers to be used for the MLP. |
|
""" |
|
self.layer1 = nn.Dense(self.d_model) |
|
self.layer2 = nn.Dense(self.d_output) |
|
|
|
def __call__(self, x): |
|
""" |
|
Computes the input to be used for the softmax function given a set of |
|
4 features. Note this function operates directly on the batch size. |
|
Args: |
|
x (float32): features (bsz, 4*d_model) |
|
Returns: |
|
output (float32): (bsz, d_output) |
|
""" |
|
x = self.layer1(x) |
|
x = nn.gelu(x) |
|
return self.layer2(x) |
|
|
|
|
|
class RetrievalModel(nn.Module): |
|
""" S5 Retrieval classification model. This consists of the stacked encoder |
|
(which consists of a linear encoder and stack of S5 layers), mean pooling |
|
across the sequence length, constructing 4 features which are fed into a MLP, |
|
and a softmax operation. Note that unlike the standard classification model above, |
|
the apply function of this model operates directly on the batch of data (instead of calling |
|
vmap on this model). |
|
Args: |
|
ssm (nn.Module): the SSM to be used (i.e. S5 ssm) |
|
d_output (int32): the output dimension, i.e. the number of classes |
|
d_model (int32): this is the feature size of the layer inputs and outputs |
|
we usually refer to this size as H |
|
n_layers (int32): the number of S5 layers to stack |
|
padded: (bool): if true: padding was used |
|
activation (string): Type of activation function to use |
|
dropout (float32): dropout rate |
|
training (bool): whether in training mode or not |
|
prenorm (bool): apply prenorm if true or postnorm if false |
|
batchnorm (bool): apply batchnorm if true or layernorm if false |
|
bn_momentum (float32): the batchnorm momentum if batchnorm is used |
|
""" |
|
ssm: nn.Module |
|
d_output: int |
|
d_model: int |
|
n_layers: int |
|
padded: bool |
|
activation: str = "gelu" |
|
dropout: float = 0.2 |
|
training: bool = True |
|
prenorm: bool = False |
|
batchnorm: bool = False |
|
bn_momentum: float = 0.9 |
|
step_rescale: float = 1.0 |
|
|
|
def setup(self): |
|
""" |
|
Initializes the S5 stacked encoder and the retrieval decoder. Note that here we |
|
vmap over the stacked encoder model to work well with the retrieval decoder that |
|
operates directly on the batch. |
|
""" |
|
BatchEncoderModel = nn.vmap( |
|
StackedEncoderModel, |
|
in_axes=(0, 0), |
|
out_axes=0, |
|
variable_axes={"params": None, "dropout": None, 'batch_stats': None, "cache": 0, "prime": None}, |
|
split_rngs={"params": False, "dropout": True}, axis_name='batch' |
|
) |
|
|
|
self.encoder = BatchEncoderModel( |
|
ssm=self.ssm, |
|
d_model=self.d_model, |
|
n_layers=self.n_layers, |
|
activation=self.activation, |
|
dropout=self.dropout, |
|
training=self.training, |
|
prenorm=self.prenorm, |
|
batchnorm=self.batchnorm, |
|
bn_momentum=self.bn_momentum, |
|
step_rescale=self.step_rescale, |
|
) |
|
BatchRetrievalDecoder = nn.vmap( |
|
RetrievalDecoder, |
|
in_axes=0, |
|
out_axes=0, |
|
variable_axes={"params": None}, |
|
split_rngs={"params": False}, |
|
) |
|
|
|
self.decoder = BatchRetrievalDecoder( |
|
d_model=self.d_model, |
|
d_output=self.d_output |
|
) |
|
|
|
def __call__(self, input, integration_timesteps): |
|
""" |
|
Compute the size d_output log softmax output given a |
|
Lxd_input input sequence. The encoded features are constructed as in |
|
Tay et al 2020 https://arxiv.org/pdf/2011.04006.pdf. |
|
Args: |
|
input (float32, int32): tuple of input sequence and prepadded sequence lengths |
|
input sequence is of shape (2*bsz, L, d_input) (includes both documents) and |
|
lengths is (2*bsz,) |
|
Returns: |
|
output (float32): (d_output) |
|
""" |
|
x, lengths = input |
|
x = self.encoder(x, integration_timesteps) |
|
outs = batch_masked_meanpool(x, lengths) |
|
outs0, outs1 = np.split(outs, 2) |
|
features = np.concatenate([outs0, outs1, outs0-outs1, outs0*outs1], axis=-1) |
|
out = self.decoder(features) |
|
return nn.log_softmax(out, axis=-1) |
|
|