S5: Simplified State Space Layers for Sequence Modeling

This repository provides the implementation for the paper: Simplified State Space Layers for Sequence Modeling. The preprint is available here.

Figure 1: S5 uses a single multi-input, multi-output linear state-space model, coupled with non-linearities, to define a non-linear sequence-to-sequence transformation. Parallel scans are used for efficient offline processing.

The S5 layer builds on the prior S4 work (paper). While it has departed considerably, this repository originally started off with much of the JAX implementation of S4 from the Annotated S4 blog by Rush and Karamcheti (available here).

Requirements & Installation

To run the code on your own machine, run either pip install -r requirements_cpu.txt or pip install -r requirements_gpu.txt. The GPU installation of JAX can be tricky, and so we include requirements that should work for most people, although further instructions are available here.

Run from within the root directory pip install -e . to install the package.

Data Download

Downloading the raw data is done differently for each dataset. The following datasets require no action:

  • Text (IMDb)
  • Image (Cifar black & white)
  • sMNIST
  • psMNIST
  • Cifar (Color)

The remaining datasets need to be manually downloaded. To download everything, run ./bin/download_all.sh. This will download quite a lot of data and will take some time.

Below is a summary of the steps for each dataset:

  • ListOps: run ./bin/download_lra.sh to download the full LRA dataset.
  • Retrieval (AAN): run ./bin/download_aan.sh
  • Pathfinder: run ./bin/download_lra.sh to download the full LRA dataset.
  • Path-X: run ./bin/download_lra.sh to download the full LRA dataset.
  • Speech commands 35: run ./bin/download_sc35.sh to download the speech commands data.

With the exception of SC35. When the dataset is used for the first time, a cache is created in ./cache_dir. Converting the data (e.g. tokenizing) can be quite slow, and so this cache contains the processed dataset. The cache can be moved and specified with the --dir_name argument (i.e. the default is --dir_name=./cache_dir) to avoid applying this preprocessing every time the code is run somewhere new.

SC35 is slightly different. SC35 doesn't use --dir_name, and instead requires that the following path exists: ./raw_datasets/speech_commands/0.0.2/SpeechCommands (i.e. the directory ./raw_datasets/speech_commands/0.0.2/SpeechCommands/zero must exist). The cache is then stored in ./raw_datasets/speech_commands/0.0.2/SpeechCommands/processed_data. This directory can then be copied (preserving the directory path) to move the preprocessed dataset to a new location.

Repository Structure

Directories and files that ship with GitHub repo:

s5/                    Source code for models, datasets, etc.
    dataloading.py          Dataloading functions.
    layers.py               Defines the S5 layer which wraps the S5 SSM with nonlinearity, norms, dropout, etc.
    seq_model.py            Defines deep sequence models that consist of stacks of S5 layers.
    ssm.py                  S5 SSM implementation.
    ssm_init.py             Helper functions for initializing the S5 SSM .
    train.py                Training loop code.
    train_helpers.py        Functions for optimization, training and evaluation steps.
    dataloaders/            Code mainly derived from S4 processing each dataset.
    utils/                  Range of utility functions.
bin/                    Shell scripts for downloading data and running example experiments.
requirements_cpu.txt    Requirements for running in CPU mode (not advised).
requirements_gpu.txt    Requirements for running in GPU mode (installation can be highly system-dependent).
run_train.py            Training loop entrypoint.

Directories that may be created on-the-fly:

raw_datasets/       Raw data as downloaded.
cache_dir/          Precompiled caches of data.  Can be copied to new locations to avoid preprocessing.
wandb/              Local WandB log files.

Experiments

The configurations to run the LRA and 35-way Speech Commands experiments from the paper are located in bin/run_experiments. For example, to run the LRA text (character level IMDB) experiment, run ./bin/run_experiments/run_lra_imdb.sh. To log with W&B, adjust the default USE_WANDB, wandb_entity, wandb_project arguments. Note: the pendulum regression dataloading and experiments will be added soon.

Citation

Please use the following when citing our work:

@misc{smith2022s5,
  doi = {10.48550/ARXIV.2208.04933},
  url = {https://arxiv.org/abs/2208.04933},
  author = {Smith, Jimmy T. H. and Warrington, Andrew and Linderman, Scott W.},
  keywords = {Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences},
  title = {Simplified State Space Layers for Sequence Modeling},
  publisher = {arXiv},
  year = {2022},
  copyright = {Creative Commons Attribution 4.0 International}
}

Please reach out if you have any questions.

-- The S5 authors.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .