dbal0503 commited on
Commit
84a3519
·
1 Parent(s): 4a30387

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +96 -1
README.md CHANGED
@@ -1,4 +1,99 @@
1
  ---
2
  language:
3
  - en
4
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  language:
3
  - en
4
+ ---
5
+ # S5: Simplified State Space Layers for Sequence Modeling
6
+
7
+ This repository provides the implementation for the
8
+ paper: Simplified State Space Layers for Sequence Modeling. The preprint is available [here](https://arxiv.org/abs/2208.04933).
9
+
10
+ ![](./docs/figures/pngs/s5-matrix-blocks.png)
11
+ <p style="text-align: center;">
12
+ Figure 1: S5 uses a single multi-input, multi-output linear state-space model, coupled with non-linearities, to define a non-linear sequence-to-sequence transformation. Parallel scans are used for efficient offline processing.
13
+ </p>
14
+
15
+
16
+ The S5 layer builds on the prior S4 work ([paper](https://arxiv.org/abs/2111.00396)). While it has departed considerably, this repository originally started off with much of the JAX implementation of S4 from the
17
+ Annotated S4 blog by Rush and Karamcheti (available [here](https://github.com/srush/annotated-s4)).
18
+
19
+
20
+ ## Requirements & Installation
21
+ To run the code on your own machine, run either `pip install -r requirements_cpu.txt` or `pip install -r requirements_gpu.txt`. The GPU installation of JAX can be tricky, and so we include requirements that should work for most people, although further instructions are available [here](https://github.com/google/jax#installation).
22
+
23
+ Run from within the root directory `pip install -e .` to install the package.
24
+
25
+
26
+ ## Data Download
27
+ Downloading the raw data is done differently for each dataset. The following datasets require no action:
28
+ - Text (IMDb)
29
+ - Image (Cifar black & white)
30
+ - sMNIST
31
+ - psMNIST
32
+ - Cifar (Color)
33
+
34
+ The remaining datasets need to be manually downloaded. To download _everything_, run `./bin/download_all.sh`. This will download quite a lot of data and will take some time.
35
+
36
+ Below is a summary of the steps for each dataset:
37
+ - ListOps: run `./bin/download_lra.sh` to download the full LRA dataset.
38
+ - Retrieval (AAN): run `./bin/download_aan.sh`
39
+ - Pathfinder: run `./bin/download_lra.sh` to download the full LRA dataset.
40
+ - Path-X: run `./bin/download_lra.sh` to download the full LRA dataset.
41
+ - Speech commands 35: run `./bin/download_sc35.sh` to download the speech commands data.
42
+
43
+ *With the exception of SC35.* When the dataset is used for the first time, a cache is created in `./cache_dir`. Converting the data (e.g. tokenizing) can be quite slow, and so this cache contains the processed dataset. The cache can be moved and specified with the `--dir_name` argument (i.e. the default is `--dir_name=./cache_dir`) to avoid applying this preprocessing every time the code is run somewhere new.
44
+
45
+ SC35 is slightly different. SC35 doesn't use `--dir_name`, and instead requires that the following path exists: `./raw_datasets/speech_commands/0.0.2/SpeechCommands` (i.e. the directory `./raw_datasets/speech_commands/0.0.2/SpeechCommands/zero` must exist). The cache is then stored in `./raw_datasets/speech_commands/0.0.2/SpeechCommands/processed_data`. This directory can then be copied (preserving the directory path) to move the preprocessed dataset to a new location.
46
+
47
+
48
+ ## Repository Structure
49
+ Directories and files that ship with GitHub repo:
50
+ ```
51
+ s5/ Source code for models, datasets, etc.
52
+ dataloading.py Dataloading functions.
53
+ layers.py Defines the S5 layer which wraps the S5 SSM with nonlinearity, norms, dropout, etc.
54
+ seq_model.py Defines deep sequence models that consist of stacks of S5 layers.
55
+ ssm.py S5 SSM implementation.
56
+ ssm_init.py Helper functions for initializing the S5 SSM .
57
+ train.py Training loop code.
58
+ train_helpers.py Functions for optimization, training and evaluation steps.
59
+ dataloaders/ Code mainly derived from S4 processing each dataset.
60
+ utils/ Range of utility functions.
61
+ bin/ Shell scripts for downloading data and running example experiments.
62
+ requirements_cpu.txt Requirements for running in CPU mode (not advised).
63
+ requirements_gpu.txt Requirements for running in GPU mode (installation can be highly system-dependent).
64
+ run_train.py Training loop entrypoint.
65
+ ```
66
+
67
+ Directories that may be created on-the-fly:
68
+ ```
69
+ raw_datasets/ Raw data as downloaded.
70
+ cache_dir/ Precompiled caches of data. Can be copied to new locations to avoid preprocessing.
71
+ wandb/ Local WandB log files.
72
+ ```
73
+
74
+ ## Experiments
75
+
76
+ The configurations to run the LRA and 35-way Speech Commands experiments from the paper are located in `bin/run_experiments`. For example,
77
+ to run the LRA text (character level IMDB) experiment, run `./bin/run_experiments/run_lra_imdb.sh`.
78
+ To log with W&B, adjust the default `USE_WANDB, wandb_entity, wandb_project` arguments.
79
+ Note: the pendulum
80
+ regression dataloading and experiments will be added soon.
81
+
82
+ ## Citation
83
+ Please use the following when citing our work:
84
+ ```
85
+ @misc{smith2022s5,
86
+ doi = {10.48550/ARXIV.2208.04933},
87
+ url = {https://arxiv.org/abs/2208.04933},
88
+ author = {Smith, Jimmy T. H. and Warrington, Andrew and Linderman, Scott W.},
89
+ keywords = {Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences},
90
+ title = {Simplified State Space Layers for Sequence Modeling},
91
+ publisher = {arXiv},
92
+ year = {2022},
93
+ copyright = {Creative Commons Attribution 4.0 International}
94
+ }
95
+ ```
96
+
97
+ Please reach out if you have any questions.
98
+
99
+ -- The S5 authors.