Papers
arxiv:2410.01201

Were RNNs All We Needed?

Published on Oct 2
Authors:
,
,
,
,

Abstract

The scalability limitations of Transformers regarding sequence length have renewed interest in recurrent sequence models that are parallelizable during training. As a result, many novel recurrent architectures, such as S4, Mamba, and Aaren, have been proposed that achieve comparable performance. In this work, we revisit traditional recurrent neural networks (RNNs) from over a decade ago: LSTMs (1997) and GRUs (2014). While these models were slow due to requiring to backpropagate through time (BPTT), we show that by removing their hidden state dependencies from their input, forget, and update gates, LSTMs and GRUs no longer need to BPTT and can be efficiently trained in parallel. Building on this, we introduce minimal versions (minLSTMs and minGRUs) that (1) use significantly fewer parameters than their traditional counterparts and (2) are fully parallelizable during training (175x faster for a sequence of length 512). Lastly, we show that these stripped-down versions of decade-old RNNs match the empirical performance of recent sequence models.

Community

In case you are interested in this, here is a working implementation: https://github.com/lucidrains/minGRU-pytorch

ยท

I also have one here https://github.com/cheind/mingru featuring

  • Parallel: Efficient log-space parallel evaluation support plus sequential support for testing. Automatically dispatches to the most efficient implementation.
  • Multilayer: Stack multiple MinGRU layers via num_layers= arguments. When num_layers>1, the output hidden states of layer $i$ are passed as inputs to $i+1$.
  • Dropout: Via parameter dropout=, when > 0 all inputs of each layer are effected except for the last layer.
  • Bias: Biases in linear layers can be enabled and disabled via the bias= argument.
  • Residuals: Residual connections betweeen outputs of minGRU layers via residual= argument.
  • Transforms: Custom (shared) transforms betweeen outputs of minGRU layers via layer_transforms= argument.
  • Compatibility: Interface of mingru is mostly compatible with that of torch.nn.GRU, except that bi-directional and sequence-first arguments are not supported.

My summary of this paper:

image.png
๐Ÿ“œ ๐Ž๐ฅ๐-๐ฌ๐œ๐ก๐จ๐จ๐ฅ ๐‘๐๐๐ฌ ๐œ๐š๐ง ๐š๐œ๐ญ๐ฎ๐š๐ฅ๐ฅ๐ฒ ๐ซ๐ข๐ฏ๐š๐ฅ ๐Ÿ๐š๐ง๐œ๐ฒ ๐ญ๐ซ๐š๐ง๐ฌ๐Ÿ๐จ๐ซ๐ฆ๐ž๐ซ๐ฌ!

Researchers from Mila and Borealis AI just have shown that simplified versions of good old Recurrent Neural Networks (RNNs) can match the performance of today's transformers.

They took a fresh look at LSTMs (from 1997!) and GRUs (from 2014). They stripped these models down to their bare essentials, creating "minLSTM" and "minGRU". The key changes:
โถ Removed dependencies on previous hidden states in the gates
โท Dropped the tanh that had been added to restrict output range in order to avoid vanishing gradients
โธ Ensured outputs are time-independent in scale (not sure I understood that well either, don't worry)

โšก๏ธ As a result, you can use a โ€œparallel scanโ€ algorithm to train these new, minimal RNNs, in parallel, taking 88% more memory but also making them 200x faster than their traditional counterparts for long sequences

๐Ÿ”ฅ The results are mind-blowing! Performance-wise, they go toe-to-toe with Transformers or Mamba.

And for Language Modeling, they need 2.5x fewer training steps than Transformers to reach the same performance! ๐Ÿš€

๐Ÿค” Why does this matter?

By showing there are simpler models with similar performance to transformers, this challenges the narrative that we need advanced architectures for better performance!

๐Ÿ’ฌ Franรงois Chollet wrote in a tweet about this paper:

โ€œThe fact that there are many recent architectures coming from different directions that roughly match Transformers is proof that architectures aren't fundamentally important in the curve-fitting paradigm (aka deep learning)โ€

โ€œCurve-fitting is about embedding a dataset on a curve. The critical factor is the dataset, not the specific hard-coded bells and whistles that constrain the curve's shape.โ€

Itโ€™s the Bitter lesson by Rich Sutton striking again: donโ€™t need fancy thinking architectures, just scale up your model and data!

This is a very interesting topic. In academia, transformers seem to be the only choice, but with this architecture (older but more streamlined), we can achieve faster inference.

My summary

Screenshot 2024-10-06 at 8.22.45โ€ฏAM.png

Screenshot 2024-10-06 at 8.21.40โ€ฏAM.png

Sign up or log in to comment

Models citing this paper 0

No model linking this paper

Cite arxiv.org/abs/2410.01201 in a model README.md to link it from this page.

Datasets citing this paper 0

No dataset linking this paper

Cite arxiv.org/abs/2410.01201 in a dataset README.md to link it from this page.

Spaces citing this paper 0

No Space linking this paper

Cite arxiv.org/abs/2410.01201 in a Space README.md to link it from this page.

Collections including this paper 13