AnshulRanjan2004 commited on
Commit
7af0576
1 Parent(s): 40ee791

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +148 -3
README.md CHANGED
@@ -1,3 +1,148 @@
1
- ---
2
- license: gpl-3.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: gpl-3.0
3
+ ---
4
+ # MicroRWKV
5
+ This is a custom architecture for the nanoRWKV project from [RWKV-v4neo](https://github.com/BlinkDL/RWKV-LM/tree/main/RWKV-v4neo). The architecture is based on the original nanoRWKV architecture, but with some modifications.
6
+
7
+ ![nanoRWKV](assets/nanoRWKV-loss.png)
8
+
9
+ This is RWKV "x051a" which does not require custom CUDA kernel to train, so it works for any GPU / CPU.
10
+
11
+ ![nanoGPT](assets/nanoRWKV.png)
12
+
13
+ > The [nanoGPT](https://github.com/karpathy/nanoGPT)-style implementation of [RWKV Language Model](https://www.rwkv.com) - an RNN with GPT-level LLM performance.
14
+
15
+ Dataset used - [TinyStories](https://arxiv.org/abs/2305.07759)
16
+
17
+ ![nanoGPT](assets/current_loss.png)
18
+
19
+ RWKV is essentially an RNN with unrivaled advantage when doing inference. Here we benchmark the speed and space occupation of RWKV, along with its Transformer counterpart (code could be found [here](https://github.com/AnshulRanjan2004/MicroRWKV/blob/main/benchmark_inference_time.py)). We could easily find:
20
+ - single token generation latency of RWKV is an constant.
21
+ - overall latency of RWKV is linear with respect to context length.
22
+ - overall memory occupation of RWKV is an constant.
23
+
24
+ ![benchmark](assets/benchmark.png)
25
+
26
+ ## Prerequisites
27
+ Before kicking off this project, make sure you are familiar with the following concepts:
28
+ - **RNN**: RNN stands for Recurrent Neural Network. It is a type of artificial neural network designed to work with sequential data or time-series data. Check this [tutorial](http://karpathy.github.io/2015/05/21/rnn-effectiveness/) about RNN.
29
+ - **Transformer**: A Transformer is a type of deep learning model introduced in the paper [Attention is All You Need](https://arxiv.org/abs/1706.03762). It is specifically designed for handling sequential data, like natural language processing tasks, by using a mechanism called self-attention. Check this [post](http://jalammar.github.io/illustrated-transformer/) to know more about Transformer.
30
+ - **LLM**: LLM, short for Large Language Model, has taken the world by storm. Check this [Awesome-LLM repo](https://github.com/Hannibal046/Awesome-LLM) and [State of GPT](https://build.microsoft.com/en-US/sessions/db3f4859-cd30-4445-a0cd-553c3304f8e2).
31
+ - **nanoGPT**: the simplest, fastest repository for training/finetuning medium-sized GPTs by great [Andrej Karpathy](https://karpathy.ai). Here you could find the [code](https://github.com/karpathy/nanoGPT) and the [teaching video](https://www.youtube.com/watch?v=kCc8FmEb1nY).
32
+ - **RWKV Language Model**: an RNN with GPT-level LLM performance, which can also be directly trained like a GPT transformer (parallelizable). The model is created by an independent researcher [Bo Peng](https://twitter.com/BlinkDL_AI). Get more information [here](https://www.rwkv.com).
33
+
34
+ ## Model Structure
35
+ RWKV_TimeMix -> RWKV_ChannelMix -> Sliding Window Attention -> GroupedQAttention -> TinyMoE
36
+
37
+ Here is a brief description of each component:
38
+ 1. RWKV_TimeMix: This component applies a time-based mixing operation to the input, which helps the model capture temporal dependencies.
39
+ 2. RWKV_ChannelMix: The channel-based mixing operation is performed in this module, allowing the model to learn better representations across different feature channels.
40
+ 3. Sliding Window Attention: This attention mechanism operates on a sliding window of the input, enabling the model to efficiently capture local and global dependencies.
41
+ 4. GroupedQAttention: This attention module applies a grouped approach to the query, key, and value computations, improving the model's ability to capture multi-headed attention.
42
+ 5. TinyMoE: The Tiny Mixture of Experts (TinyMoE) layer is a lightweight and efficient implementation of a Mixture of Experts (MoE) mechanism, which can help the model learn specialized representations.
43
+
44
+ ## Detailed Explanation
45
+ 1. RWKV_TimeMix:
46
+ This module applies a time-based mixing operation to the input, which helps the model capture temporal dependencies.
47
+ It uses several learnable parameters, such as `time_maa_k`, `time_maa_v`, `time_maa_r`, and `time_maa_g`, to control the mixing process.
48
+ The module also applies a time-decay mechanism using the time_decay parameter, which allows the model to give more importance to recent inputs.
49
+ The output of this module is then passed through a series of linear layers, including the receptance, key, value, and gate layers.
50
+
51
+ ### Time Mixing
52
+ ![](assets/time_mixing.gif)
53
+
54
+ 2. RWKV_ChannelMix:
55
+ This module performs a channel-based mixing operation on the input, allowing the model to learn better representations across different feature channels.
56
+ It uses a time-shift operation and learnable parameters, such as `time_maa_k` and `time_maa_r`, to control the mixing process.
57
+ The module applies a key, value, and receptance linear layers to the mixed input, and the output is then passed through a sigmoid activation function.
58
+
59
+ ### Channel Mixing
60
+ ![](assets/channel_mixing.gif)
61
+
62
+ 3. Sliding Window Attention:
63
+ This attention mechanism operates on a sliding window of the input, enabling the model to efficiently capture both local and global dependencies.
64
+ The module computes the query, key, and value matrices using a linear layer, and then applies a sliding window attention operation to the input.
65
+ The output of the sliding window attention is then passed through a final linear layer to produce the final output.
66
+
67
+ 4. GroupedQAttention:
68
+ This attention module applies a grouped approach to the query, key, and value computations, improving the model's ability to capture multi-headed attention.
69
+ The module first computes the query, key, value, and weight matrices using a single linear layer, and then splits these matrices into groups.
70
+ The attention computation is then performed on each group, and the results are concatenated and passed through a final linear layer.
71
+
72
+ 5. TinyMoE:
73
+ The Tiny Mixture of Experts (TinyMoE) layer is a lightweight and efficient implementation of a Mixture of Experts (MoE) mechanism, which can help the model learn specialized representations.
74
+ The module computes attention scores using a linear layer, and then applies these scores to a set of expert networks to produce the final output.
75
+ The module also includes an auxiliary loss term that encourages the experts to learn diverse representations, improving the overall performance of the model.
76
+
77
+ ## Usage (Inference)
78
+ To use this model for inference, you can follow these steps:
79
+ 1. Download and paste model weights in the `out` directory.
80
+ 2. Copy and paste the values like: `block_size`, `vocab_size`, etc from the table into the class GPTConfig in `generate.py`.
81
+ 3. Then run the following command:
82
+ ```python
83
+ python generate.py --prompt="One day" --max_num_tokens=50 --model_name="ckpt-500"
84
+ ```
85
+ Explain:
86
+ This command will generate text based on the input prompt "One day" using the model weights stored in the `out` directory. The `max_num_tokens` parameter specifies the maximum number of tokens to generate, and the `model_name` parameter specifies the name of the model weights file to load. For `model_name`, you can specify the name of the model weights file without the extension, like "ckpt-500" or "ckpt-1000" or only "ckpt".
87
+
88
+ ## Tables
89
+ | name_model | BLOCK_SIZE | VOCAB_SIZE | N_LAYER | N_HEAD | N_EMBD | NUM_EXPERTS | NUM_ACTIVE_EXPERTS | EXPERT_DIM | DIM | DROPOUT | BIAS | DATASET |
90
+ |--------------|------------|------------|---------|--------|--------|-------------|--------------------|------------|-----|---------|-------|-----------------|
91
+ | ckpt-500.pth | 1024 | 50304 | 8 | 8 | 768 | 4 | 4 | 512 | 768 | 0.0 | False | tinystories_15k |
92
+
93
+ ## Results
94
+ Prompt: One day
95
+
96
+ Generated text: One day: Sharing positive bought Isabel a rainbow hug. Her name was an vitamins, so only one favorite thing to cheer she were.
97
+
98
+ Lily picked up a hay and proudly went to a small portion. She was very happened. When Tommy said it
99
+
100
+ Generated text length: 227 | Inference time: 3 seconds
101
+
102
+ We got the results as follows
103
+
104
+ ![](https://github.com/AnshulRanjan2004/MicroRWKV/assets/91585064/0dfb4bc1-843a-4c99-9d9f-eda5a0d5c110)
105
+
106
+ ![image](https://github.com/AnshulRanjan2004/MicroRWKV/assets/91585064/8b276bc2-93a2-4295-9449-8b7fbc90afcf)
107
+
108
+ | model | params | train loss | val loss |
109
+ | ----- | ------ | ---------- | -------- |
110
+ | GPT-2 | 124M | 2.82 | 2.86 |
111
+ | RWKV | 130M | 2.85 | 2.88 |
112
+
113
+
114
+ ### baselines
115
+
116
+ Existing OpenAI GPT-2 checkpoints and RWKV checkpoints allow us to get some baselines in place for openwebtext. We can get the numbers as follows:
117
+ ```
118
+ python train.py config/eval_rwkv4_{169m|430m|1b5|3b|7b|14b}.py
119
+ python train.py config/eval_gpt2{|_medium|_large|_xl}.py
120
+ ```
121
+ and observe the following losses on val set:
122
+ | model | RWKV | | | | | | GPT-2 | | | |
123
+ |:----------:|:----:|:----:|:----:|:----:|------|------|:-----:|:----:|:----:|:----:|
124
+ | parameters | 169M | 430M | 1.5B | 3B | 7B | 14B | 124M | 350M | 774M | 1.5B |
125
+ | val loss | 3.11 | 2.79 | 2.54 | 2.42 | 2.32 | 2.23 | 3.11 | 2.82 | 2.66 | 2.56 |
126
+
127
+ Notice that both models are not trained in the openwebtext (RWKV in The Pile and OpenAI GPT-2 in private WebText), so they could be further improved due to dataset domain gap.
128
+
129
+ ## Dependencies
130
+ - torch
131
+ - numpy
132
+ - tiktoken
133
+
134
+ ## Conclusion
135
+ The MicroRWKV model is a custom neural network architecture that combines several cutting-edge techniques, such as time-based and channel-based mixing, sliding window attention, grouped attention, and a Tiny Mixture of Experts (TinyMoE) layer. These components work together to enhance the model's ability to capture both local and global dependencies, as well as to learn specialized representations. The combination of these techniques results in a powerful and efficient model that can be used for a variety of natural language processing tasks.
136
+
137
+ ## Reference
138
+ Here are some useful references (offering my sincerest gratitude):
139
+ - [nanoGPT](https://github.com/karpathy/nanoGPT) - The original nanoGPT implementation by [Andrej Karpathy]
140
+ - [RWKV: Reinventing RNNs for the Transformer Era](https://arxiv.org/abs/2305.13048) - the paper
141
+ - [How the RWKV language model works](https://johanwind.github.io/2023/03/23/rwkv_details.html) - a great blog post by [Johan Sokrates Wind](https://www.mn.uio.no/math/english/people/aca/johanswi/index.html).
142
+ - [Investigating the RWKV Language Model](https://ben.bolte.cc/rwkv-model) - a great post by [Ben Bolte](https://ben.bolte.cc)
143
+ - [An Attention Free Transformer](https://arxiv.org/abs/2105.14103) - a paper from Apple that inspires RWKV.
144
+ - [RWKV-in-150-lines](https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_in_150_lines.py)
145
+ - [nanoT5](https://github.com/PiotrNawrot/nanoT5) - a follow-up of nanoGPT for T5 model
146
+ - [有了Transformer框架后是不是RNN完全可以废弃了?](https://www.zhihu.com/question/302392659/answer/2954997969) - a great answer by [Songlin Yang](https://sustcsonglin.github.io)
147
+ - [RWKV的RNN CNN二象性](https://zhuanlan.zhihu.com/p/614311961) - a great zhihu post by [Songlin Yang](https://sustcsonglin.github.io)
148
+ - [Google新作试图“复活”RNN:RNN能否再次辉煌?](https://kexue.fm/archives/9554) - a great blog post by [苏剑林](https://kexue.fm/me.html)