chore: adding lolcats configs scrc and src
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- configs/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml +52 -0
- configs/experiment/distill_alpaca_clean_xent1_mse1000_lr1e-2.yaml +52 -0
- configs/experiment/eval_alpaca_clean.yaml +56 -0
- configs/experiment/finetune_lora_fqkvo_alpaca_clean.yaml +58 -0
- configs/experiment/finetune_lora_qkvo_alpaca_clean.yaml +56 -0
- configs/experiment/no_distill_alpaca_clean.yaml +29 -0
- configs/model/base_llama3_1_8b.yaml +15 -0
- configs/model/base_llama3_8b.yaml +15 -0
- configs/model/base_mistral_7b.yaml +15 -0
- configs/model/chunked_experimental/distill_long_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml +40 -0
- configs/model/chunked_experimental/distill_long_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml +40 -0
- configs/model/chunked_experimental/distill_long_llama3_8b_lk_smd_wsw64_fd64_w01.yaml +34 -0
- configs/model/chunked_experimental/distill_long_llama3_8b_lk_smd_wtk64_fd64_w01.yaml +34 -0
- configs/model/chunked_experimental/distill_long_mistral_7b_lk_smd_wsw64_fd64_w01.yaml +36 -0
- configs/model/chunked_experimental/distill_long_mistral_7b_lk_smd_wtk64_fd64_w01.yaml +35 -0
- configs/model/distill_llama3_1_8b_lk_smd_fd64.yaml +35 -0
- configs/model/distill_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml +39 -0
- configs/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml +39 -0
- configs/model/distill_llama3_1_8b_lk_t2r.yaml +35 -0
- configs/model/distill_llama3_8b_lk_smd_fd64.yaml +29 -0
- configs/model/distill_llama3_8b_lk_smd_wsw64_fd64_w01.yaml +33 -0
- configs/model/distill_llama3_8b_lk_smd_wtk64_fd64_w01.yaml +33 -0
- configs/model/distill_llama3_8b_lk_t2r.yaml +29 -0
- configs/model/distill_mistral_7b_lk_smd_fd64.yaml +29 -0
- configs/model/distill_mistral_7b_lk_smd_wsw64_fd64_w01.yaml +35 -0
- configs/model/distill_mistral_7b_lk_smd_wtk64_fd64_w01.yaml +35 -0
- configs/model/distill_mistral_7b_lk_t2r.yaml +29 -0
- csrc/__init__.py +6 -0
- csrc/causal_attention.cpp +225 -0
- csrc/causal_attention.py +77 -0
- csrc/causal_attention_cuda.cu +1483 -0
- csrc/causal_attention_kv_cuda.cu +1483 -0
- csrc/setup.py +53 -0
- src/__init__.py +0 -0
- src/dataloaders/__init__.py +22 -0
- src/dataloaders/alpaca_clean.py +149 -0
- src/dataloaders/alpaca_clean_instruct.py +148 -0
- src/dataloaders/utils/__init__.py +4 -0
- src/dataloaders/utils/llama3.py +62 -0
- src/dataloaders/utils/packing.py +80 -0
- src/dataloaders/utils/setup.py +123 -0
- src/finetune.py +68 -0
- src/model/__init__.py +0 -0
- src/model/convert_model.py +173 -0
- src/model/feature_map.py +306 -0
- src/model/linear_attention/__init__.py +23 -0
- src/model/linear_attention/linear_attention.py +459 -0
- src/model/linear_attention/linear_window_attention_sw.py +339 -0
- src/model/linear_attention/linear_window_attention_sw_linear.py +522 -0
- src/model/linear_attention/linear_window_attention_sw_long.py +23 -0
configs/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset:
|
2 |
+
name: alpaca_clean
|
3 |
+
dataset_config:
|
4 |
+
name: default
|
5 |
+
path: yahma/alpaca-cleaned
|
6 |
+
chunk_size: 1024 # sequence length for distilling
|
7 |
+
concat_data: true
|
8 |
+
cache_dir: 'data/alpaca' # Change this to where you want to save
|
9 |
+
pretrained_model_config: # will be updated based on model_config
|
10 |
+
pretrained_model_name_or_path: 'meta-llama/Meta-Llama-3-8B'
|
11 |
+
cache_dir: '/scratch/'
|
12 |
+
preprocess_config: null
|
13 |
+
|
14 |
+
dataloader:
|
15 |
+
batch_size: 1
|
16 |
+
num_workers: 2
|
17 |
+
drop_last: false
|
18 |
+
pin_memory: true
|
19 |
+
|
20 |
+
optimizer:
|
21 |
+
optim: adamw_torch_fused
|
22 |
+
lr: 0.01
|
23 |
+
weight_decay: 0.0
|
24 |
+
|
25 |
+
lr_scheduler:
|
26 |
+
lr_scheduler_type: reduce_lr_on_plateau
|
27 |
+
mode: min
|
28 |
+
factor: 0.1
|
29 |
+
patience: 10
|
30 |
+
min_lr: 0.00001
|
31 |
+
|
32 |
+
trainer: # HuggingFace Trainer-like arguments
|
33 |
+
name: distill_attention_xent_mse
|
34 |
+
reverse_kl: false
|
35 |
+
mse_factor: 1000
|
36 |
+
xent_factor: 0
|
37 |
+
|
38 |
+
bf16: true
|
39 |
+
train_split: train
|
40 |
+
val_split: validation
|
41 |
+
num_train_epochs: 2
|
42 |
+
gradient_accumulation_steps: 8
|
43 |
+
seed: 42
|
44 |
+
batch_size: 1
|
45 |
+
load_best_model_at_end: true
|
46 |
+
greater_is_better: false
|
47 |
+
metric_for_best_model: distill/eval/loss
|
48 |
+
logging_steps: 100
|
49 |
+
evaluation_strategy: steps
|
50 |
+
max_steps: -1
|
51 |
+
eval_steps: 100
|
52 |
+
max_eval_batches: null
|
configs/experiment/distill_alpaca_clean_xent1_mse1000_lr1e-2.yaml
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset:
|
2 |
+
name: alpaca_clean
|
3 |
+
dataset_config:
|
4 |
+
name: default
|
5 |
+
path: yahma/alpaca-cleaned
|
6 |
+
chunk_size: 1024 # sequence length for distilling
|
7 |
+
concat_data: true
|
8 |
+
cache_dir: 'data/alpaca' # Change this to where you want to save
|
9 |
+
pretrained_model_config: # will be updated based on model_config
|
10 |
+
pretrained_model_name_or_path: 'meta-llama/Meta-Llama-3.1-8B'
|
11 |
+
cache_dir: '/data_persistent2/sim_data/llama-3_1-8b/'
|
12 |
+
preprocess_config: null
|
13 |
+
|
14 |
+
dataloader:
|
15 |
+
batch_size: 1
|
16 |
+
num_workers: 2
|
17 |
+
drop_last: false
|
18 |
+
pin_memory: true
|
19 |
+
|
20 |
+
optimizer:
|
21 |
+
optim: adamw_torch_fused
|
22 |
+
lr: 0.01
|
23 |
+
weight_decay: 0.0
|
24 |
+
|
25 |
+
lr_scheduler:
|
26 |
+
lr_scheduler_type: reduce_lr_on_plateau
|
27 |
+
mode: min
|
28 |
+
factor: 0.1
|
29 |
+
patience: 10
|
30 |
+
min_lr: 0.00001
|
31 |
+
|
32 |
+
trainer: # HuggingFace Trainer-like arguments
|
33 |
+
name: distill_attention_xent_mse
|
34 |
+
reverse_kl: false
|
35 |
+
mse_factor: 1000
|
36 |
+
xent_factor: 1
|
37 |
+
|
38 |
+
bf16: true
|
39 |
+
train_split: train
|
40 |
+
val_split: validation
|
41 |
+
num_train_epochs: 2
|
42 |
+
gradient_accumulation_steps: 8
|
43 |
+
seed: 42
|
44 |
+
batch_size: 1
|
45 |
+
load_best_model_at_end: true
|
46 |
+
greater_is_better: false
|
47 |
+
metric_for_best_model: distill/eval/loss
|
48 |
+
logging_steps: 100
|
49 |
+
evaluation_strategy: steps
|
50 |
+
max_steps: -1
|
51 |
+
eval_steps: 100
|
52 |
+
max_eval_batches: null
|
configs/experiment/eval_alpaca_clean.yaml
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset:
|
2 |
+
name: alpaca_clean
|
3 |
+
dataset_config:
|
4 |
+
name: alpaca
|
5 |
+
path: yahma/alpaca-cleaned
|
6 |
+
chunk_size: 1024 # sequence length for distilling
|
7 |
+
concat_data: true
|
8 |
+
cache_dir: 'data/alpaca' # Change this to where you want to save
|
9 |
+
pretrained_model_config:
|
10 |
+
pretrained_model_name_or_path: 'mistralai/Mistral-7B-v0.1' # will be updated based on model_config
|
11 |
+
cache_dir: '/scratch/'
|
12 |
+
preprocess_config: null
|
13 |
+
|
14 |
+
dataloader:
|
15 |
+
batch_size: 1
|
16 |
+
num_workers: 2
|
17 |
+
drop_last: false
|
18 |
+
pin_memory: true
|
19 |
+
|
20 |
+
optimizer:
|
21 |
+
optim: adamw_torch_fused
|
22 |
+
lr: 1e-4
|
23 |
+
weight_decay: 0.0
|
24 |
+
|
25 |
+
lr_scheduler:
|
26 |
+
lr_scheduler_type: reduce_lr_on_plateau
|
27 |
+
mode: min
|
28 |
+
factor: 0.1
|
29 |
+
patience: 10
|
30 |
+
min_lr: 0.00001
|
31 |
+
|
32 |
+
trainer: # HuggingFace Trainer-like arguments
|
33 |
+
name: finetune_seq2seq
|
34 |
+
bf16: true
|
35 |
+
train_split: train
|
36 |
+
val_split: test
|
37 |
+
num_train_epochs: 2
|
38 |
+
gradient_accumulation_steps: 8
|
39 |
+
seed: 42
|
40 |
+
batch_size: 1
|
41 |
+
load_best_model_at_end: true
|
42 |
+
greater_is_better: true
|
43 |
+
metric_for_best_model: eval/rouge/geometric_mean
|
44 |
+
logging_steps: 100
|
45 |
+
evaluation_strategy: steps
|
46 |
+
max_steps: -1
|
47 |
+
eval_steps: 100
|
48 |
+
max_eval_batches: null
|
49 |
+
|
50 |
+
finetune:
|
51 |
+
method: lora
|
52 |
+
kwargs:
|
53 |
+
r: 8
|
54 |
+
lora_alpha: 16
|
55 |
+
lora_dropout: 0 # 0.05
|
56 |
+
target_modules: ['q_proj', 'k_proj', 'v_proj', 'o_proj']
|
configs/experiment/finetune_lora_fqkvo_alpaca_clean.yaml
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset:
|
2 |
+
name: alpaca_clean
|
3 |
+
dataset_config:
|
4 |
+
name: default
|
5 |
+
path: yahma/alpaca-cleaned
|
6 |
+
chunk_size: 1024
|
7 |
+
concat_data: true
|
8 |
+
cache_dir: "data/alpaca"
|
9 |
+
pretrained_model_config:
|
10 |
+
pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1" # will be updated based on model_config
|
11 |
+
cache_dir: "/data_persistent2/sim_data/"
|
12 |
+
preprocess_config: null
|
13 |
+
|
14 |
+
dataloader:
|
15 |
+
batch_size: 1
|
16 |
+
num_workers: 2
|
17 |
+
drop_last: false
|
18 |
+
pin_memory: true
|
19 |
+
|
20 |
+
optimizer:
|
21 |
+
optim: adamw_torch_fused
|
22 |
+
lr: 1e-4
|
23 |
+
weight_decay: 0.0
|
24 |
+
|
25 |
+
lr_scheduler:
|
26 |
+
lr_scheduler_type: reduce_lr_on_plateau
|
27 |
+
mode: min
|
28 |
+
factor: 0.1
|
29 |
+
patience: 10
|
30 |
+
min_lr: 0.00001
|
31 |
+
|
32 |
+
trainer: # HuggingFace Trainer-like arguments
|
33 |
+
name: default_lm
|
34 |
+
bf16: true
|
35 |
+
train_split: train
|
36 |
+
val_split: validation
|
37 |
+
num_train_epochs: 2
|
38 |
+
gradient_accumulation_steps: 8
|
39 |
+
seed: 42
|
40 |
+
batch_size: 1
|
41 |
+
load_best_model_at_end: true
|
42 |
+
greater_is_better: false
|
43 |
+
metric_for_best_model: eval/loss # eval/rouge/geometric_mean
|
44 |
+
logging_steps: 100
|
45 |
+
evaluation_strategy: steps
|
46 |
+
max_steps: -1
|
47 |
+
eval_steps: 100
|
48 |
+
max_eval_batches: null
|
49 |
+
num_save_ckpt_steps: 200
|
50 |
+
|
51 |
+
finetune:
|
52 |
+
method: lora
|
53 |
+
kwargs:
|
54 |
+
r: 8
|
55 |
+
lora_alpha: 16
|
56 |
+
lora_dropout: 0 # 0.05
|
57 |
+
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
|
58 |
+
trainable_weights: ['feature_map_q.mlp.layer', 'feature_map_k.mlp.layer', 'window_factors']
|
configs/experiment/finetune_lora_qkvo_alpaca_clean.yaml
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset:
|
2 |
+
name: alpaca_clean
|
3 |
+
dataset_config:
|
4 |
+
name: default
|
5 |
+
path: yahma/alpaca-cleaned
|
6 |
+
chunk_size: 1024
|
7 |
+
concat_data: true
|
8 |
+
cache_dir: "data/alpaca"
|
9 |
+
pretrained_model_config:
|
10 |
+
pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1" # will be updated based on model_config
|
11 |
+
cache_dir: "/scratch/"
|
12 |
+
preprocess_config: null
|
13 |
+
|
14 |
+
dataloader:
|
15 |
+
batch_size: 1
|
16 |
+
num_workers: 2
|
17 |
+
drop_last: false
|
18 |
+
pin_memory: true
|
19 |
+
|
20 |
+
optimizer:
|
21 |
+
optim: adamw_torch_fused
|
22 |
+
lr: 1e-4
|
23 |
+
weight_decay: 0.0
|
24 |
+
|
25 |
+
lr_scheduler:
|
26 |
+
lr_scheduler_type: reduce_lr_on_plateau
|
27 |
+
mode: min
|
28 |
+
factor: 0.1
|
29 |
+
patience: 10
|
30 |
+
min_lr: 0.00001
|
31 |
+
|
32 |
+
trainer: # HuggingFace Trainer-like arguments
|
33 |
+
name: default_lm
|
34 |
+
bf16: true
|
35 |
+
train_split: train
|
36 |
+
val_split: validation
|
37 |
+
num_train_epochs: 2
|
38 |
+
gradient_accumulation_steps: 8
|
39 |
+
seed: 42
|
40 |
+
batch_size: 1
|
41 |
+
load_best_model_at_end: true
|
42 |
+
greater_is_better: false
|
43 |
+
metric_for_best_model: eval/loss # eval/rouge/geometric_mean
|
44 |
+
logging_steps: 100
|
45 |
+
evaluation_strategy: steps
|
46 |
+
max_steps: -1
|
47 |
+
eval_steps: 100
|
48 |
+
max_eval_batches: null
|
49 |
+
|
50 |
+
finetune:
|
51 |
+
method: lora
|
52 |
+
kwargs:
|
53 |
+
r: 8
|
54 |
+
lora_alpha: 16
|
55 |
+
lora_dropout: 0 # 0.05
|
56 |
+
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
|
configs/experiment/no_distill_alpaca_clean.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset:
|
2 |
+
name: alpaca_clean
|
3 |
+
dataset_config:
|
4 |
+
name: alpaca
|
5 |
+
path: yahma/alpaca-cleaned
|
6 |
+
chunk_size: 1024 # sequence length for distilling
|
7 |
+
concat_data: true
|
8 |
+
cache_dir: 'data/alpaca' # Change this to where you want to save
|
9 |
+
pretrained_model_config:
|
10 |
+
pretrained_model_name_or_path: 'mistralai/Mistral-7B-v0.1' # will be updated based on model_config
|
11 |
+
cache_dir: '/scr-ssd/mzhang/models/mistral-v0.1'
|
12 |
+
preprocess_config: null
|
13 |
+
|
14 |
+
dataloader:
|
15 |
+
batch_size: 1
|
16 |
+
num_workers: 2
|
17 |
+
drop_last: false
|
18 |
+
pin_memory: true
|
19 |
+
|
20 |
+
optimizer:
|
21 |
+
optim: adamw_torch_fused
|
22 |
+
lr: 0.01
|
23 |
+
weight_decay: 0.0
|
24 |
+
|
25 |
+
lr_scheduler:
|
26 |
+
lr_scheduler_type: none
|
27 |
+
|
28 |
+
trainer: # HuggingFace Trainer-like arguments
|
29 |
+
name: null
|
configs/model/base_llama3_1_8b.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: llama
|
2 |
+
model:
|
3 |
+
pretrained_model_name_or_path: 'meta-llama/Meta-Llama-3.1-8B'
|
4 |
+
cache_dir: '/scr-ssd/mzhang/models/llama3' # Set this to where you want to save checkpoint weights
|
5 |
+
return_dict: true
|
6 |
+
load_in_8bit: false
|
7 |
+
load_in_4bit: false
|
8 |
+
device_map: auto
|
9 |
+
low_cpu_mem_usage: true
|
10 |
+
torch_dtype: bfloat16
|
11 |
+
attn_implementation: flash_attention_2
|
12 |
+
rope_theta: 500000.0
|
13 |
+
|
14 |
+
attention:
|
15 |
+
attention_type: softmax
|
configs/model/base_llama3_8b.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: llama
|
2 |
+
model:
|
3 |
+
pretrained_model_name_or_path: 'meta-llama/Meta-Llama-3-8B'
|
4 |
+
cache_dir: '/scr-ssd/mzhang/models/llama3' # Set this to where you want to save checkpoint weights
|
5 |
+
return_dict: true
|
6 |
+
load_in_8bit: false
|
7 |
+
load_in_4bit: false
|
8 |
+
device_map: auto
|
9 |
+
low_cpu_mem_usage: true
|
10 |
+
torch_dtype: bfloat16
|
11 |
+
attn_implementation: flash_attention_2
|
12 |
+
rope_theta: 500000.0
|
13 |
+
|
14 |
+
attention:
|
15 |
+
attention_type: softmax
|
configs/model/base_mistral_7b.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: llama
|
2 |
+
model:
|
3 |
+
pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1"
|
4 |
+
cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights
|
5 |
+
return_dict: true
|
6 |
+
load_in_8bit: false
|
7 |
+
load_in_4bit: false
|
8 |
+
device_map: auto
|
9 |
+
low_cpu_mem_usage: true
|
10 |
+
torch_dtype: bfloat16
|
11 |
+
attn_implementation: flash_attention_2
|
12 |
+
rope_theta: 10000.0
|
13 |
+
|
14 |
+
attention:
|
15 |
+
attention_type: softmax
|
configs/model/chunked_experimental/distill_long_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Experimental config for chunked linear attention
|
2 |
+
name: llama
|
3 |
+
model:
|
4 |
+
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B"
|
5 |
+
cache_dir: "/scr-ssd/mzhang/models/llama-3_1-8b" # Set this to where you want to save checkpoint weights
|
6 |
+
return_dict: true
|
7 |
+
load_in_8bit: false
|
8 |
+
load_in_4bit: false
|
9 |
+
device_map: auto
|
10 |
+
low_cpu_mem_usage: true
|
11 |
+
torch_dtype: bfloat16
|
12 |
+
attn_implementation: flash_attention_2
|
13 |
+
rope_theta: 500000.0
|
14 |
+
rope_scaling:
|
15 |
+
factor: 8.0
|
16 |
+
low_freq_factor: 1.0
|
17 |
+
high_freq_factor: 4.0
|
18 |
+
original_max_position_embeddings: 8192
|
19 |
+
rope_type: llama3
|
20 |
+
|
21 |
+
attention:
|
22 |
+
attention_type: lolcats_long_llama_window_sw
|
23 |
+
state_chunk_len: 1024
|
24 |
+
window_size: 64
|
25 |
+
affine_attention_factors: false
|
26 |
+
init_window_factor: -2.1972245773362196
|
27 |
+
feature_map: softmax_dim
|
28 |
+
feature_map_kwargs:
|
29 |
+
eps: 1e-12
|
30 |
+
# mlp: null # to set
|
31 |
+
fullspace: true
|
32 |
+
layer_idx: null # to set
|
33 |
+
learned_kernel: untied_head_einsum
|
34 |
+
learned_kernel_kwargs:
|
35 |
+
feature_dim: 64
|
36 |
+
skip_connection: false
|
37 |
+
bias: false
|
38 |
+
zero_init: false
|
39 |
+
tie_qk_kernels: false
|
40 |
+
train_qk: false
|
configs/model/chunked_experimental/distill_long_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Experimental config for chunked linear attention
|
2 |
+
name: llama
|
3 |
+
model:
|
4 |
+
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B"
|
5 |
+
cache_dir: "/scr-ssd/mzhang/models/llama-3_1-8b" # Set this to where you want to save checkpoint weights
|
6 |
+
return_dict: true
|
7 |
+
load_in_8bit: false
|
8 |
+
load_in_4bit: false
|
9 |
+
device_map: auto
|
10 |
+
low_cpu_mem_usage: true
|
11 |
+
torch_dtype: bfloat16
|
12 |
+
attn_implementation: flash_attention_2
|
13 |
+
rope_theta: 500000.0
|
14 |
+
rope_scaling:
|
15 |
+
factor: 8.0
|
16 |
+
low_freq_factor: 1.0
|
17 |
+
high_freq_factor: 4.0
|
18 |
+
original_max_position_embeddings: 8192
|
19 |
+
rope_type: llama3
|
20 |
+
|
21 |
+
attention:
|
22 |
+
attention_type: lolcats_long_llama_window_tk
|
23 |
+
state_chunk_len: 1024
|
24 |
+
window_size: 64
|
25 |
+
affine_attention_factors: false
|
26 |
+
init_window_factor: -2.1972245773362196
|
27 |
+
feature_map: softmax_dim
|
28 |
+
feature_map_kwargs:
|
29 |
+
eps: 1e-12
|
30 |
+
# mlp: null # to set
|
31 |
+
fullspace: true
|
32 |
+
layer_idx: null # to set
|
33 |
+
learned_kernel: untied_head_einsum
|
34 |
+
learned_kernel_kwargs:
|
35 |
+
feature_dim: 64
|
36 |
+
skip_connection: false
|
37 |
+
bias: false
|
38 |
+
zero_init: false
|
39 |
+
tie_qk_kernels: false
|
40 |
+
train_qk: false
|
configs/model/chunked_experimental/distill_long_llama3_8b_lk_smd_wsw64_fd64_w01.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Experimental config for chunked linear attention
|
2 |
+
name: llama
|
3 |
+
model:
|
4 |
+
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B"
|
5 |
+
cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights
|
6 |
+
return_dict: true
|
7 |
+
load_in_8bit: false
|
8 |
+
load_in_4bit: false
|
9 |
+
device_map: auto
|
10 |
+
low_cpu_mem_usage: true
|
11 |
+
torch_dtype: bfloat16
|
12 |
+
attn_implementation: flash_attention_2
|
13 |
+
rope_theta: 500000.0
|
14 |
+
|
15 |
+
attention:
|
16 |
+
attention_type: lolcats_long_llama_window_sw
|
17 |
+
state_chunk_len: 1024
|
18 |
+
window_size: 64
|
19 |
+
affine_attention_factors: false
|
20 |
+
init_window_factor: -2.1972245773362196
|
21 |
+
feature_map: softmax_dim
|
22 |
+
feature_map_kwargs:
|
23 |
+
eps: 1e-12
|
24 |
+
# mlp: null # to set
|
25 |
+
fullspace: true
|
26 |
+
layer_idx: null # to set
|
27 |
+
learned_kernel: untied_head_einsum
|
28 |
+
learned_kernel_kwargs:
|
29 |
+
feature_dim: 64
|
30 |
+
skip_connection: false
|
31 |
+
bias: false
|
32 |
+
zero_init: false
|
33 |
+
tie_qk_kernels: false
|
34 |
+
train_qk: false
|
configs/model/chunked_experimental/distill_long_llama3_8b_lk_smd_wtk64_fd64_w01.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Experimental config for chunked linear attention
|
2 |
+
name: llama
|
3 |
+
model:
|
4 |
+
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B"
|
5 |
+
cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights
|
6 |
+
return_dict: true
|
7 |
+
load_in_8bit: false
|
8 |
+
load_in_4bit: false
|
9 |
+
device_map: auto
|
10 |
+
low_cpu_mem_usage: true
|
11 |
+
torch_dtype: bfloat16
|
12 |
+
attn_implementation: flash_attention_2
|
13 |
+
rope_theta: 500000.0
|
14 |
+
|
15 |
+
attention:
|
16 |
+
attention_type: lolcats_long_llama_window_tk
|
17 |
+
state_chunk_len: 1024
|
18 |
+
window_size: 64
|
19 |
+
affine_attention_factors: false
|
20 |
+
init_window_factor: -2.1972245773362196
|
21 |
+
feature_map: softmax_dim
|
22 |
+
feature_map_kwargs:
|
23 |
+
eps: 1e-12
|
24 |
+
# mlp: null # to set
|
25 |
+
fullspace: true
|
26 |
+
layer_idx: null # to set
|
27 |
+
learned_kernel: untied_head_einsum
|
28 |
+
learned_kernel_kwargs:
|
29 |
+
feature_dim: 64
|
30 |
+
skip_connection: false
|
31 |
+
bias: false
|
32 |
+
zero_init: false
|
33 |
+
tie_qk_kernels: false
|
34 |
+
train_qk: false
|
configs/model/chunked_experimental/distill_long_mistral_7b_lk_smd_wsw64_fd64_w01.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Experimental config for chunked linear attention
|
2 |
+
name: llama
|
3 |
+
model:
|
4 |
+
pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1"
|
5 |
+
cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights
|
6 |
+
return_dict: true
|
7 |
+
load_in_8bit: false
|
8 |
+
load_in_4bit: false
|
9 |
+
device_map: auto
|
10 |
+
low_cpu_mem_usage: true
|
11 |
+
torch_dtype: bfloat16
|
12 |
+
attn_implementation: flash_attention_2 # eager # so we can load attention weights
|
13 |
+
rope_theta: 10000.0
|
14 |
+
|
15 |
+
attention:
|
16 |
+
attention_type: lolcats_long_llama_window_sw
|
17 |
+
state_chunk_len: 512 # 1024
|
18 |
+
window_size: 64
|
19 |
+
affine_attention_factors: false
|
20 |
+
init_window_factor: -2.1972245773362196
|
21 |
+
train_window_factor: true
|
22 |
+
train_attention_weights: false
|
23 |
+
feature_map: softmax_dim
|
24 |
+
feature_map_kwargs:
|
25 |
+
eps: 1e-12
|
26 |
+
# mlp: null # to set
|
27 |
+
fullspace: true
|
28 |
+
layer_idx: null # to set
|
29 |
+
learned_kernel: untied_head_einsum
|
30 |
+
learned_kernel_kwargs:
|
31 |
+
feature_dim: 64
|
32 |
+
skip_connection: false
|
33 |
+
bias: false
|
34 |
+
zero_init: false
|
35 |
+
tie_qk_kernels: false
|
36 |
+
train_qk: false
|
configs/model/chunked_experimental/distill_long_mistral_7b_lk_smd_wtk64_fd64_w01.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: llama
|
2 |
+
model:
|
3 |
+
pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1"
|
4 |
+
cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights
|
5 |
+
return_dict: true
|
6 |
+
load_in_8bit: false
|
7 |
+
load_in_4bit: false
|
8 |
+
device_map: auto
|
9 |
+
low_cpu_mem_usage: true
|
10 |
+
torch_dtype: bfloat16
|
11 |
+
attn_implementation: flash_attention_2 # eager # so we can load attention weights
|
12 |
+
rope_theta: 10000.0
|
13 |
+
|
14 |
+
attention:
|
15 |
+
attention_type: lolcats_long_llama_window_tk
|
16 |
+
state_chunk_len: 512 # 1024
|
17 |
+
window_size: 64
|
18 |
+
affine_attention_factors: false
|
19 |
+
init_window_factor: -2.1972245773362196
|
20 |
+
train_window_factor: true
|
21 |
+
train_attention_weights: false
|
22 |
+
feature_map: softmax_dim
|
23 |
+
feature_map_kwargs:
|
24 |
+
eps: 1e-12
|
25 |
+
# mlp: null # to set
|
26 |
+
fullspace: true
|
27 |
+
layer_idx: null # to set
|
28 |
+
learned_kernel: untied_head_einsum
|
29 |
+
learned_kernel_kwargs:
|
30 |
+
feature_dim: 64
|
31 |
+
skip_connection: false
|
32 |
+
bias: false
|
33 |
+
zero_init: false
|
34 |
+
tie_qk_kernels: false
|
35 |
+
train_qk: false
|
configs/model/distill_llama3_1_8b_lk_smd_fd64.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: llama
|
2 |
+
model:
|
3 |
+
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B"
|
4 |
+
cache_dir: "/scr-ssd/mzhang/models/llama-3_1-8b" # Set this to where you want to save checkpoint weights
|
5 |
+
return_dict: true
|
6 |
+
load_in_8bit: false
|
7 |
+
load_in_4bit: false
|
8 |
+
device_map: auto
|
9 |
+
low_cpu_mem_usage: true
|
10 |
+
torch_dtype: bfloat16
|
11 |
+
attn_implementation: eager
|
12 |
+
rope_theta: 500000.0
|
13 |
+
rope_scaling:
|
14 |
+
factor: 8.0
|
15 |
+
low_freq_factor: 1.0
|
16 |
+
high_freq_factor: 4.0
|
17 |
+
original_max_position_embeddings: 8192
|
18 |
+
rope_type: llama3
|
19 |
+
|
20 |
+
attention:
|
21 |
+
attention_type: lolcats_llama
|
22 |
+
feature_map: softmax_dim
|
23 |
+
feature_map_kwargs:
|
24 |
+
eps: 1e-12
|
25 |
+
# mlp: null # to set
|
26 |
+
fullspace: true
|
27 |
+
layer_idx: null # to set
|
28 |
+
learned_kernel: untied_head_einsum
|
29 |
+
learned_kernel_kwargs:
|
30 |
+
feature_dim: 64
|
31 |
+
skip_connection: false
|
32 |
+
bias: false
|
33 |
+
zero_init: false
|
34 |
+
tie_qk_kernels: false
|
35 |
+
train_qk: false
|
configs/model/distill_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: llama
|
2 |
+
model:
|
3 |
+
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B"
|
4 |
+
cache_dir: "/scratch/" # Set this to where you want to save checkpoint weights
|
5 |
+
return_dict: true
|
6 |
+
load_in_8bit: false
|
7 |
+
load_in_4bit: false
|
8 |
+
device_map: auto
|
9 |
+
low_cpu_mem_usage: true
|
10 |
+
torch_dtype: bfloat16
|
11 |
+
attn_implementation: eager
|
12 |
+
rope_theta: 500000.0
|
13 |
+
rope_scaling:
|
14 |
+
factor: 8.0
|
15 |
+
low_freq_factor: 1.0
|
16 |
+
high_freq_factor: 4.0
|
17 |
+
original_max_position_embeddings: 8192
|
18 |
+
rope_type: llama3
|
19 |
+
|
20 |
+
attention:
|
21 |
+
attention_type: lolcats_llama_window_sw
|
22 |
+
state_chunk_len: 1024
|
23 |
+
window_size: 64
|
24 |
+
affine_attention_factors: false
|
25 |
+
init_window_factor: -2.1972245773362196
|
26 |
+
feature_map: softmax_dim
|
27 |
+
feature_map_kwargs:
|
28 |
+
eps: 1e-12
|
29 |
+
# mlp: null # to set
|
30 |
+
fullspace: true
|
31 |
+
layer_idx: null # to set
|
32 |
+
learned_kernel: untied_head_einsum
|
33 |
+
learned_kernel_kwargs:
|
34 |
+
feature_dim: 64
|
35 |
+
skip_connection: false
|
36 |
+
bias: false
|
37 |
+
zero_init: false
|
38 |
+
tie_qk_kernels: false
|
39 |
+
train_qk: false
|
configs/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: llama
|
2 |
+
model:
|
3 |
+
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B"
|
4 |
+
cache_dir: "/scratch/" # Set this to where you want to save checkpoint weights
|
5 |
+
return_dict: true
|
6 |
+
load_in_8bit: false
|
7 |
+
load_in_4bit: false
|
8 |
+
device_map: auto
|
9 |
+
low_cpu_mem_usage: true
|
10 |
+
torch_dtype: bfloat16
|
11 |
+
attn_implementation: eager
|
12 |
+
rope_theta: 500000.0
|
13 |
+
rope_scaling:
|
14 |
+
factor: 8.0
|
15 |
+
low_freq_factor: 1.0
|
16 |
+
high_freq_factor: 4.0
|
17 |
+
original_max_position_embeddings: 8192
|
18 |
+
rope_type: llama3
|
19 |
+
|
20 |
+
attention:
|
21 |
+
attention_type: lolcats_llama_window_tk
|
22 |
+
state_chunk_len: 1024
|
23 |
+
window_size: 64
|
24 |
+
affine_attention_factors: false
|
25 |
+
init_window_factor: -2.1972245773362196
|
26 |
+
feature_map: softmax_dim
|
27 |
+
feature_map_kwargs:
|
28 |
+
eps: 1e-12
|
29 |
+
# mlp: null # to set
|
30 |
+
fullspace: true
|
31 |
+
layer_idx: null # to set
|
32 |
+
learned_kernel: untied_head_einsum
|
33 |
+
learned_kernel_kwargs:
|
34 |
+
feature_dim: 64
|
35 |
+
skip_connection: false
|
36 |
+
bias: false
|
37 |
+
zero_init: false
|
38 |
+
tie_qk_kernels: false
|
39 |
+
train_qk: false
|
configs/model/distill_llama3_1_8b_lk_t2r.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: llama
|
2 |
+
model:
|
3 |
+
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B"
|
4 |
+
cache_dir: "/scr-ssd/mzhang/models/llama-3_1-8b" # Set this to where you want to save checkpoint weights
|
5 |
+
return_dict: true
|
6 |
+
load_in_8bit: false
|
7 |
+
load_in_4bit: false
|
8 |
+
device_map: auto
|
9 |
+
low_cpu_mem_usage: true
|
10 |
+
torch_dtype: bfloat16
|
11 |
+
attn_implementation: eager
|
12 |
+
rope_theta: 500000.0
|
13 |
+
rope_scaling:
|
14 |
+
factor: 8.0
|
15 |
+
low_freq_factor: 1.0
|
16 |
+
high_freq_factor: 4.0
|
17 |
+
original_max_position_embeddings: 8192
|
18 |
+
rope_type: llama3
|
19 |
+
|
20 |
+
attention:
|
21 |
+
attention_type: lolcats_llama
|
22 |
+
feature_map: relu
|
23 |
+
feature_map_kwargs:
|
24 |
+
eps: 1e-12
|
25 |
+
# mlp: null # to set
|
26 |
+
fullspace: true
|
27 |
+
layer_idx: null # to set
|
28 |
+
learned_kernel: untied_head_einsum
|
29 |
+
learned_kernel_kwargs:
|
30 |
+
feature_dim: 128
|
31 |
+
skip_connection: false
|
32 |
+
bias: true
|
33 |
+
zero_init: false
|
34 |
+
tie_qk_kernels: false
|
35 |
+
train_qk: false
|
configs/model/distill_llama3_8b_lk_smd_fd64.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: llama
|
2 |
+
model:
|
3 |
+
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B"
|
4 |
+
cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights
|
5 |
+
return_dict: true
|
6 |
+
load_in_8bit: false
|
7 |
+
load_in_4bit: false
|
8 |
+
device_map: auto
|
9 |
+
low_cpu_mem_usage: true
|
10 |
+
torch_dtype: bfloat16
|
11 |
+
attn_implementation: flash_attention_2
|
12 |
+
rope_theta: 500000.0
|
13 |
+
|
14 |
+
attention:
|
15 |
+
attention_type: lolcats_llama
|
16 |
+
feature_map: softmax_dim
|
17 |
+
feature_map_kwargs:
|
18 |
+
eps: 1e-12
|
19 |
+
# mlp: null # to set
|
20 |
+
fullspace: true
|
21 |
+
layer_idx: null # to set
|
22 |
+
learned_kernel: untied_head_einsum
|
23 |
+
learned_kernel_kwargs:
|
24 |
+
feature_dim: 64
|
25 |
+
skip_connection: false
|
26 |
+
bias: false
|
27 |
+
zero_init: false
|
28 |
+
tie_qk_kernels: false
|
29 |
+
train_qk: false
|
configs/model/distill_llama3_8b_lk_smd_wsw64_fd64_w01.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: llama
|
2 |
+
model:
|
3 |
+
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B"
|
4 |
+
cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights
|
5 |
+
return_dict: true
|
6 |
+
load_in_8bit: false
|
7 |
+
load_in_4bit: false
|
8 |
+
device_map: auto
|
9 |
+
low_cpu_mem_usage: true
|
10 |
+
torch_dtype: bfloat16
|
11 |
+
attn_implementation: flash_attention_2
|
12 |
+
rope_theta: 500000.0
|
13 |
+
|
14 |
+
attention:
|
15 |
+
attention_type: lolcats_llama_window_sw
|
16 |
+
state_chunk_len: 1024
|
17 |
+
window_size: 64
|
18 |
+
affine_attention_factors: false
|
19 |
+
init_window_factor: -2.1972245773362196
|
20 |
+
feature_map: softmax_dim
|
21 |
+
feature_map_kwargs:
|
22 |
+
eps: 1e-12
|
23 |
+
# mlp: null # to set
|
24 |
+
fullspace: true
|
25 |
+
layer_idx: null # to set
|
26 |
+
learned_kernel: untied_head_einsum
|
27 |
+
learned_kernel_kwargs:
|
28 |
+
feature_dim: 64
|
29 |
+
skip_connection: false
|
30 |
+
bias: false
|
31 |
+
zero_init: false
|
32 |
+
tie_qk_kernels: false
|
33 |
+
train_qk: false
|
configs/model/distill_llama3_8b_lk_smd_wtk64_fd64_w01.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: llama
|
2 |
+
model:
|
3 |
+
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B"
|
4 |
+
cache_dir: '/scr-ssd/mzhang/models/llama3' # Set this to where you want to save checkpoint weights
|
5 |
+
return_dict: true
|
6 |
+
load_in_8bit: false
|
7 |
+
load_in_4bit: false
|
8 |
+
device_map: auto
|
9 |
+
low_cpu_mem_usage: true
|
10 |
+
torch_dtype: bfloat16
|
11 |
+
attn_implementation: flash_attention_2
|
12 |
+
rope_theta: 500000.0
|
13 |
+
|
14 |
+
attention:
|
15 |
+
attention_type: lolcats_llama_window_tk
|
16 |
+
state_chunk_len: 1024
|
17 |
+
window_size: 64
|
18 |
+
affine_attention_factors: false
|
19 |
+
init_window_factor: -2.1972245773362196
|
20 |
+
feature_map: softmax_dim
|
21 |
+
feature_map_kwargs:
|
22 |
+
eps: 1e-12
|
23 |
+
# mlp: null # to set
|
24 |
+
fullspace: true
|
25 |
+
layer_idx: null # to set
|
26 |
+
learned_kernel: untied_head_einsum
|
27 |
+
learned_kernel_kwargs:
|
28 |
+
feature_dim: 64
|
29 |
+
skip_connection: false
|
30 |
+
bias: false
|
31 |
+
zero_init: false
|
32 |
+
tie_qk_kernels: false
|
33 |
+
train_qk: false
|
configs/model/distill_llama3_8b_lk_t2r.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: llama
|
2 |
+
model:
|
3 |
+
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B"
|
4 |
+
cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights
|
5 |
+
return_dict: true
|
6 |
+
load_in_8bit: false
|
7 |
+
load_in_4bit: false
|
8 |
+
device_map: auto
|
9 |
+
low_cpu_mem_usage: true
|
10 |
+
torch_dtype: bfloat16
|
11 |
+
attn_implementation: flash_attention_2
|
12 |
+
rope_theta: 500000.0
|
13 |
+
|
14 |
+
attention:
|
15 |
+
attention_type: lolcats_llama
|
16 |
+
feature_map: relu
|
17 |
+
feature_map_kwargs:
|
18 |
+
eps: 1e-12
|
19 |
+
# mlp: null # to set
|
20 |
+
fullspace: true
|
21 |
+
layer_idx: null # to set
|
22 |
+
learned_kernel: untied_head_einsum
|
23 |
+
learned_kernel_kwargs:
|
24 |
+
feature_dim: 128
|
25 |
+
skip_connection: false
|
26 |
+
bias: true
|
27 |
+
zero_init: false
|
28 |
+
tie_qk_kernels: false
|
29 |
+
train_qk: false
|
configs/model/distill_mistral_7b_lk_smd_fd64.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: llama
|
2 |
+
model:
|
3 |
+
pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1"
|
4 |
+
cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights
|
5 |
+
return_dict: true
|
6 |
+
load_in_8bit: false
|
7 |
+
load_in_4bit: false
|
8 |
+
device_map: auto
|
9 |
+
low_cpu_mem_usage: true
|
10 |
+
torch_dtype: bfloat16
|
11 |
+
attn_implementation: flash_attention_2 # eager # so we can load attention weights
|
12 |
+
rope_theta: 10000.0
|
13 |
+
|
14 |
+
attention:
|
15 |
+
attention_type: lolcats_llama
|
16 |
+
feature_map: softmax_dim
|
17 |
+
feature_map_kwargs:
|
18 |
+
eps: 1e-12
|
19 |
+
# mlp: null # to set
|
20 |
+
fullspace: true
|
21 |
+
layer_idx: null # to set
|
22 |
+
learned_kernel: untied_head_einsum
|
23 |
+
learned_kernel_kwargs:
|
24 |
+
feature_dim: 64
|
25 |
+
skip_connection: false
|
26 |
+
bias: false
|
27 |
+
zero_init: false
|
28 |
+
tie_qk_kernels: false
|
29 |
+
train_qk: false
|
configs/model/distill_mistral_7b_lk_smd_wsw64_fd64_w01.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: llama
|
2 |
+
model:
|
3 |
+
pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1"
|
4 |
+
cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights
|
5 |
+
return_dict: true
|
6 |
+
load_in_8bit: false
|
7 |
+
load_in_4bit: false
|
8 |
+
device_map: auto
|
9 |
+
low_cpu_mem_usage: true
|
10 |
+
torch_dtype: bfloat16
|
11 |
+
attn_implementation: flash_attention_2 # eager # so we can load attention weights
|
12 |
+
rope_theta: 10000.0
|
13 |
+
|
14 |
+
attention:
|
15 |
+
attention_type: lolcats_llama_window_sw
|
16 |
+
state_chunk_len: 512 # 1024
|
17 |
+
window_size: 64
|
18 |
+
affine_attention_factors: false
|
19 |
+
init_window_factor: -2.1972245773362196
|
20 |
+
train_window_factor: true
|
21 |
+
train_attention_weights: false
|
22 |
+
feature_map: softmax_dim
|
23 |
+
feature_map_kwargs:
|
24 |
+
eps: 1e-12
|
25 |
+
# mlp: null # to set
|
26 |
+
fullspace: true
|
27 |
+
layer_idx: null # to set
|
28 |
+
learned_kernel: untied_head_einsum
|
29 |
+
learned_kernel_kwargs:
|
30 |
+
feature_dim: 64
|
31 |
+
skip_connection: false
|
32 |
+
bias: false
|
33 |
+
zero_init: false
|
34 |
+
tie_qk_kernels: false
|
35 |
+
train_qk: false
|
configs/model/distill_mistral_7b_lk_smd_wtk64_fd64_w01.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: llama
|
2 |
+
model:
|
3 |
+
pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1"
|
4 |
+
cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights
|
5 |
+
return_dict: true
|
6 |
+
load_in_8bit: false
|
7 |
+
load_in_4bit: false
|
8 |
+
device_map: auto
|
9 |
+
low_cpu_mem_usage: true
|
10 |
+
torch_dtype: bfloat16
|
11 |
+
attn_implementation: flash_attention_2 # eager # so we can load attention weights
|
12 |
+
rope_theta: 10000.0
|
13 |
+
|
14 |
+
attention:
|
15 |
+
attention_type: lolcats_llama_window_tk
|
16 |
+
state_chunk_len: 512 # 1024
|
17 |
+
window_size: 64
|
18 |
+
affine_attention_factors: false
|
19 |
+
init_window_factor: -2.1972245773362196
|
20 |
+
train_window_factor: true
|
21 |
+
train_attention_weights: false
|
22 |
+
feature_map: softmax_dim
|
23 |
+
feature_map_kwargs:
|
24 |
+
eps: 1e-12
|
25 |
+
# mlp: null # to set
|
26 |
+
fullspace: true
|
27 |
+
layer_idx: null # to set
|
28 |
+
learned_kernel: untied_head_einsum
|
29 |
+
learned_kernel_kwargs:
|
30 |
+
feature_dim: 64
|
31 |
+
skip_connection: false
|
32 |
+
bias: false
|
33 |
+
zero_init: false
|
34 |
+
tie_qk_kernels: false
|
35 |
+
train_qk: false
|
configs/model/distill_mistral_7b_lk_t2r.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: llama
|
2 |
+
model:
|
3 |
+
pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1"
|
4 |
+
cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights
|
5 |
+
return_dict: true
|
6 |
+
load_in_8bit: false
|
7 |
+
load_in_4bit: false
|
8 |
+
device_map: auto
|
9 |
+
low_cpu_mem_usage: true
|
10 |
+
torch_dtype: bfloat16
|
11 |
+
attn_implementation: flash_attention_2 # eager # so we can load attention weights
|
12 |
+
rope_theta: 10000.0
|
13 |
+
|
14 |
+
attention:
|
15 |
+
attention_type: lolcats_llama
|
16 |
+
feature_map: relu
|
17 |
+
feature_map_kwargs:
|
18 |
+
eps: 1e-12
|
19 |
+
# mlp: null # to set
|
20 |
+
fullspace: true
|
21 |
+
layer_idx: null # to set
|
22 |
+
learned_kernel: untied_head_einsum
|
23 |
+
learned_kernel_kwargs:
|
24 |
+
feature_dim: 128
|
25 |
+
skip_connection: false
|
26 |
+
bias: true
|
27 |
+
zero_init: false
|
28 |
+
tie_qk_kernels: false
|
29 |
+
train_qk: false
|
csrc/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
4 |
+
# Apoorv Vyas <avyas@idiap.ch>
|
5 |
+
#
|
6 |
+
from .causal_attention import causal_dot_product
|
csrc/causal_attention.cpp
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
//
|
2 |
+
// Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
// Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
4 |
+
// Apoorv Vyas <avyas@idiap.ch>
|
5 |
+
//
|
6 |
+
|
7 |
+
#include <torch/extension.h>
|
8 |
+
|
9 |
+
|
10 |
+
/**
|
11 |
+
* Compute a*b^T and save it into out.
|
12 |
+
*
|
13 |
+
* a \in R^A
|
14 |
+
* b \in R^B
|
15 |
+
*/
|
16 |
+
inline void vvt_dot(float *a, float *b, float *out, int A, int B) {
|
17 |
+
for (int i=0; i<A; i++) {
|
18 |
+
float * bi = b;
|
19 |
+
for (int j=0; j<B; j++) {
|
20 |
+
*out += (*a) * (*bi);
|
21 |
+
out++;
|
22 |
+
bi++;
|
23 |
+
}
|
24 |
+
a++;
|
25 |
+
}
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
/**
|
30 |
+
* Implement a vector matrix product v*m and save it into out.
|
31 |
+
*
|
32 |
+
* v \in R^A
|
33 |
+
* m \in R^{AxB}
|
34 |
+
*/
|
35 |
+
inline void vm_dot(float *v, float *m, float *out, int A, int B) {
|
36 |
+
// TODO: Consider removing the zeroing part and assuming out already
|
37 |
+
// contains 0s
|
38 |
+
for (int i=0; i<B; i++) {
|
39 |
+
out[i] = 0;
|
40 |
+
}
|
41 |
+
|
42 |
+
for (int i=0; i<A; i++) {
|
43 |
+
float *oi = out;
|
44 |
+
for (int j=0; j<B; j++) {
|
45 |
+
*oi += (*v) * (*m);
|
46 |
+
oi++;
|
47 |
+
m++;
|
48 |
+
}
|
49 |
+
v++;
|
50 |
+
}
|
51 |
+
}
|
52 |
+
|
53 |
+
|
54 |
+
/**
|
55 |
+
* Implement a vector transposed-matrix product and save it into out.
|
56 |
+
*
|
57 |
+
* v \in R^B
|
58 |
+
* m \in R^{AxB}
|
59 |
+
*/
|
60 |
+
inline void vmt_dot(float *v, float *m, float *out, int A, int B) {
|
61 |
+
for (int i=0; i<A; i++) {
|
62 |
+
float *vi = v;
|
63 |
+
float s = 0;
|
64 |
+
for (int j=0; j<B; j++) {
|
65 |
+
s += (*vi) * (*m);
|
66 |
+
vi++;
|
67 |
+
m++;
|
68 |
+
}
|
69 |
+
// TODO: Should we be aggregating? See the comment on vm_dot.
|
70 |
+
*out = s;
|
71 |
+
out++;
|
72 |
+
}
|
73 |
+
}
|
74 |
+
|
75 |
+
|
76 |
+
/**
|
77 |
+
* Compute the causally masked dot products of queries, keys and values.
|
78 |
+
*
|
79 |
+
* Basically compute V_j' = (Q_{0:j} * K_{0:j}^T) * V_{0:j} for all j. The
|
80 |
+
* computation is done efficiently by changing the order of the dot products.
|
81 |
+
*/
|
82 |
+
void causal_dot_product(
|
83 |
+
const torch::Tensor queries,
|
84 |
+
const torch::Tensor keys,
|
85 |
+
const torch::Tensor values,
|
86 |
+
torch::Tensor product
|
87 |
+
) {
|
88 |
+
// Extract some shapes
|
89 |
+
int N = queries.size(0);
|
90 |
+
int H = queries.size(1);
|
91 |
+
int L = queries.size(2);
|
92 |
+
int E = queries.size(3);
|
93 |
+
int M = values.size(3);
|
94 |
+
|
95 |
+
// Create accessors for all the arguments
|
96 |
+
auto qa = queries.accessor<float, 4>();
|
97 |
+
auto ka = keys.accessor<float, 4>();
|
98 |
+
auto va = values.accessor<float, 4>();
|
99 |
+
auto pa = product.accessor<float, 4>();
|
100 |
+
|
101 |
+
#pragma omp parallel for collapse(2)
|
102 |
+
for (int n=0; n<N; n++) {
|
103 |
+
for (int h=0; h<H; h++) {
|
104 |
+
auto kv = torch::zeros({E, M}, queries.options());
|
105 |
+
float *kvp = kv.data_ptr<float>();
|
106 |
+
for (int l=0; l<L; l++) {
|
107 |
+
vvt_dot(
|
108 |
+
&ka[n][h][l][0],
|
109 |
+
&va[n][h][l][0],
|
110 |
+
kvp,
|
111 |
+
E,
|
112 |
+
M
|
113 |
+
);
|
114 |
+
vm_dot(
|
115 |
+
&qa[n][h][l][0],
|
116 |
+
kvp,
|
117 |
+
&pa[n][h][l][0],
|
118 |
+
E,
|
119 |
+
M
|
120 |
+
);
|
121 |
+
}
|
122 |
+
}
|
123 |
+
}
|
124 |
+
}
|
125 |
+
|
126 |
+
|
127 |
+
/**
|
128 |
+
* Compute the gradients of queries, keys and values given the gradient of the
|
129 |
+
* causal_dot_product output.
|
130 |
+
*
|
131 |
+
* Make sure that everything is computed in O(N D^2) complexity.
|
132 |
+
*/
|
133 |
+
void causal_dot_backward(
|
134 |
+
const torch::Tensor queries,
|
135 |
+
const torch::Tensor keys,
|
136 |
+
const torch::Tensor values,
|
137 |
+
const torch::Tensor grad_out,
|
138 |
+
torch::Tensor grad_queries,
|
139 |
+
torch::Tensor grad_keys,
|
140 |
+
torch::Tensor grad_values
|
141 |
+
) {
|
142 |
+
// Extract some shapes
|
143 |
+
int N = queries.size(0);
|
144 |
+
int H = queries.size(1);
|
145 |
+
int L = queries.size(2);
|
146 |
+
int E = queries.size(3);
|
147 |
+
int M = values.size(3);
|
148 |
+
|
149 |
+
// Create accessors for all the arguments
|
150 |
+
auto qa = queries.accessor<float, 4>();
|
151 |
+
auto ka = keys.accessor<float, 4>();
|
152 |
+
auto va = values.accessor<float, 4>();
|
153 |
+
auto ga = grad_out.accessor<float, 4>();
|
154 |
+
auto gqa = grad_queries.accessor<float, 4>();
|
155 |
+
auto gka = grad_keys.accessor<float, 4>();
|
156 |
+
auto gva = grad_values.accessor<float, 4>();
|
157 |
+
|
158 |
+
#pragma omp parallel for collapse(2)
|
159 |
+
for (int n=0; n<N; n++) {
|
160 |
+
for (int h=0; h<H; h++) {
|
161 |
+
auto kv = torch::zeros({E, M}, queries.options());
|
162 |
+
float *kvp = kv.data_ptr<float>();
|
163 |
+
|
164 |
+
// Compute the gradient wrt the queries
|
165 |
+
for (int l=0; l<L; l++) {
|
166 |
+
vvt_dot(
|
167 |
+
&ka[n][h][l][0],
|
168 |
+
&va[n][h][l][0],
|
169 |
+
kvp,
|
170 |
+
E,
|
171 |
+
M
|
172 |
+
);
|
173 |
+
vmt_dot(
|
174 |
+
&ga[n][h][l][0],
|
175 |
+
kvp,
|
176 |
+
&gqa[n][h][l][0],
|
177 |
+
E,
|
178 |
+
M
|
179 |
+
);
|
180 |
+
}
|
181 |
+
|
182 |
+
// Compute the gradient wrt the keys and values
|
183 |
+
kv.zero_();
|
184 |
+
for (int l=L-1; l>=0; l--) {
|
185 |
+
vvt_dot(
|
186 |
+
&qa[n][h][l][0],
|
187 |
+
&ga[n][h][l][0],
|
188 |
+
kvp,
|
189 |
+
E,
|
190 |
+
M
|
191 |
+
);
|
192 |
+
vmt_dot(
|
193 |
+
&va[n][h][l][0],
|
194 |
+
kvp,
|
195 |
+
&gka[n][h][l][0],
|
196 |
+
E,
|
197 |
+
M
|
198 |
+
);
|
199 |
+
vm_dot(
|
200 |
+
&ka[n][h][l][0],
|
201 |
+
kvp,
|
202 |
+
&gva[n][h][l][0],
|
203 |
+
E,
|
204 |
+
M
|
205 |
+
);
|
206 |
+
}
|
207 |
+
}
|
208 |
+
}
|
209 |
+
}
|
210 |
+
|
211 |
+
|
212 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
213 |
+
m.def(
|
214 |
+
"causal_dot_product",
|
215 |
+
&causal_dot_product,
|
216 |
+
"Compute the weighted sum of values but attending only to previous "
|
217 |
+
"values."
|
218 |
+
);
|
219 |
+
m.def(
|
220 |
+
"causal_dot_backward",
|
221 |
+
&causal_dot_backward,
|
222 |
+
"Compute the gradient of queries, keys and values given the gradient "
|
223 |
+
"of causal_dot_product."
|
224 |
+
);
|
225 |
+
}
|
csrc/causal_attention.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
4 |
+
# Apoorv Vyas <avyas@idiap.ch>
|
5 |
+
#
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
try:
|
10 |
+
from causal_attention_cuda import causal_dot_product as causal_dot_product_cuda
|
11 |
+
from causal_attention_cuda import causal_dot_backward as causal_dot_backward_cuda
|
12 |
+
except ImportError as e:
|
13 |
+
print(e)
|
14 |
+
causal_dot_product_cuda = causal_dot_backward_cuda = None
|
15 |
+
|
16 |
+
|
17 |
+
class CausalDotProduct(torch.autograd.Function):
|
18 |
+
"""Compute the weighted sum of values but attending only to previous
|
19 |
+
values."""
|
20 |
+
dot = {
|
21 |
+
# "cpu": causal_dot_product_cpu,
|
22 |
+
"cuda": causal_dot_product_cuda
|
23 |
+
}
|
24 |
+
dot_backward = {
|
25 |
+
# "cpu": causal_dot_backward_cpu,
|
26 |
+
"cuda": causal_dot_backward_cuda
|
27 |
+
}
|
28 |
+
|
29 |
+
@staticmethod
|
30 |
+
def forward(ctx, Q, K, V):
|
31 |
+
# Save the inputs for the gradient computation
|
32 |
+
ctx.save_for_backward(Q, K, V)
|
33 |
+
|
34 |
+
# Create the output tensor
|
35 |
+
device = Q.device
|
36 |
+
N, H, L, _ = Q.shape
|
37 |
+
_, _, _, M = V.shape
|
38 |
+
product = torch.zeros((N, H, L, M), dtype=Q.dtype, device=device)
|
39 |
+
|
40 |
+
# Actually perform the dot product
|
41 |
+
CausalDotProduct.dot[device.type](
|
42 |
+
Q.data,
|
43 |
+
K.data,
|
44 |
+
V.data,
|
45 |
+
product
|
46 |
+
)
|
47 |
+
# breakpoint()
|
48 |
+
# CausalDotProduct.dot[device.type](Q.data, K.data, V.data, product)
|
49 |
+
|
50 |
+
return product
|
51 |
+
|
52 |
+
@staticmethod
|
53 |
+
def backward(ctx, grad_out):
|
54 |
+
# Extract the saved tensors
|
55 |
+
Q, K, V = ctx.saved_tensors
|
56 |
+
|
57 |
+
# Allocate memory for the gradients
|
58 |
+
grad_Q = torch.zeros_like(Q)
|
59 |
+
grad_K = torch.zeros_like(K)
|
60 |
+
grad_V = torch.zeros_like(V)
|
61 |
+
|
62 |
+
# Actually compute the gradients
|
63 |
+
CausalDotProduct.dot_backward[Q.device.type](
|
64 |
+
Q.data,
|
65 |
+
K.data,
|
66 |
+
V.data,
|
67 |
+
grad_out,
|
68 |
+
grad_Q,
|
69 |
+
grad_K,
|
70 |
+
grad_V
|
71 |
+
)
|
72 |
+
|
73 |
+
return grad_Q, grad_K, grad_V
|
74 |
+
|
75 |
+
|
76 |
+
# Alias the autograd functions to python style snake case naming
|
77 |
+
causal_dot_product = CausalDotProduct.apply
|
csrc/causal_attention_cuda.cu
ADDED
@@ -0,0 +1,1483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
//
|
2 |
+
// Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
// Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
4 |
+
// Apoorv Vyas <avyas@idiap.ch>
|
5 |
+
//
|
6 |
+
|
7 |
+
//
|
8 |
+
// For modifications made inside namespace nvidia (authored by jdemouth):
|
9 |
+
//
|
10 |
+
// Copyright (c) 2021 NVIDIA CORPORATION. All rights reserved.
|
11 |
+
//
|
12 |
+
// Permission is hereby granted, free of charge, to any person obtaining a copy of
|
13 |
+
// this software and associated documentation files (the "Software"), to deal in
|
14 |
+
// the Software without restriction, including without limitation the rights to
|
15 |
+
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
16 |
+
// the Software, and to permit persons to whom the Software is furnished to do so,
|
17 |
+
// subject to the following conditions:
|
18 |
+
//
|
19 |
+
// The above copyright notice and this permission notice shall be included in all
|
20 |
+
// copies or substantial portions of the Software.
|
21 |
+
//
|
22 |
+
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
23 |
+
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
24 |
+
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
25 |
+
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
26 |
+
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
27 |
+
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
28 |
+
//
|
29 |
+
|
30 |
+
#include <torch/extension.h>
|
31 |
+
#include <assert.h>
|
32 |
+
#include <stdio.h>
|
33 |
+
|
34 |
+
#define ENABLE_NVIDIA_OPTIMIZATIONS
|
35 |
+
|
36 |
+
#ifdef ENABLE_NVIDIA_OPTIMIZATIONS
|
37 |
+
namespace nvidia {
|
38 |
+
|
39 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
40 |
+
|
41 |
+
constexpr int THREADS_PER_WARP = 32;
|
42 |
+
|
43 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
44 |
+
|
45 |
+
constexpr int LOW_OCCUPANCY_THRESHOLD = 40; // TODO: Make it HW specific (like 1/2 SMs).
|
46 |
+
|
47 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
48 |
+
|
49 |
+
static inline __device__ __host__ int div_up(int m, int n) {
|
50 |
+
return (m + n-1) / n;
|
51 |
+
}
|
52 |
+
|
53 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
54 |
+
|
55 |
+
static inline __device__ __host__ int round_up(int m, int n) {
|
56 |
+
return div_up(m, n) * n;
|
57 |
+
}
|
58 |
+
|
59 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
60 |
+
|
61 |
+
template< typename T >
|
62 |
+
struct Lmha_params {
|
63 |
+
|
64 |
+
// The output buffer. Dimensions [B, H, L, M].
|
65 |
+
T *out;
|
66 |
+
|
67 |
+
// The input Qs. Dimensions [B, H, L, E].
|
68 |
+
const T *q;
|
69 |
+
// The input Ks. Dimensions [B, H, L, E].
|
70 |
+
const T *k;
|
71 |
+
// The input Vs. Dimensions [B, H, L, M].
|
72 |
+
const T *v;
|
73 |
+
|
74 |
+
// The different dimensions.
|
75 |
+
int B, L, H, E, M;
|
76 |
+
|
77 |
+
// The strides for the different tensors.
|
78 |
+
int q_stride_B, q_stride_H, q_stride_L;
|
79 |
+
int k_stride_B, k_stride_H, k_stride_L;
|
80 |
+
int v_stride_B, v_stride_H, v_stride_L;
|
81 |
+
int o_stride_B, o_stride_H, o_stride_L;
|
82 |
+
};
|
83 |
+
|
84 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
85 |
+
|
86 |
+
template< int E, bool GO_BACKWARD, int WARPS, int COLS_PER_THREAD = 4 >
|
87 |
+
__global__ __launch_bounds__(WARPS * THREADS_PER_WARP)
|
88 |
+
void lmha_low_occupancy_kernel(Lmha_params<float> params) {
|
89 |
+
|
90 |
+
// The number of threads per block.
|
91 |
+
constexpr int THREADS_PER_BLOCK = WARPS * THREADS_PER_WARP;
|
92 |
+
// The number of rows per thread.
|
93 |
+
constexpr int ROWS_PER_THREAD = E / THREADS_PER_WARP;
|
94 |
+
// The number of steps per iteration.
|
95 |
+
constexpr int COLS_PER_ITER = WARPS * COLS_PER_THREAD;
|
96 |
+
|
97 |
+
// Make sure E is a multiple of the warp size.
|
98 |
+
static_assert(E % THREADS_PER_WARP == 0, "");
|
99 |
+
|
100 |
+
// Shared memory to store V/O.
|
101 |
+
__shared__ float smem_v[COLS_PER_ITER], smem_o[COLS_PER_ITER];
|
102 |
+
// Shared memory buffer to performance the reductions.
|
103 |
+
__shared__ float smem_reds[E * WARPS];
|
104 |
+
|
105 |
+
// The sequence processed by that block.
|
106 |
+
const int bi = blockIdx.z;
|
107 |
+
// The head processed by that block.
|
108 |
+
const int hi = blockIdx.y;
|
109 |
+
// The hidden cell in the V/output buffers.
|
110 |
+
const int vi = blockIdx.x;
|
111 |
+
|
112 |
+
// The linear index of the thread.
|
113 |
+
const int tidx = threadIdx.x;
|
114 |
+
|
115 |
+
// Decompose the block in warp/lane.
|
116 |
+
const int warp = tidx / THREADS_PER_WARP;
|
117 |
+
const int lane = tidx % THREADS_PER_WARP;
|
118 |
+
|
119 |
+
// The base offset loaded by the thread in Q and K.
|
120 |
+
int offset_q = bi*params.q_stride_B + hi*params.q_stride_H + lane;
|
121 |
+
int offset_k = bi*params.k_stride_B + hi*params.k_stride_H + lane;
|
122 |
+
|
123 |
+
// If we walk backward, account for the extra offset.
|
124 |
+
if( GO_BACKWARD ) {
|
125 |
+
offset_q += (params.L-1)*params.q_stride_L;
|
126 |
+
offset_k += (params.L-1)*params.k_stride_L;
|
127 |
+
}
|
128 |
+
|
129 |
+
// Position the warp at the beginning of the proper timestep.
|
130 |
+
if( GO_BACKWARD ) {
|
131 |
+
offset_q -= warp*COLS_PER_THREAD*params.q_stride_L;
|
132 |
+
offset_k -= warp*COLS_PER_THREAD*params.k_stride_L;
|
133 |
+
} else {
|
134 |
+
offset_q += warp*COLS_PER_THREAD*params.q_stride_L;
|
135 |
+
offset_k += warp*COLS_PER_THREAD*params.k_stride_L;
|
136 |
+
}
|
137 |
+
|
138 |
+
// Determine the base pointers for Q and K.
|
139 |
+
const float *ptr_q = ¶ms.q[offset_q];
|
140 |
+
const float *ptr_k = ¶ms.k[offset_k];
|
141 |
+
|
142 |
+
// Is a given row valid?
|
143 |
+
int valid_qk[ROWS_PER_THREAD];
|
144 |
+
#pragma unroll
|
145 |
+
for( int ii = 0; ii < ROWS_PER_THREAD; ++ii ) {
|
146 |
+
valid_qk[ii] = lane + ii*THREADS_PER_WARP < params.E;
|
147 |
+
}
|
148 |
+
|
149 |
+
// The offset to the position loaded by the thread in V.
|
150 |
+
int offset_v = bi*params.v_stride_B + hi*params.v_stride_H + vi;
|
151 |
+
int offset_o = bi*params.o_stride_B + hi*params.o_stride_H + vi;
|
152 |
+
|
153 |
+
// If we walk backward, account for the extra offset.
|
154 |
+
if( GO_BACKWARD ) {
|
155 |
+
offset_v += (params.L-1)*params.v_stride_L;
|
156 |
+
offset_o += (params.L-1)*params.o_stride_L;
|
157 |
+
}
|
158 |
+
|
159 |
+
// We load/store a strided matrix of COLS_PER_ITER x OUTPUTS_PER_BLOCK.
|
160 |
+
if( GO_BACKWARD ) {
|
161 |
+
offset_v -= tidx*params.v_stride_L;
|
162 |
+
offset_o -= tidx*params.o_stride_L;
|
163 |
+
} else {
|
164 |
+
offset_v += tidx*params.v_stride_L;
|
165 |
+
offset_o += tidx*params.o_stride_L;
|
166 |
+
}
|
167 |
+
|
168 |
+
// Determine the base pointer for V.
|
169 |
+
const float *ptr_v = ¶ms.v[offset_v];
|
170 |
+
// The output pointer.
|
171 |
+
float *ptr_o = ¶ms.out[offset_o];
|
172 |
+
|
173 |
+
// The running KVs.
|
174 |
+
float running_kv[ROWS_PER_THREAD];
|
175 |
+
#pragma unroll
|
176 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
177 |
+
running_kv[ri] = 0.f;
|
178 |
+
}
|
179 |
+
|
180 |
+
// Iterate over the timesteps. TODO: Use params.loop_count!!!
|
181 |
+
for( int iter = 0; iter < params.L; iter += COLS_PER_ITER ) {
|
182 |
+
|
183 |
+
// Each thread loads a matrix of elements.
|
184 |
+
float q[ROWS_PER_THREAD][COLS_PER_THREAD], k[ROWS_PER_THREAD][COLS_PER_THREAD];
|
185 |
+
|
186 |
+
// Trigger the memory loads for Q and K.
|
187 |
+
#pragma unroll
|
188 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
189 |
+
#pragma unroll
|
190 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
191 |
+
|
192 |
+
// For Q/K, each warp loads from various timesteps.
|
193 |
+
int ti = iter + warp*COLS_PER_THREAD;
|
194 |
+
if( GO_BACKWARD ) {
|
195 |
+
ti = params.L - 1 - ti;
|
196 |
+
}
|
197 |
+
|
198 |
+
// Is it a valid access?
|
199 |
+
int valid;
|
200 |
+
if( GO_BACKWARD ) {
|
201 |
+
valid = valid_qk[ri] && ti - ci >= 0;
|
202 |
+
} else {
|
203 |
+
valid = valid_qk[ri] && ti + ci < params.L;
|
204 |
+
}
|
205 |
+
|
206 |
+
// The extra offset to add.
|
207 |
+
if( GO_BACKWARD ) {
|
208 |
+
offset_q = ri*THREADS_PER_WARP - ci*params.q_stride_L;
|
209 |
+
offset_k = ri*THREADS_PER_WARP - ci*params.k_stride_L;
|
210 |
+
} else {
|
211 |
+
offset_q = ri*THREADS_PER_WARP + ci*params.q_stride_L;
|
212 |
+
offset_k = ri*THREADS_PER_WARP + ci*params.k_stride_L;
|
213 |
+
}
|
214 |
+
|
215 |
+
// Load Q/K if they are valid.
|
216 |
+
q[ri][ci] = valid ? ptr_q[offset_q] : 0.f;
|
217 |
+
k[ri][ci] = valid ? ptr_k[offset_k] : 0.f;
|
218 |
+
}
|
219 |
+
}
|
220 |
+
|
221 |
+
// For the V tensor, we assign contiguous thread to different loads. So, ti is different.
|
222 |
+
int ti = iter + tidx;
|
223 |
+
if( GO_BACKWARD ) {
|
224 |
+
ti = params.L - 1 - ti;
|
225 |
+
}
|
226 |
+
|
227 |
+
// Is it a valid access?
|
228 |
+
int valid_vo = tidx < COLS_PER_ITER;
|
229 |
+
if( GO_BACKWARD ) {
|
230 |
+
valid_vo &= ti >= 0;
|
231 |
+
} else {
|
232 |
+
valid_vo &= ti < params.L;
|
233 |
+
}
|
234 |
+
|
235 |
+
// Trigger the loads for V.
|
236 |
+
float ldg_v = valid_vo ? *ptr_v : 0.f;
|
237 |
+
|
238 |
+
// Move the load pointers.
|
239 |
+
if( GO_BACKWARD ) {
|
240 |
+
ptr_q -= COLS_PER_ITER*params.q_stride_L;
|
241 |
+
ptr_k -= COLS_PER_ITER*params.k_stride_L;
|
242 |
+
ptr_v -= COLS_PER_ITER*params.v_stride_L;
|
243 |
+
} else {
|
244 |
+
ptr_q += COLS_PER_ITER*params.q_stride_L;
|
245 |
+
ptr_k += COLS_PER_ITER*params.k_stride_L;
|
246 |
+
ptr_v += COLS_PER_ITER*params.v_stride_L;
|
247 |
+
}
|
248 |
+
|
249 |
+
// Store to shared memory.
|
250 |
+
if( tidx < COLS_PER_ITER ) {
|
251 |
+
smem_v[tidx] = ldg_v;
|
252 |
+
}
|
253 |
+
|
254 |
+
// Make sure V is in shared memory.
|
255 |
+
__syncthreads();
|
256 |
+
|
257 |
+
// Read V from shared memory.
|
258 |
+
float v[COLS_PER_THREAD];
|
259 |
+
#pragma unroll
|
260 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
261 |
+
v[ci] = smem_v[warp*COLS_PER_THREAD + ci];
|
262 |
+
}
|
263 |
+
|
264 |
+
// Each thread computes local K*V products.
|
265 |
+
float kv[ROWS_PER_THREAD][COLS_PER_THREAD];
|
266 |
+
#pragma unroll
|
267 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
268 |
+
#pragma unroll
|
269 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
270 |
+
kv[ri][ci] = 0.f;
|
271 |
+
}
|
272 |
+
}
|
273 |
+
|
274 |
+
// Update the K*V^T product.
|
275 |
+
#pragma unroll
|
276 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
277 |
+
#pragma unroll
|
278 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
279 |
+
kv[ri][ci] += k[ri][ci] * v[ci];
|
280 |
+
}
|
281 |
+
}
|
282 |
+
|
283 |
+
// We must perform the prefix sums within the thread-block. Start with the thread.
|
284 |
+
#pragma unroll
|
285 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
286 |
+
#pragma unroll
|
287 |
+
for( int ci = 1; ci < COLS_PER_THREAD; ++ci ) {
|
288 |
+
kv[ri][ci] += kv[ri][ci-1];
|
289 |
+
}
|
290 |
+
}
|
291 |
+
|
292 |
+
// Store the partial sums to shared memory. Unless we have no inter-warp reduction to perform.
|
293 |
+
#pragma unroll
|
294 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
295 |
+
smem_reds[warp*E + ri*THREADS_PER_WARP + lane] = kv[ri][COLS_PER_THREAD-1];
|
296 |
+
}
|
297 |
+
|
298 |
+
// Make sure the data is in shared memory.
|
299 |
+
__syncthreads();
|
300 |
+
|
301 |
+
// Each thread deals with one or more column(s) of the matrix.
|
302 |
+
constexpr int SUMS_PER_THREAD = (E + THREADS_PER_BLOCK-1) / THREADS_PER_BLOCK;
|
303 |
+
#pragma unroll
|
304 |
+
for( int ii = 0, idx = tidx; ii < SUMS_PER_THREAD; ++ii, idx += THREADS_PER_BLOCK ) {
|
305 |
+
if( idx < E ) {
|
306 |
+
float sum = smem_reds[idx];
|
307 |
+
#pragma unroll
|
308 |
+
for( int jj = 1; jj < WARPS; ++jj ) {
|
309 |
+
smem_reds[idx + jj*E] = sum += smem_reds[idx + jj*E];
|
310 |
+
}
|
311 |
+
}
|
312 |
+
}
|
313 |
+
|
314 |
+
// Make sure the reductions are stored in shared memory.
|
315 |
+
__syncthreads();
|
316 |
+
|
317 |
+
// Each thread updates his partial products.
|
318 |
+
#pragma unroll
|
319 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
320 |
+
float sum = running_kv[ri];
|
321 |
+
if( warp > 0 ) {
|
322 |
+
sum += smem_reds[(warp-1)*E + lane + ri*THREADS_PER_WARP];
|
323 |
+
}
|
324 |
+
#pragma unroll
|
325 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
326 |
+
kv[ri][ci] += sum;
|
327 |
+
}
|
328 |
+
}
|
329 |
+
|
330 |
+
// Compute the partial output values for that thread.
|
331 |
+
float sum[COLS_PER_THREAD];
|
332 |
+
#pragma unroll
|
333 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
334 |
+
sum[ci] = q[0][ci] * kv[0][ci];
|
335 |
+
#pragma unroll
|
336 |
+
for( int ri = 1; ri < ROWS_PER_THREAD; ++ri ) {
|
337 |
+
sum[ci] += q[ri][ci] * kv[ri][ci];
|
338 |
+
}
|
339 |
+
}
|
340 |
+
|
341 |
+
// Run the parallel reductions inside the warp.
|
342 |
+
#pragma unroll
|
343 |
+
for( int mask = THREADS_PER_WARP / 2; mask >= 1; mask /= 2 ) {
|
344 |
+
#pragma unroll
|
345 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
346 |
+
sum[ci] += __shfl_xor_sync(uint32_t(-1), sum[ci], mask);
|
347 |
+
}
|
348 |
+
}
|
349 |
+
|
350 |
+
// Store the final output to shared memory.
|
351 |
+
if( lane == 0 ) {
|
352 |
+
#pragma unroll
|
353 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
354 |
+
smem_o[warp*COLS_PER_THREAD + ci] = sum[ci];
|
355 |
+
}
|
356 |
+
}
|
357 |
+
|
358 |
+
// Make sure the data is in shared memory.
|
359 |
+
__syncthreads();
|
360 |
+
|
361 |
+
// Store the output.
|
362 |
+
if( valid_vo ) {
|
363 |
+
*ptr_o = smem_o[tidx];
|
364 |
+
}
|
365 |
+
|
366 |
+
// Each thread updates his running kv.
|
367 |
+
#pragma unroll
|
368 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
369 |
+
running_kv[ri] += smem_reds[(WARPS-1)*E + lane + ri*THREADS_PER_WARP];
|
370 |
+
}
|
371 |
+
|
372 |
+
// Move to next location.
|
373 |
+
if( GO_BACKWARD ) {
|
374 |
+
ptr_o -= COLS_PER_ITER*params.o_stride_L;
|
375 |
+
} else {
|
376 |
+
ptr_o += COLS_PER_ITER*params.o_stride_L;
|
377 |
+
}
|
378 |
+
}
|
379 |
+
}
|
380 |
+
|
381 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
382 |
+
|
383 |
+
template< int E, bool GO_BACKWARD, int WARPS >
|
384 |
+
int lmha_low_occupancy_(const Lmha_params<float> ¶ms) {
|
385 |
+
|
386 |
+
// Make sure we are not going to launch an invalid grid.
|
387 |
+
if( params.H > 65535 || params.B > 65535 ) {
|
388 |
+
return 1;
|
389 |
+
}
|
390 |
+
|
391 |
+
// Prepare the grid and trigger the CUDA kernel.
|
392 |
+
dim3 grid;
|
393 |
+
grid.x = params.M;
|
394 |
+
grid.y = params.H;
|
395 |
+
grid.z = params.B;
|
396 |
+
lmha_low_occupancy_kernel<E, GO_BACKWARD, WARPS><<<grid, WARPS*THREADS_PER_WARP>>>(params);
|
397 |
+
return 0;
|
398 |
+
}
|
399 |
+
|
400 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
401 |
+
|
402 |
+
template< int E, bool GO_BACKWARD >
|
403 |
+
int lmha_low_occupancy_(const Lmha_params<float> ¶ms, int blocks) {
|
404 |
+
if( params.M * blocks >= 8*LOW_OCCUPANCY_THRESHOLD ) {
|
405 |
+
return lmha_low_occupancy_<E, GO_BACKWARD, 4>(params);
|
406 |
+
} else if( params.M * blocks >= 4*LOW_OCCUPANCY_THRESHOLD ) {
|
407 |
+
return lmha_low_occupancy_<E, GO_BACKWARD, 8>(params);
|
408 |
+
} else {
|
409 |
+
return lmha_low_occupancy_<E, GO_BACKWARD, 16>(params);
|
410 |
+
}
|
411 |
+
return 1;
|
412 |
+
}
|
413 |
+
|
414 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
415 |
+
|
416 |
+
template< int E, typename Params >
|
417 |
+
static inline __device__ __host__ int smem_buffer_elts_(const Params ¶ms) {
|
418 |
+
int M = round_up(params.M, 4);
|
419 |
+
return 2*E + 2*M;
|
420 |
+
}
|
421 |
+
|
422 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
423 |
+
|
424 |
+
template< int E, int THREADS_PER_HEAD, bool GO_BACKWARD >
|
425 |
+
__global__
|
426 |
+
void lmha_kernel(Lmha_params<float> params) {
|
427 |
+
|
428 |
+
// Make sure E is a multiple of 4.
|
429 |
+
static_assert(E % 4 == 0, "");
|
430 |
+
|
431 |
+
// The amount of shared memory per buffer (2 buffers for double-buffering).
|
432 |
+
const int smem_buffer_elts = smem_buffer_elts_<E>(params);
|
433 |
+
// The M dimension for shared memory.
|
434 |
+
const int M = round_up(params.M, 4);
|
435 |
+
|
436 |
+
// Shared memory to store Q, K and V. Size is 2*smem_buffer_elts.
|
437 |
+
extern __shared__ float smem_[];
|
438 |
+
|
439 |
+
// The various shared memory buffers.
|
440 |
+
float *smem_q = &smem_[0*E];
|
441 |
+
float *smem_k = &smem_[1*E];
|
442 |
+
float *smem_v = &smem_[2*E];
|
443 |
+
float *smem_o = &smem_[2*E + M];
|
444 |
+
|
445 |
+
// The index of the shared memory buffer (for double-buffering).
|
446 |
+
int smem_curr = 0;
|
447 |
+
|
448 |
+
// The sequence processed by that block.
|
449 |
+
const int bi = blockIdx.y;
|
450 |
+
// The head processed by that block.
|
451 |
+
const int hi = blockIdx.x;
|
452 |
+
|
453 |
+
// The linear index of the thread.
|
454 |
+
const int tidx = threadIdx.x;
|
455 |
+
|
456 |
+
// The offset to the position loaded by the thread in Q.
|
457 |
+
int offset_q = bi*params.q_stride_B + hi*params.q_stride_H + tidx;
|
458 |
+
// The offset to the position loaded by the thread in K.
|
459 |
+
int offset_k = bi*params.k_stride_B + hi*params.k_stride_H + tidx;
|
460 |
+
|
461 |
+
// If we walk backward, account for the extra offset.
|
462 |
+
if( GO_BACKWARD ) {
|
463 |
+
offset_q += (params.L-1)*params.q_stride_L;
|
464 |
+
offset_k += (params.L-1)*params.k_stride_L;
|
465 |
+
}
|
466 |
+
|
467 |
+
// Determine the base pointers for Q and K.
|
468 |
+
const float *ptr_q = ¶ms.q[offset_q];
|
469 |
+
const float *ptr_k = ¶ms.k[offset_k];
|
470 |
+
|
471 |
+
// The offset to the position loaded by the thread in V and O.
|
472 |
+
int offset_v = bi*params.v_stride_B + hi*params.v_stride_H + tidx;
|
473 |
+
int offset_o = bi*params.o_stride_B + hi*params.o_stride_H + tidx;
|
474 |
+
|
475 |
+
// If we walk backward, account for the extra offset.
|
476 |
+
if( GO_BACKWARD ) {
|
477 |
+
offset_v += (params.L-1)*params.v_stride_L;
|
478 |
+
offset_o += (params.L-1)*params.o_stride_L;
|
479 |
+
}
|
480 |
+
|
481 |
+
// Determine the base pointers for V.
|
482 |
+
const float *ptr_v = ¶ms.v[offset_v];
|
483 |
+
|
484 |
+
// Is it an active Q/K thread?
|
485 |
+
const int active_qk = tidx < params.E;
|
486 |
+
|
487 |
+
// Trigger the memory loads for Q and K.
|
488 |
+
float ldg_q = 0.f, ldg_k = 0.f;
|
489 |
+
if( active_qk ) {
|
490 |
+
ldg_q = *ptr_q;
|
491 |
+
ldg_k = *ptr_k;
|
492 |
+
}
|
493 |
+
|
494 |
+
// Is it an active V thread?
|
495 |
+
const int active_v = tidx < params.M;
|
496 |
+
|
497 |
+
// Trigger the memory loads for V.
|
498 |
+
float ldg_v = 0.f;
|
499 |
+
if( active_v ) {
|
500 |
+
ldg_v = *ptr_v;
|
501 |
+
}
|
502 |
+
|
503 |
+
// Move the load pointers.
|
504 |
+
if( GO_BACKWARD ) {
|
505 |
+
ptr_q -= params.q_stride_L;
|
506 |
+
ptr_k -= params.k_stride_L;
|
507 |
+
ptr_v -= params.v_stride_L;
|
508 |
+
} else {
|
509 |
+
ptr_q += params.q_stride_L;
|
510 |
+
ptr_k += params.k_stride_L;
|
511 |
+
ptr_v += params.v_stride_L;
|
512 |
+
}
|
513 |
+
|
514 |
+
// The number of FLOAT4s per head.
|
515 |
+
constexpr int FLOAT4s_PER_HEAD = E / 4;
|
516 |
+
// The number of FLOAT4s per thread.
|
517 |
+
constexpr int FLOAT4s_PER_THREAD = FLOAT4s_PER_HEAD / THREADS_PER_HEAD;
|
518 |
+
|
519 |
+
// The storage for the K*V^T values.
|
520 |
+
float4 kv[FLOAT4s_PER_THREAD];
|
521 |
+
#pragma unroll
|
522 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
523 |
+
kv[ii] = make_float4(0.f, 0.f, 0.f, 0.f);
|
524 |
+
}
|
525 |
+
|
526 |
+
// The output pointer.
|
527 |
+
float *out_ptr = ¶ms.out[offset_o];
|
528 |
+
|
529 |
+
// Store to shared memory Q and K.
|
530 |
+
if( tidx < E ) {
|
531 |
+
smem_q[smem_curr*smem_buffer_elts + tidx] = ldg_q;
|
532 |
+
smem_k[smem_curr*smem_buffer_elts + tidx] = ldg_k;
|
533 |
+
}
|
534 |
+
|
535 |
+
// Store to shared memory V. All threads store valid values.
|
536 |
+
if( tidx < M ) {
|
537 |
+
smem_v[smem_curr*smem_buffer_elts + tidx] = ldg_v;
|
538 |
+
}
|
539 |
+
|
540 |
+
// The position of the thread in the V dimension.
|
541 |
+
int vo = tidx / THREADS_PER_HEAD;
|
542 |
+
int vi = tidx % THREADS_PER_HEAD;
|
543 |
+
|
544 |
+
// Iterate over the timesteps.
|
545 |
+
for( int ti = 0; ti < params.L; ++ti ) {
|
546 |
+
|
547 |
+
// Is it the last iteration?
|
548 |
+
int is_last = ti == params.L - 1;
|
549 |
+
|
550 |
+
// Trigger the next loads for Q and K.
|
551 |
+
if( !is_last && active_qk ) {
|
552 |
+
ldg_q = *ptr_q;
|
553 |
+
ldg_k = *ptr_k;
|
554 |
+
}
|
555 |
+
|
556 |
+
// Trigger the next loads for V.
|
557 |
+
if( !is_last && active_v ) {
|
558 |
+
ldg_v = *ptr_v;
|
559 |
+
}
|
560 |
+
|
561 |
+
// Move the load pointers.
|
562 |
+
if( GO_BACKWARD ) {
|
563 |
+
ptr_q -= params.q_stride_L;
|
564 |
+
ptr_k -= params.k_stride_L;
|
565 |
+
ptr_v -= params.v_stride_L;
|
566 |
+
} else {
|
567 |
+
ptr_q += params.q_stride_L;
|
568 |
+
ptr_k += params.k_stride_L;
|
569 |
+
ptr_v += params.v_stride_L;
|
570 |
+
}
|
571 |
+
|
572 |
+
// Make sure the data is in shared memory.
|
573 |
+
__syncthreads();
|
574 |
+
|
575 |
+
// Each thread loads 4 values from K.
|
576 |
+
float4 k[FLOAT4s_PER_THREAD];
|
577 |
+
#pragma unroll
|
578 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
579 |
+
int ki = tidx % THREADS_PER_HEAD * 4 + ii * THREADS_PER_HEAD * 4;
|
580 |
+
k[ii] = *reinterpret_cast<const float4*>(&smem_k[smem_curr*smem_buffer_elts + ki]);
|
581 |
+
}
|
582 |
+
|
583 |
+
// Each thread loads a single V value.
|
584 |
+
float v = 0.f;
|
585 |
+
if( vo < params.M ) {
|
586 |
+
v = *reinterpret_cast<const float *>(&smem_v[smem_curr*smem_buffer_elts + vo]);
|
587 |
+
}
|
588 |
+
|
589 |
+
// Update the K*V^T product.
|
590 |
+
#pragma unroll
|
591 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
592 |
+
kv[ii].x += k[ii].x * v;
|
593 |
+
kv[ii].y += k[ii].y * v;
|
594 |
+
kv[ii].z += k[ii].z * v;
|
595 |
+
kv[ii].w += k[ii].w * v;
|
596 |
+
}
|
597 |
+
|
598 |
+
// Load the Q values from shared memory.
|
599 |
+
float4 q[FLOAT4s_PER_THREAD];
|
600 |
+
#pragma unroll
|
601 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
602 |
+
int qi = tidx % THREADS_PER_HEAD * 4 + ii * THREADS_PER_HEAD * 4;
|
603 |
+
q[ii] = *reinterpret_cast<const float4*>(&smem_q[smem_curr*smem_buffer_elts + qi]);
|
604 |
+
}
|
605 |
+
|
606 |
+
// Compute the partial output value for that thread.
|
607 |
+
float sum = 0.f;
|
608 |
+
#pragma unroll
|
609 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
610 |
+
sum += q[ii].x * kv[ii].x;
|
611 |
+
sum += q[ii].y * kv[ii].y;
|
612 |
+
sum += q[ii].z * kv[ii].z;
|
613 |
+
sum += q[ii].w * kv[ii].w;
|
614 |
+
}
|
615 |
+
|
616 |
+
// Finalize the computation of the sum (if we have more than 1 thread per head).
|
617 |
+
if( THREADS_PER_HEAD > 1 ) {
|
618 |
+
|
619 |
+
// Finalize the sum for each head.
|
620 |
+
#pragma unroll
|
621 |
+
for( int mask = THREADS_PER_HEAD / 2; mask >= 1; mask /= 2 ) {
|
622 |
+
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
623 |
+
}
|
624 |
+
|
625 |
+
// Store to shared memory.
|
626 |
+
if( vo < M && vi == 0 ) {
|
627 |
+
smem_o[smem_curr*smem_buffer_elts + vo] = sum;
|
628 |
+
}
|
629 |
+
|
630 |
+
// Make sure the data is in shared memory.
|
631 |
+
__syncthreads();
|
632 |
+
|
633 |
+
// Active threads read the data to store.
|
634 |
+
if( active_v ) {
|
635 |
+
sum = smem_o[smem_curr*smem_buffer_elts + tidx];
|
636 |
+
}
|
637 |
+
|
638 |
+
} // THREADS_PER_HEAD > 1.
|
639 |
+
|
640 |
+
// Store the output. All the threads are active.
|
641 |
+
if( active_v ) {
|
642 |
+
*out_ptr = sum;
|
643 |
+
}
|
644 |
+
|
645 |
+
// Move to next location.
|
646 |
+
if( GO_BACKWARD ) {
|
647 |
+
out_ptr -= params.o_stride_L;
|
648 |
+
} else {
|
649 |
+
out_ptr += params.o_stride_L;
|
650 |
+
}
|
651 |
+
|
652 |
+
// Move the shared memory buffer.
|
653 |
+
smem_curr = (smem_curr + 1) % 2;
|
654 |
+
|
655 |
+
// Store to shared memory for Q and K.
|
656 |
+
if( !is_last && tidx < E ) {
|
657 |
+
smem_q[smem_curr*smem_buffer_elts + tidx] = ldg_q;
|
658 |
+
smem_k[smem_curr*smem_buffer_elts + tidx] = ldg_k;
|
659 |
+
}
|
660 |
+
|
661 |
+
// Store to shared memory for V.
|
662 |
+
if( !is_last && tidx < M ) {
|
663 |
+
smem_v[smem_curr*smem_buffer_elts + tidx] = ldg_v;
|
664 |
+
}
|
665 |
+
}
|
666 |
+
}
|
667 |
+
|
668 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
669 |
+
|
670 |
+
template< int E, int THREADS_PER_HEAD, bool GO_BACKWARD >
|
671 |
+
int lmha_(const Lmha_params<float> ¶ms) {
|
672 |
+
// The M dimension rounded up to 4.
|
673 |
+
int M = round_up(params.M, 4);
|
674 |
+
|
675 |
+
// The number of threads in the block.
|
676 |
+
int block = round_up(max(E, M*THREADS_PER_HEAD), 32);
|
677 |
+
if( block > 512 || params.B > 65535 ) {
|
678 |
+
return 1;
|
679 |
+
}
|
680 |
+
|
681 |
+
// Prepare the kernel.
|
682 |
+
dim3 grid(params.H, params.B);
|
683 |
+
size_t smem = smem_buffer_elts_<E>(params)*2*sizeof(float);
|
684 |
+
lmha_kernel<E, THREADS_PER_HEAD, GO_BACKWARD><<<grid, block, smem>>>(params);
|
685 |
+
return 0;
|
686 |
+
}
|
687 |
+
|
688 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
689 |
+
|
690 |
+
template< bool GO_BACKWARD >
|
691 |
+
int lmha(const Lmha_params<float> ¶ms) {
|
692 |
+
int blocks = params.B * params.H;
|
693 |
+
int res = 1;
|
694 |
+
if( blocks < LOW_OCCUPANCY_THRESHOLD ) {
|
695 |
+
if( params.E <= 32 ) {
|
696 |
+
res = lmha_low_occupancy_< 32, GO_BACKWARD>(params, blocks);
|
697 |
+
} else if( params.E <= 64 ) {
|
698 |
+
res = lmha_low_occupancy_< 64, GO_BACKWARD>(params, blocks);
|
699 |
+
} else if( params.E <= 128 ) {
|
700 |
+
res = lmha_low_occupancy_<128, GO_BACKWARD>(params, blocks);
|
701 |
+
} else if( params.E <= 256 ) {
|
702 |
+
res = lmha_low_occupancy_<256, GO_BACKWARD>(params, blocks);
|
703 |
+
}
|
704 |
+
} else {
|
705 |
+
if( params.E <= 32 ) {
|
706 |
+
res = lmha_< 32, 1, GO_BACKWARD>(params);
|
707 |
+
} else if( params.E <= 48 ) {
|
708 |
+
res = lmha_< 48, 1, GO_BACKWARD>(params);
|
709 |
+
} else if( params.E <= 64 ) {
|
710 |
+
res = lmha_< 64, 1, GO_BACKWARD>(params);
|
711 |
+
} else if( params.E <= 128 ) {
|
712 |
+
res = lmha_<128, 2, GO_BACKWARD>(params);
|
713 |
+
} else if( params.E <= 256 ) {
|
714 |
+
res = lmha_<256, 4, GO_BACKWARD>(params);
|
715 |
+
}
|
716 |
+
}
|
717 |
+
return res;
|
718 |
+
}
|
719 |
+
|
720 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
721 |
+
|
722 |
+
template< typename T >
|
723 |
+
inline void set_params(Lmha_params<T> ¶ms,
|
724 |
+
const torch::Tensor q,
|
725 |
+
const torch::Tensor k,
|
726 |
+
const torch::Tensor v,
|
727 |
+
torch::Tensor o) {
|
728 |
+
|
729 |
+
// Define the pointers.
|
730 |
+
params.out = o.data_ptr<T>();
|
731 |
+
params.q = q.data_ptr<T>();
|
732 |
+
params.k = k.data_ptr<T>();
|
733 |
+
params.v = v.data_ptr<T>();
|
734 |
+
|
735 |
+
// Define the strides.
|
736 |
+
params.q_stride_B = (int) q.stride(0);
|
737 |
+
params.q_stride_H = (int) q.stride(1);
|
738 |
+
params.q_stride_L = (int) q.stride(2);
|
739 |
+
params.k_stride_B = (int) k.stride(0);
|
740 |
+
params.k_stride_H = (int) k.stride(1);
|
741 |
+
params.k_stride_L = (int) k.stride(2);
|
742 |
+
params.v_stride_B = (int) v.stride(0);
|
743 |
+
params.v_stride_H = (int) v.stride(1);
|
744 |
+
params.v_stride_L = (int) v.stride(2);
|
745 |
+
params.o_stride_B = (int) o.stride(0);
|
746 |
+
params.o_stride_H = (int) o.stride(1);
|
747 |
+
params.o_stride_L = (int) o.stride(2);
|
748 |
+
|
749 |
+
// Extract the dimensions.
|
750 |
+
int N = q.size(0);
|
751 |
+
int H = q.size(1);
|
752 |
+
int L = q.size(2);
|
753 |
+
int E = q.size(3);
|
754 |
+
int M = v.size(3);
|
755 |
+
|
756 |
+
params.B = N;
|
757 |
+
params.L = L;
|
758 |
+
params.H = H;
|
759 |
+
params.E = E;
|
760 |
+
params.M = M;
|
761 |
+
}
|
762 |
+
|
763 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
764 |
+
|
765 |
+
int lmha_fwd(const torch::Tensor queries,
|
766 |
+
const torch::Tensor keys,
|
767 |
+
const torch::Tensor values,
|
768 |
+
torch::Tensor product) {
|
769 |
+
|
770 |
+
// Make sure that we are using the correct GPU device
|
771 |
+
torch::DeviceGuard _guard(queries.device());
|
772 |
+
|
773 |
+
// Make sure the inner-most dimension of the tensors is packed.
|
774 |
+
assert(queries.stride(3) == 1);
|
775 |
+
assert(keys .stride(3) == 1);
|
776 |
+
assert(values .stride(3) == 1);
|
777 |
+
assert(product.stride(3) == 1);
|
778 |
+
|
779 |
+
// Extract the dimensions.
|
780 |
+
int N = queries.size(0);
|
781 |
+
int H = queries.size(1);
|
782 |
+
int L = queries.size(2);
|
783 |
+
int E = queries.size(3);
|
784 |
+
int M = values.size (3);
|
785 |
+
|
786 |
+
// The structure of params.
|
787 |
+
Lmha_params<float> params;
|
788 |
+
set_params(params, queries, keys, values, product);
|
789 |
+
|
790 |
+
// Launch the kernel.
|
791 |
+
return lmha<false>(params);
|
792 |
+
}
|
793 |
+
|
794 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
795 |
+
|
796 |
+
template< typename T >
|
797 |
+
struct Lmha_bwd_params {
|
798 |
+
|
799 |
+
// The output buffer for K. Dimensions [B, H, L, D].
|
800 |
+
T *out_k;
|
801 |
+
// The output buffer for V. Dimensions [B, H, L, D].
|
802 |
+
T *out_v;
|
803 |
+
|
804 |
+
// The input Qs. Dimensions [B, H, L, D].
|
805 |
+
const T *q;
|
806 |
+
// The input Ks. Dimensions [B, H, L, D].
|
807 |
+
const T *k;
|
808 |
+
// The input Vs. Dimensions [B, H, L, D].
|
809 |
+
const T *v;
|
810 |
+
// The input Gs. Dimensions [B, H, L, D].
|
811 |
+
const T *g;
|
812 |
+
|
813 |
+
// The dimensions.
|
814 |
+
int B, L, H, M, E;
|
815 |
+
|
816 |
+
// The strides for the input tensors.
|
817 |
+
int q_stride_B, q_stride_L, q_stride_H;
|
818 |
+
int k_stride_B, k_stride_L, k_stride_H;
|
819 |
+
int v_stride_B, v_stride_L, v_stride_H;
|
820 |
+
int g_stride_B, g_stride_L, g_stride_H;
|
821 |
+
|
822 |
+
// The strides for the outputs.
|
823 |
+
int out_k_stride_B, out_k_stride_L, out_k_stride_H;
|
824 |
+
int out_v_stride_B, out_v_stride_L, out_v_stride_H;
|
825 |
+
};
|
826 |
+
|
827 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
828 |
+
|
829 |
+
template< int D, int THREADS_PER_HEAD >
|
830 |
+
__global__ __launch_bounds__(D*THREADS_PER_HEAD*2)
|
831 |
+
void lmha_bwd_kernel(Lmha_bwd_params<float> params) {
|
832 |
+
|
833 |
+
// Make sure D is a multiple of 4.
|
834 |
+
static_assert(D % 4 == 0, "");
|
835 |
+
|
836 |
+
// The shared memory buffers.
|
837 |
+
__shared__ struct Smem { float qg[2*D], kv[2*D], out_kv[2*D]; } smem_[2];
|
838 |
+
|
839 |
+
// The index of the shared memory buffer (for double-buffering).
|
840 |
+
int smem_curr = 0;
|
841 |
+
|
842 |
+
// The sequence processed by that block.
|
843 |
+
const int bi = blockIdx.y;
|
844 |
+
// The head processed by that block.
|
845 |
+
const int hi = blockIdx.x;
|
846 |
+
|
847 |
+
// The linear index of the thread.
|
848 |
+
const int tidx = threadIdx.x;
|
849 |
+
|
850 |
+
// Split the threads into two slices.
|
851 |
+
int so = tidx / (D*THREADS_PER_HEAD);
|
852 |
+
int si = tidx % (D*THREADS_PER_HEAD);
|
853 |
+
|
854 |
+
// The strides for B/L/H for the Q/G tensors.
|
855 |
+
int qg_stride_B, qg_stride_L, qg_stride_H;
|
856 |
+
if( so == 0 ) {
|
857 |
+
qg_stride_B = params.q_stride_B;
|
858 |
+
qg_stride_L = params.q_stride_L;
|
859 |
+
qg_stride_H = params.q_stride_H;
|
860 |
+
} else {
|
861 |
+
qg_stride_B = params.g_stride_B;
|
862 |
+
qg_stride_L = params.g_stride_L;
|
863 |
+
qg_stride_H = params.g_stride_H;
|
864 |
+
}
|
865 |
+
|
866 |
+
// The strides for B/L/H for the K/V tensors.
|
867 |
+
int kv_stride_B, kv_stride_L, kv_stride_H;
|
868 |
+
if( so == 0 ) {
|
869 |
+
kv_stride_B = params.k_stride_B;
|
870 |
+
kv_stride_L = params.k_stride_L;
|
871 |
+
kv_stride_H = params.k_stride_H;
|
872 |
+
} else {
|
873 |
+
kv_stride_B = params.v_stride_B;
|
874 |
+
kv_stride_L = params.v_stride_L;
|
875 |
+
kv_stride_H = params.v_stride_H;
|
876 |
+
}
|
877 |
+
|
878 |
+
// The hidden size.
|
879 |
+
int hidden_size_per_head = 0;
|
880 |
+
if( so == 0 ) {
|
881 |
+
hidden_size_per_head = params.E;
|
882 |
+
} else {
|
883 |
+
hidden_size_per_head = params.M;
|
884 |
+
}
|
885 |
+
|
886 |
+
// Where to start reading from.
|
887 |
+
int offset_qg = bi*qg_stride_B + hi*qg_stride_H + si;
|
888 |
+
int offset_kv = bi*kv_stride_B + hi*kv_stride_H + si;
|
889 |
+
|
890 |
+
// We walk backward, account for the extra offset.
|
891 |
+
offset_qg += (params.L-1)*qg_stride_L;
|
892 |
+
offset_kv += (params.L-1)*kv_stride_L;
|
893 |
+
|
894 |
+
// Determine the base pointers for Q, K, V and G.
|
895 |
+
const float *ptr_qg = &(so == 0 ? params.q : params.g)[offset_qg];
|
896 |
+
const float *ptr_kv = &(so == 0 ? params.k : params.v)[offset_kv];
|
897 |
+
|
898 |
+
// Is it an active thread?
|
899 |
+
const int active = si < hidden_size_per_head;
|
900 |
+
|
901 |
+
// Trigger the memory loads for Q, K, V and G.
|
902 |
+
float ldg_qg = 0.f, ldg_kv = 0.f;
|
903 |
+
if( active ) {
|
904 |
+
ldg_qg = *ptr_qg;
|
905 |
+
ldg_kv = *ptr_kv;
|
906 |
+
}
|
907 |
+
|
908 |
+
// Move the load pointers (backward).
|
909 |
+
ptr_qg -= qg_stride_L;
|
910 |
+
ptr_kv -= kv_stride_L;
|
911 |
+
|
912 |
+
// The number of FLOAT4s per head.
|
913 |
+
constexpr int FLOAT4s_PER_HEAD = D / 4;
|
914 |
+
// The number of FLOAT4s per thread.
|
915 |
+
constexpr int FLOAT4s_PER_THREAD = FLOAT4s_PER_HEAD / THREADS_PER_HEAD;
|
916 |
+
|
917 |
+
// The storage for the G*Q^T or Q^T*G values.
|
918 |
+
float4 gq[FLOAT4s_PER_THREAD];
|
919 |
+
#pragma unroll
|
920 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
921 |
+
gq[ii] = make_float4(0.f, 0.f, 0.f, 0.f);
|
922 |
+
}
|
923 |
+
|
924 |
+
// The strides for B/L/H for the K/V tensors.
|
925 |
+
int out_kv_stride_B, out_kv_stride_L, out_kv_stride_H;
|
926 |
+
if( so == 0 ) {
|
927 |
+
out_kv_stride_B = params.out_k_stride_B;
|
928 |
+
out_kv_stride_L = params.out_k_stride_L;
|
929 |
+
out_kv_stride_H = params.out_k_stride_H;
|
930 |
+
} else {
|
931 |
+
out_kv_stride_B = params.out_v_stride_B;
|
932 |
+
out_kv_stride_L = params.out_v_stride_L;
|
933 |
+
out_kv_stride_H = params.out_v_stride_H;
|
934 |
+
}
|
935 |
+
|
936 |
+
// Where to start reading from.
|
937 |
+
int offset_out_kv = bi*out_kv_stride_B + hi*out_kv_stride_H + si;
|
938 |
+
|
939 |
+
// We walk backward, account for the extra offset.
|
940 |
+
offset_out_kv += (params.L-1)*out_kv_stride_L;
|
941 |
+
|
942 |
+
// The output pointer.
|
943 |
+
float *ptr_out_kv = &(so == 0 ? params.out_k : params.out_v)[offset_out_kv];
|
944 |
+
|
945 |
+
// Store to shared memory.
|
946 |
+
if( si < D ) {
|
947 |
+
smem_[smem_curr].qg[so*D + si] = ldg_qg;
|
948 |
+
smem_[smem_curr].kv[so*D + si] = ldg_kv;
|
949 |
+
}
|
950 |
+
|
951 |
+
// The position of the thread in the output dimension.
|
952 |
+
int oo = si / THREADS_PER_HEAD % D;
|
953 |
+
int oi = si % THREADS_PER_HEAD * 4;
|
954 |
+
|
955 |
+
// Iterate over the timesteps.
|
956 |
+
for( int ti = 0; ti < params.L; ++ti ) {
|
957 |
+
|
958 |
+
// Is it the last iteration?
|
959 |
+
int is_last = ti == params.L - 1;
|
960 |
+
|
961 |
+
// Trigger the next loads.
|
962 |
+
if( !is_last && active ) {
|
963 |
+
ldg_qg = *ptr_qg;
|
964 |
+
ldg_kv = *ptr_kv;
|
965 |
+
}
|
966 |
+
|
967 |
+
// Move the load pointers.
|
968 |
+
ptr_qg -= qg_stride_L;
|
969 |
+
ptr_kv -= kv_stride_L;
|
970 |
+
|
971 |
+
// Make sure the data is in shared memory.
|
972 |
+
__syncthreads();
|
973 |
+
|
974 |
+
// Each thread loads 4 values from G or Q.
|
975 |
+
float4 g[FLOAT4s_PER_THREAD];
|
976 |
+
#pragma unroll
|
977 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
978 |
+
float *smem_ptr = &smem_[smem_curr].qg[(so^1)*D + oi];
|
979 |
+
g[ii] = *reinterpret_cast<const float4*>(&smem_ptr[ii*THREADS_PER_HEAD*4]);
|
980 |
+
}
|
981 |
+
|
982 |
+
// Each thread loads a single from Q or G value.
|
983 |
+
float q = smem_[smem_curr].qg[so*D + oo];
|
984 |
+
|
985 |
+
// Update the G*Q^T or Q*G^T product.
|
986 |
+
#pragma unroll
|
987 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
988 |
+
gq[ii].x += g[ii].x * q;
|
989 |
+
gq[ii].y += g[ii].y * q;
|
990 |
+
gq[ii].z += g[ii].z * q;
|
991 |
+
gq[ii].w += g[ii].w * q;
|
992 |
+
}
|
993 |
+
|
994 |
+
// Load the V or K values from shared memory.
|
995 |
+
float4 v[FLOAT4s_PER_THREAD];
|
996 |
+
#pragma unroll
|
997 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
998 |
+
float *smem_ptr = &smem_[smem_curr].kv[(so^1)*D + oi];
|
999 |
+
v[ii] = *reinterpret_cast<const float4*>(&smem_ptr[ii*THREADS_PER_HEAD*4]);
|
1000 |
+
}
|
1001 |
+
|
1002 |
+
// Compute the partial output value for that thread.
|
1003 |
+
float sum = 0.f;
|
1004 |
+
#pragma unroll
|
1005 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
1006 |
+
sum += v[ii].x * gq[ii].x;
|
1007 |
+
sum += v[ii].y * gq[ii].y;
|
1008 |
+
sum += v[ii].z * gq[ii].z;
|
1009 |
+
sum += v[ii].w * gq[ii].w;
|
1010 |
+
}
|
1011 |
+
|
1012 |
+
// Finalize the computation of the sum (if we have more than 1 thread per head).
|
1013 |
+
if( THREADS_PER_HEAD > 1 ) {
|
1014 |
+
|
1015 |
+
// Finalize the sum for each head.
|
1016 |
+
#pragma unroll
|
1017 |
+
for( int mask = THREADS_PER_HEAD / 2; mask >= 1; mask /= 2 ) {
|
1018 |
+
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
1019 |
+
}
|
1020 |
+
|
1021 |
+
// Store to shared memory.
|
1022 |
+
if( oi == 0 ) {
|
1023 |
+
smem_[smem_curr].out_kv[so*D + oo] = sum;
|
1024 |
+
}
|
1025 |
+
|
1026 |
+
// Make sure the data is in shared memory.
|
1027 |
+
__syncthreads();
|
1028 |
+
|
1029 |
+
// Active threads read the data to store.
|
1030 |
+
if( si < hidden_size_per_head ) {
|
1031 |
+
sum = smem_[smem_curr].out_kv[so*D + si];
|
1032 |
+
}
|
1033 |
+
|
1034 |
+
} // THREADS_PER_HEAD > 1.
|
1035 |
+
|
1036 |
+
// Store the output. All the threads are active.
|
1037 |
+
if( si < hidden_size_per_head ) {
|
1038 |
+
*ptr_out_kv = sum;
|
1039 |
+
}
|
1040 |
+
|
1041 |
+
// Move to next location.
|
1042 |
+
ptr_out_kv -= out_kv_stride_L;
|
1043 |
+
|
1044 |
+
// Move the shared memory buffer.
|
1045 |
+
smem_curr = (smem_curr + 1) % 2;
|
1046 |
+
|
1047 |
+
// Store to shared memory for Q and K.
|
1048 |
+
if( !is_last && si < D ) {
|
1049 |
+
smem_[smem_curr].qg[so*D + si] = ldg_qg;
|
1050 |
+
smem_[smem_curr].kv[so*D + si] = ldg_kv;
|
1051 |
+
}
|
1052 |
+
}
|
1053 |
+
}
|
1054 |
+
|
1055 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1056 |
+
|
1057 |
+
template< int D, int THREADS_PER_HEAD >
|
1058 |
+
int lmha_bwd_(const Lmha_bwd_params<float> ¶ms) {
|
1059 |
+
int block = D*THREADS_PER_HEAD*2;
|
1060 |
+
if( block >= 1024 || params.B > 65535 ) {
|
1061 |
+
return 1;
|
1062 |
+
}
|
1063 |
+
dim3 grid(params.H, params.B);
|
1064 |
+
lmha_bwd_kernel<D, THREADS_PER_HEAD><<<grid, block>>>(params);
|
1065 |
+
return 0;
|
1066 |
+
}
|
1067 |
+
|
1068 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1069 |
+
|
1070 |
+
int lmha_bwd(const Lmha_bwd_params<float> ¶ms) {
|
1071 |
+
int blocks = params.B * params.H;
|
1072 |
+
if( blocks < LOW_OCCUPANCY_THRESHOLD ) {
|
1073 |
+
return 1;
|
1074 |
+
}
|
1075 |
+
|
1076 |
+
int hidden_size_per_head = max(params.E, params.M);
|
1077 |
+
int res = 1;
|
1078 |
+
if( hidden_size_per_head <= 32 ) {
|
1079 |
+
res = lmha_bwd_< 32, 1>(params);
|
1080 |
+
} else if( hidden_size_per_head <= 64 ) {
|
1081 |
+
res = lmha_bwd_< 64, 1>(params);
|
1082 |
+
} else if( hidden_size_per_head <= 128 ) {
|
1083 |
+
res = lmha_bwd_<128, 2>(params);
|
1084 |
+
} else if( hidden_size_per_head <= 256 ) {
|
1085 |
+
res = lmha_bwd_<256, 4>(params);
|
1086 |
+
}
|
1087 |
+
return res;
|
1088 |
+
}
|
1089 |
+
|
1090 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1091 |
+
|
1092 |
+
int lmha_bwd(const torch::Tensor queries,
|
1093 |
+
const torch::Tensor keys,
|
1094 |
+
const torch::Tensor values,
|
1095 |
+
const torch::Tensor grad_out,
|
1096 |
+
torch::Tensor grad_queries,
|
1097 |
+
torch::Tensor grad_keys,
|
1098 |
+
torch::Tensor grad_values) {
|
1099 |
+
|
1100 |
+
// Make sure that we are using the correct GPU device
|
1101 |
+
torch::DeviceGuard _guard(queries.device());
|
1102 |
+
|
1103 |
+
// Make sure the inner-most dimension of the tensors is packed.
|
1104 |
+
assert(queries .stride(3) == 1);
|
1105 |
+
assert(keys .stride(3) == 1);
|
1106 |
+
assert(values .stride(3) == 1);
|
1107 |
+
assert(grad_out .stride(3) == 1);
|
1108 |
+
assert(grad_queries.stride(3) == 1);
|
1109 |
+
assert(grad_keys .stride(3) == 1);
|
1110 |
+
assert(grad_values .stride(3) == 1);
|
1111 |
+
|
1112 |
+
// Extract the dimensions.
|
1113 |
+
int N = queries.size(0);
|
1114 |
+
int H = queries.size(1);
|
1115 |
+
int L = queries.size(2);
|
1116 |
+
int E = queries.size(3);
|
1117 |
+
int M = values.size (3);
|
1118 |
+
|
1119 |
+
// Gradient on Q.
|
1120 |
+
|
1121 |
+
// The structure of params.
|
1122 |
+
Lmha_params<float> params;
|
1123 |
+
set_params(params, grad_out, values, keys, grad_queries);
|
1124 |
+
|
1125 |
+
// Launch the kernel.
|
1126 |
+
int res = lmha<false>(params);
|
1127 |
+
if( res ) {
|
1128 |
+
return res;
|
1129 |
+
}
|
1130 |
+
|
1131 |
+
// Gradient on K and V together.
|
1132 |
+
|
1133 |
+
Lmha_bwd_params<float> bwd_params;
|
1134 |
+
bwd_params.out_k = grad_keys.data_ptr<float>();
|
1135 |
+
bwd_params.out_v = grad_values.data_ptr<float>();
|
1136 |
+
bwd_params.q = queries.data_ptr<float>();
|
1137 |
+
bwd_params.k = keys.data_ptr<float>();
|
1138 |
+
bwd_params.v = values.data_ptr<float>();
|
1139 |
+
bwd_params.g = grad_out.data_ptr<float>();
|
1140 |
+
|
1141 |
+
bwd_params.B = N;
|
1142 |
+
bwd_params.L = L;
|
1143 |
+
bwd_params.H = H;
|
1144 |
+
bwd_params.E = E;
|
1145 |
+
bwd_params.M = M;
|
1146 |
+
|
1147 |
+
bwd_params.q_stride_B = queries.stride(0);
|
1148 |
+
bwd_params.q_stride_H = queries.stride(1);
|
1149 |
+
bwd_params.q_stride_L = queries.stride(2);
|
1150 |
+
bwd_params.k_stride_B = keys.stride(0);
|
1151 |
+
bwd_params.k_stride_H = keys.stride(1);
|
1152 |
+
bwd_params.k_stride_L = keys.stride(2);
|
1153 |
+
bwd_params.v_stride_B = values.stride(0);
|
1154 |
+
bwd_params.v_stride_H = values.stride(1);
|
1155 |
+
bwd_params.v_stride_L = values.stride(2);
|
1156 |
+
bwd_params.g_stride_B = grad_out.stride(0);
|
1157 |
+
bwd_params.g_stride_H = grad_out.stride(1);
|
1158 |
+
bwd_params.g_stride_L = grad_out.stride(2);
|
1159 |
+
|
1160 |
+
bwd_params.out_k_stride_B = grad_keys.stride(0);
|
1161 |
+
bwd_params.out_k_stride_H = grad_keys.stride(1);
|
1162 |
+
bwd_params.out_k_stride_L = grad_keys.stride(2);
|
1163 |
+
bwd_params.out_v_stride_B = grad_values.stride(0);
|
1164 |
+
bwd_params.out_v_stride_H = grad_values.stride(1);
|
1165 |
+
bwd_params.out_v_stride_L = grad_values.stride(2);
|
1166 |
+
|
1167 |
+
// Try to run the fused kernel.
|
1168 |
+
int fallback = lmha_bwd(bwd_params);
|
1169 |
+
|
1170 |
+
// If it failed, fallback on separate kernels for K and V.
|
1171 |
+
if( fallback ) {
|
1172 |
+
|
1173 |
+
// Gradient on K.
|
1174 |
+
|
1175 |
+
// Launch the kernel.
|
1176 |
+
set_params(params, values, grad_out, queries, grad_keys);
|
1177 |
+
res = lmha<true>(params);
|
1178 |
+
if( res ) {
|
1179 |
+
return res;
|
1180 |
+
}
|
1181 |
+
|
1182 |
+
// Gradient on V.
|
1183 |
+
|
1184 |
+
// Launch the kernel.
|
1185 |
+
set_params(params, keys, queries, grad_out, grad_values);
|
1186 |
+
return lmha<true>(params);
|
1187 |
+
}
|
1188 |
+
|
1189 |
+
// It worked...
|
1190 |
+
return 0;
|
1191 |
+
}
|
1192 |
+
|
1193 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1194 |
+
|
1195 |
+
} // namespace nvidia
|
1196 |
+
#endif // #ifdef ENABLE_NVIDIA_OPTIMIZATIONS
|
1197 |
+
|
1198 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1199 |
+
|
1200 |
+
typedef torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> float_accessor;
|
1201 |
+
|
1202 |
+
#define E_BLOCK_SIZE 8
|
1203 |
+
|
1204 |
+
__global__ void causal_dot_product_kernel(
|
1205 |
+
const float_accessor queries,
|
1206 |
+
const float_accessor keys,
|
1207 |
+
const float_accessor values,
|
1208 |
+
float_accessor result,
|
1209 |
+
const int N,
|
1210 |
+
const int H,
|
1211 |
+
const int L,
|
1212 |
+
const int E,
|
1213 |
+
const int M
|
1214 |
+
) {
|
1215 |
+
int n = blockIdx.y;
|
1216 |
+
int h = blockIdx.z;
|
1217 |
+
|
1218 |
+
int e_start = blockIdx.x * E_BLOCK_SIZE;
|
1219 |
+
int m = threadIdx.x % M;
|
1220 |
+
|
1221 |
+
extern __shared__ float shared_mem[];
|
1222 |
+
float* shared_kv = shared_mem;
|
1223 |
+
|
1224 |
+
for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) {
|
1225 |
+
shared_kv[m + e_local * M] = 0;
|
1226 |
+
}
|
1227 |
+
|
1228 |
+
for (int t=0; t<L; t++) {
|
1229 |
+
float res = 0;
|
1230 |
+
for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) {
|
1231 |
+
shared_kv[e_local*M + m] += keys[n][h][t][e_local + e_start] * values[n][h][t][m];
|
1232 |
+
res += queries[n][h][t][e_local + e_start] * shared_kv[e_local*M + m];
|
1233 |
+
}
|
1234 |
+
atomicAdd(
|
1235 |
+
&result[n][h][t][m],
|
1236 |
+
res
|
1237 |
+
);
|
1238 |
+
}
|
1239 |
+
}
|
1240 |
+
|
1241 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1242 |
+
|
1243 |
+
void causal_dot_product_(const torch::Tensor queries,
|
1244 |
+
const torch::Tensor keys,
|
1245 |
+
const torch::Tensor values,
|
1246 |
+
torch::Tensor product) {
|
1247 |
+
// Make sure that we are using the correct GPU device
|
1248 |
+
torch::DeviceGuard _guard(queries.device());
|
1249 |
+
|
1250 |
+
int N = queries.size(0);
|
1251 |
+
int H = queries.size(1);
|
1252 |
+
int L = queries.size(2);
|
1253 |
+
int E = queries.size(3);
|
1254 |
+
int M = values.size(3);
|
1255 |
+
|
1256 |
+
const int blocks_per_sequence = (E + E_BLOCK_SIZE - 1) / E_BLOCK_SIZE;
|
1257 |
+
|
1258 |
+
dim3 blockDim(M, 1, 1);
|
1259 |
+
dim3 gridDim(blocks_per_sequence, N, H);
|
1260 |
+
const int shared_mem_forward = E_BLOCK_SIZE * M * sizeof(float);
|
1261 |
+
|
1262 |
+
causal_dot_product_kernel<<<gridDim, blockDim, shared_mem_forward>>>(
|
1263 |
+
queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1264 |
+
keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1265 |
+
values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1266 |
+
product.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1267 |
+
N, H, L, E, M
|
1268 |
+
);
|
1269 |
+
}
|
1270 |
+
|
1271 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1272 |
+
|
1273 |
+
void causal_dot_product(const torch::Tensor queries,
|
1274 |
+
const torch::Tensor keys,
|
1275 |
+
const torch::Tensor values,
|
1276 |
+
torch::Tensor product) {
|
1277 |
+
#ifdef ENABLE_NVIDIA_OPTIMIZATIONS
|
1278 |
+
int fallback = nvidia::lmha_fwd(queries, keys, values, product);
|
1279 |
+
#else
|
1280 |
+
int fallback = 1;
|
1281 |
+
#endif
|
1282 |
+
if( fallback ) {
|
1283 |
+
causal_dot_product_(queries, keys, values, product);
|
1284 |
+
}
|
1285 |
+
}
|
1286 |
+
|
1287 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1288 |
+
|
1289 |
+
#define M_BLOCK_SIZE 4
|
1290 |
+
|
1291 |
+
// we need shared memory to store
|
1292 |
+
// kv
|
1293 |
+
// Backward direction
|
1294 |
+
// kv_backwards
|
1295 |
+
// Shared memory usage
|
1296 |
+
__global__ void causal_dot_backward_query_key_kernel(
|
1297 |
+
const float_accessor queries,
|
1298 |
+
const float_accessor keys,
|
1299 |
+
const float_accessor values,
|
1300 |
+
const float_accessor grad_out,
|
1301 |
+
float_accessor grad_queries,
|
1302 |
+
float_accessor grad_keys,
|
1303 |
+
int N,
|
1304 |
+
int H,
|
1305 |
+
int L,
|
1306 |
+
int E,
|
1307 |
+
int M
|
1308 |
+
) {
|
1309 |
+
int n = blockIdx.y;
|
1310 |
+
int h = blockIdx.z;
|
1311 |
+
|
1312 |
+
int m_start = blockIdx.x * M_BLOCK_SIZE;
|
1313 |
+
int e = threadIdx.x % E;
|
1314 |
+
|
1315 |
+
extern __shared__ float shared_mem[];
|
1316 |
+
const int shared_kv_size = M_BLOCK_SIZE * E;
|
1317 |
+
float* shared_kv = shared_mem;
|
1318 |
+
float* shared_kv_bw = shared_mem + shared_kv_size;
|
1319 |
+
|
1320 |
+
for (int m_local = 0; m_local < M_BLOCK_SIZE && m_local + m_start < M; m_local++) {
|
1321 |
+
shared_kv[m_local * E + e] = 0;
|
1322 |
+
shared_kv_bw[m_local * E + e] = 0;
|
1323 |
+
}
|
1324 |
+
|
1325 |
+
for (int l=0; l<L; l++) {
|
1326 |
+
float res = 0, res_bw = 0;
|
1327 |
+
int l_b = L - l - 1;
|
1328 |
+
for (int m_local = 0; m_local < M_BLOCK_SIZE && m_local + m_start < M; m_local++) {
|
1329 |
+
shared_kv[m_local*E + e] += keys[n][h][l][e] * values[n][h][l][m_start + m_local];
|
1330 |
+
shared_kv_bw[m_local*E + e] += queries[n][h][l_b][e] * grad_out[n][h][l_b][m_start + m_local];
|
1331 |
+
res += grad_out[n][h][l][m_start + m_local] * shared_kv[m_local*E + e];
|
1332 |
+
res_bw += values[n][h][l_b][m_start + m_local] * shared_kv_bw[m_local*E + e];
|
1333 |
+
}
|
1334 |
+
atomicAdd(
|
1335 |
+
&grad_queries[n][h][l][e],
|
1336 |
+
res
|
1337 |
+
);
|
1338 |
+
atomicAdd(
|
1339 |
+
&grad_keys[n][h][l_b][e],
|
1340 |
+
res_bw
|
1341 |
+
);
|
1342 |
+
}
|
1343 |
+
}
|
1344 |
+
|
1345 |
+
|
1346 |
+
__global__ void causal_dot_backward_value_kernel(
|
1347 |
+
const float_accessor queries,
|
1348 |
+
const float_accessor keys,
|
1349 |
+
const float_accessor values,
|
1350 |
+
const float_accessor grad_out,
|
1351 |
+
float_accessor grad_keys,
|
1352 |
+
float_accessor grad_values,
|
1353 |
+
int N,
|
1354 |
+
int H,
|
1355 |
+
int L,
|
1356 |
+
int E,
|
1357 |
+
int M
|
1358 |
+
) {
|
1359 |
+
int n = blockIdx.y;
|
1360 |
+
int h = blockIdx.z;
|
1361 |
+
|
1362 |
+
int e_start = blockIdx.x * E_BLOCK_SIZE;
|
1363 |
+
int m = threadIdx.x % M;
|
1364 |
+
|
1365 |
+
extern __shared__ float shared_mem[];
|
1366 |
+
float* shared_kv = shared_mem;
|
1367 |
+
for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) {
|
1368 |
+
shared_kv[m + e_local * M] = 0;
|
1369 |
+
}
|
1370 |
+
|
1371 |
+
for (int l = 0; l < L; l++) {
|
1372 |
+
int l_b = L - l -1;
|
1373 |
+
float res = 0;
|
1374 |
+
for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) {
|
1375 |
+
shared_kv[e_local*M + m] += queries[n][h][l_b][e_start + e_local] * grad_out[n][h][l_b][m];
|
1376 |
+
res += keys[n][h][l_b][e_start + e_local] * shared_kv[e_local*M + m];
|
1377 |
+
}
|
1378 |
+
atomicAdd(
|
1379 |
+
&grad_values[n][h][l_b][m],
|
1380 |
+
res
|
1381 |
+
);
|
1382 |
+
}
|
1383 |
+
}
|
1384 |
+
|
1385 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1386 |
+
|
1387 |
+
void causal_dot_backward_(const torch::Tensor queries,
|
1388 |
+
const torch::Tensor keys,
|
1389 |
+
const torch::Tensor values,
|
1390 |
+
const torch::Tensor grad_out,
|
1391 |
+
torch::Tensor grad_queries,
|
1392 |
+
torch::Tensor grad_keys,
|
1393 |
+
torch::Tensor grad_values) {
|
1394 |
+
|
1395 |
+
// Make sure that we are using the correct GPU device
|
1396 |
+
torch::DeviceGuard _guard(queries.device());
|
1397 |
+
|
1398 |
+
int N = queries.size(0);
|
1399 |
+
int H = queries.size(1);
|
1400 |
+
int L = queries.size(2);
|
1401 |
+
int E = queries.size(3);
|
1402 |
+
int M = values.size(3);
|
1403 |
+
|
1404 |
+
const int blocks_per_sequence = (M + M_BLOCK_SIZE - 1) / M_BLOCK_SIZE;
|
1405 |
+
|
1406 |
+
dim3 blockDim(E, 1, 1);
|
1407 |
+
dim3 gridDim(blocks_per_sequence, N, H);
|
1408 |
+
const int shared_mem_qk_backward = 2 * M_BLOCK_SIZE * E * sizeof(float);
|
1409 |
+
|
1410 |
+
causal_dot_backward_query_key_kernel<<<gridDim, blockDim, shared_mem_qk_backward>>>(
|
1411 |
+
queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1412 |
+
keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1413 |
+
values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1414 |
+
grad_out.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1415 |
+
grad_queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1416 |
+
grad_keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1417 |
+
N, H, L, E, M
|
1418 |
+
);
|
1419 |
+
|
1420 |
+
const int blocks_per_sequence_value = (E + E_BLOCK_SIZE - 1) / E_BLOCK_SIZE;
|
1421 |
+
|
1422 |
+
dim3 blockDimv(M, 1, 1);
|
1423 |
+
dim3 gridDimv(blocks_per_sequence_value, N, H);
|
1424 |
+
const int shared_mem_v_backward = E_BLOCK_SIZE * M * sizeof(float);
|
1425 |
+
causal_dot_backward_value_kernel<<<gridDimv, blockDimv, shared_mem_v_backward>>>(
|
1426 |
+
queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1427 |
+
keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1428 |
+
values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1429 |
+
grad_out.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1430 |
+
grad_keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1431 |
+
grad_values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1432 |
+
N, H, L, E, M
|
1433 |
+
);
|
1434 |
+
}
|
1435 |
+
|
1436 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1437 |
+
|
1438 |
+
void causal_dot_backward(const torch::Tensor queries,
|
1439 |
+
const torch::Tensor keys,
|
1440 |
+
const torch::Tensor values,
|
1441 |
+
const torch::Tensor grad_out,
|
1442 |
+
torch::Tensor grad_queries,
|
1443 |
+
torch::Tensor grad_keys,
|
1444 |
+
torch::Tensor grad_values) {
|
1445 |
+
#ifdef ENABLE_NVIDIA_OPTIMIZATIONS
|
1446 |
+
int fallback = nvidia::lmha_bwd(queries,
|
1447 |
+
keys,
|
1448 |
+
values,
|
1449 |
+
grad_out,
|
1450 |
+
grad_queries,
|
1451 |
+
grad_keys,
|
1452 |
+
grad_values);
|
1453 |
+
#else
|
1454 |
+
int fallback = 1;
|
1455 |
+
#endif
|
1456 |
+
if( fallback ) {
|
1457 |
+
// Make sure that the gradient tensors are 0. This is needed because the
|
1458 |
+
// bwd pass might have partially executed and filled in some values in
|
1459 |
+
// grad_queries or grad_keys.
|
1460 |
+
//
|
1461 |
+
// This adds a small overhead every time we have to fall back to the old
|
1462 |
+
// kernel for the backward pass.
|
1463 |
+
grad_queries.zero_();
|
1464 |
+
grad_keys.zero_();
|
1465 |
+
causal_dot_backward_(queries, keys, values, grad_out, grad_queries, grad_keys, grad_values);
|
1466 |
+
}
|
1467 |
+
}
|
1468 |
+
|
1469 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1470 |
+
|
1471 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
1472 |
+
m.def(
|
1473 |
+
"causal_dot_product",
|
1474 |
+
&causal_dot_product,
|
1475 |
+
"Compute the weighted sum of values but attending only to previous "
|
1476 |
+
"values."
|
1477 |
+
);
|
1478 |
+
m.def(
|
1479 |
+
"causal_dot_backward",
|
1480 |
+
&causal_dot_backward,
|
1481 |
+
"Compute the gradients for the causal dot product."
|
1482 |
+
);
|
1483 |
+
}
|
csrc/causal_attention_kv_cuda.cu
ADDED
@@ -0,0 +1,1483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
//
|
2 |
+
// Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
// Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
4 |
+
// Apoorv Vyas <avyas@idiap.ch>
|
5 |
+
//
|
6 |
+
|
7 |
+
//
|
8 |
+
// For modifications made inside namespace nvidia (authored by jdemouth):
|
9 |
+
//
|
10 |
+
// Copyright (c) 2021 NVIDIA CORPORATION. All rights reserved.
|
11 |
+
//
|
12 |
+
// Permission is hereby granted, free of charge, to any person obtaining a copy of
|
13 |
+
// this software and associated documentation files (the "Software"), to deal in
|
14 |
+
// the Software without restriction, including without limitation the rights to
|
15 |
+
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
16 |
+
// the Software, and to permit persons to whom the Software is furnished to do so,
|
17 |
+
// subject to the following conditions:
|
18 |
+
//
|
19 |
+
// The above copyright notice and this permission notice shall be included in all
|
20 |
+
// copies or substantial portions of the Software.
|
21 |
+
//
|
22 |
+
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
23 |
+
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
24 |
+
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
25 |
+
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
26 |
+
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
27 |
+
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
28 |
+
//
|
29 |
+
|
30 |
+
#include <torch/extension.h>
|
31 |
+
#include <assert.h>
|
32 |
+
#include <stdio.h>
|
33 |
+
|
34 |
+
#define ENABLE_NVIDIA_OPTIMIZATIONS
|
35 |
+
|
36 |
+
#ifdef ENABLE_NVIDIA_OPTIMIZATIONS
|
37 |
+
namespace nvidia {
|
38 |
+
|
39 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
40 |
+
|
41 |
+
constexpr int THREADS_PER_WARP = 32;
|
42 |
+
|
43 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
44 |
+
|
45 |
+
constexpr int LOW_OCCUPANCY_THRESHOLD = 40; // TODO: Make it HW specific (like 1/2 SMs).
|
46 |
+
|
47 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
48 |
+
|
49 |
+
static inline __device__ __host__ int div_up(int m, int n) {
|
50 |
+
return (m + n-1) / n;
|
51 |
+
}
|
52 |
+
|
53 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
54 |
+
|
55 |
+
static inline __device__ __host__ int round_up(int m, int n) {
|
56 |
+
return div_up(m, n) * n;
|
57 |
+
}
|
58 |
+
|
59 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
60 |
+
|
61 |
+
template< typename T >
|
62 |
+
struct Lmha_params {
|
63 |
+
|
64 |
+
// The output buffer. Dimensions [B, H, L, M].
|
65 |
+
T *out;
|
66 |
+
|
67 |
+
// The input Qs. Dimensions [B, H, L, E].
|
68 |
+
const T *q;
|
69 |
+
// The input Ks. Dimensions [B, H, L, E].
|
70 |
+
const T *k;
|
71 |
+
// The input Vs. Dimensions [B, H, L, M].
|
72 |
+
const T *v;
|
73 |
+
|
74 |
+
// The different dimensions.
|
75 |
+
int B, L, H, E, M;
|
76 |
+
|
77 |
+
// The strides for the different tensors.
|
78 |
+
int q_stride_B, q_stride_H, q_stride_L;
|
79 |
+
int k_stride_B, k_stride_H, k_stride_L;
|
80 |
+
int v_stride_B, v_stride_H, v_stride_L;
|
81 |
+
int o_stride_B, o_stride_H, o_stride_L;
|
82 |
+
};
|
83 |
+
|
84 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
85 |
+
|
86 |
+
template< int E, bool GO_BACKWARD, int WARPS, int COLS_PER_THREAD = 4 >
|
87 |
+
__global__ __launch_bounds__(WARPS * THREADS_PER_WARP)
|
88 |
+
void lmha_low_occupancy_kernel(Lmha_params<float> params) {
|
89 |
+
|
90 |
+
// The number of threads per block.
|
91 |
+
constexpr int THREADS_PER_BLOCK = WARPS * THREADS_PER_WARP;
|
92 |
+
// The number of rows per thread.
|
93 |
+
constexpr int ROWS_PER_THREAD = E / THREADS_PER_WARP;
|
94 |
+
// The number of steps per iteration.
|
95 |
+
constexpr int COLS_PER_ITER = WARPS * COLS_PER_THREAD;
|
96 |
+
|
97 |
+
// Make sure E is a multiple of the warp size.
|
98 |
+
static_assert(E % THREADS_PER_WARP == 0, "");
|
99 |
+
|
100 |
+
// Shared memory to store V/O.
|
101 |
+
__shared__ float smem_v[COLS_PER_ITER], smem_o[COLS_PER_ITER];
|
102 |
+
// Shared memory buffer to performance the reductions.
|
103 |
+
__shared__ float smem_reds[E * WARPS];
|
104 |
+
|
105 |
+
// The sequence processed by that block.
|
106 |
+
const int bi = blockIdx.z;
|
107 |
+
// The head processed by that block.
|
108 |
+
const int hi = blockIdx.y;
|
109 |
+
// The hidden cell in the V/output buffers.
|
110 |
+
const int vi = blockIdx.x;
|
111 |
+
|
112 |
+
// The linear index of the thread.
|
113 |
+
const int tidx = threadIdx.x;
|
114 |
+
|
115 |
+
// Decompose the block in warp/lane.
|
116 |
+
const int warp = tidx / THREADS_PER_WARP;
|
117 |
+
const int lane = tidx % THREADS_PER_WARP;
|
118 |
+
|
119 |
+
// The base offset loaded by the thread in Q and K.
|
120 |
+
int offset_q = bi*params.q_stride_B + hi*params.q_stride_H + lane;
|
121 |
+
int offset_k = bi*params.k_stride_B + hi*params.k_stride_H + lane;
|
122 |
+
|
123 |
+
// If we walk backward, account for the extra offset.
|
124 |
+
if( GO_BACKWARD ) {
|
125 |
+
offset_q += (params.L-1)*params.q_stride_L;
|
126 |
+
offset_k += (params.L-1)*params.k_stride_L;
|
127 |
+
}
|
128 |
+
|
129 |
+
// Position the warp at the beginning of the proper timestep.
|
130 |
+
if( GO_BACKWARD ) {
|
131 |
+
offset_q -= warp*COLS_PER_THREAD*params.q_stride_L;
|
132 |
+
offset_k -= warp*COLS_PER_THREAD*params.k_stride_L;
|
133 |
+
} else {
|
134 |
+
offset_q += warp*COLS_PER_THREAD*params.q_stride_L;
|
135 |
+
offset_k += warp*COLS_PER_THREAD*params.k_stride_L;
|
136 |
+
}
|
137 |
+
|
138 |
+
// Determine the base pointers for Q and K.
|
139 |
+
const float *ptr_q = ¶ms.q[offset_q];
|
140 |
+
const float *ptr_k = ¶ms.k[offset_k];
|
141 |
+
|
142 |
+
// Is a given row valid?
|
143 |
+
int valid_qk[ROWS_PER_THREAD];
|
144 |
+
#pragma unroll
|
145 |
+
for( int ii = 0; ii < ROWS_PER_THREAD; ++ii ) {
|
146 |
+
valid_qk[ii] = lane + ii*THREADS_PER_WARP < params.E;
|
147 |
+
}
|
148 |
+
|
149 |
+
// The offset to the position loaded by the thread in V.
|
150 |
+
int offset_v = bi*params.v_stride_B + hi*params.v_stride_H + vi;
|
151 |
+
int offset_o = bi*params.o_stride_B + hi*params.o_stride_H + vi;
|
152 |
+
|
153 |
+
// If we walk backward, account for the extra offset.
|
154 |
+
if( GO_BACKWARD ) {
|
155 |
+
offset_v += (params.L-1)*params.v_stride_L;
|
156 |
+
offset_o += (params.L-1)*params.o_stride_L;
|
157 |
+
}
|
158 |
+
|
159 |
+
// We load/store a strided matrix of COLS_PER_ITER x OUTPUTS_PER_BLOCK.
|
160 |
+
if( GO_BACKWARD ) {
|
161 |
+
offset_v -= tidx*params.v_stride_L;
|
162 |
+
offset_o -= tidx*params.o_stride_L;
|
163 |
+
} else {
|
164 |
+
offset_v += tidx*params.v_stride_L;
|
165 |
+
offset_o += tidx*params.o_stride_L;
|
166 |
+
}
|
167 |
+
|
168 |
+
// Determine the base pointer for V.
|
169 |
+
const float *ptr_v = ¶ms.v[offset_v];
|
170 |
+
// The output pointer.
|
171 |
+
float *ptr_o = ¶ms.out[offset_o];
|
172 |
+
|
173 |
+
// The running KVs.
|
174 |
+
float running_kv[ROWS_PER_THREAD];
|
175 |
+
#pragma unroll
|
176 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
177 |
+
running_kv[ri] = 0.f;
|
178 |
+
}
|
179 |
+
|
180 |
+
// Iterate over the timesteps. TODO: Use params.loop_count!!!
|
181 |
+
for( int iter = 0; iter < params.L; iter += COLS_PER_ITER ) {
|
182 |
+
|
183 |
+
// Each thread loads a matrix of elements.
|
184 |
+
float q[ROWS_PER_THREAD][COLS_PER_THREAD], k[ROWS_PER_THREAD][COLS_PER_THREAD];
|
185 |
+
|
186 |
+
// Trigger the memory loads for Q and K.
|
187 |
+
#pragma unroll
|
188 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
189 |
+
#pragma unroll
|
190 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
191 |
+
|
192 |
+
// For Q/K, each warp loads from various timesteps.
|
193 |
+
int ti = iter + warp*COLS_PER_THREAD;
|
194 |
+
if( GO_BACKWARD ) {
|
195 |
+
ti = params.L - 1 - ti;
|
196 |
+
}
|
197 |
+
|
198 |
+
// Is it a valid access?
|
199 |
+
int valid;
|
200 |
+
if( GO_BACKWARD ) {
|
201 |
+
valid = valid_qk[ri] && ti - ci >= 0;
|
202 |
+
} else {
|
203 |
+
valid = valid_qk[ri] && ti + ci < params.L;
|
204 |
+
}
|
205 |
+
|
206 |
+
// The extra offset to add.
|
207 |
+
if( GO_BACKWARD ) {
|
208 |
+
offset_q = ri*THREADS_PER_WARP - ci*params.q_stride_L;
|
209 |
+
offset_k = ri*THREADS_PER_WARP - ci*params.k_stride_L;
|
210 |
+
} else {
|
211 |
+
offset_q = ri*THREADS_PER_WARP + ci*params.q_stride_L;
|
212 |
+
offset_k = ri*THREADS_PER_WARP + ci*params.k_stride_L;
|
213 |
+
}
|
214 |
+
|
215 |
+
// Load Q/K if they are valid.
|
216 |
+
q[ri][ci] = valid ? ptr_q[offset_q] : 0.f;
|
217 |
+
k[ri][ci] = valid ? ptr_k[offset_k] : 0.f;
|
218 |
+
}
|
219 |
+
}
|
220 |
+
|
221 |
+
// For the V tensor, we assign contiguous thread to different loads. So, ti is different.
|
222 |
+
int ti = iter + tidx;
|
223 |
+
if( GO_BACKWARD ) {
|
224 |
+
ti = params.L - 1 - ti;
|
225 |
+
}
|
226 |
+
|
227 |
+
// Is it a valid access?
|
228 |
+
int valid_vo = tidx < COLS_PER_ITER;
|
229 |
+
if( GO_BACKWARD ) {
|
230 |
+
valid_vo &= ti >= 0;
|
231 |
+
} else {
|
232 |
+
valid_vo &= ti < params.L;
|
233 |
+
}
|
234 |
+
|
235 |
+
// Trigger the loads for V.
|
236 |
+
float ldg_v = valid_vo ? *ptr_v : 0.f;
|
237 |
+
|
238 |
+
// Move the load pointers.
|
239 |
+
if( GO_BACKWARD ) {
|
240 |
+
ptr_q -= COLS_PER_ITER*params.q_stride_L;
|
241 |
+
ptr_k -= COLS_PER_ITER*params.k_stride_L;
|
242 |
+
ptr_v -= COLS_PER_ITER*params.v_stride_L;
|
243 |
+
} else {
|
244 |
+
ptr_q += COLS_PER_ITER*params.q_stride_L;
|
245 |
+
ptr_k += COLS_PER_ITER*params.k_stride_L;
|
246 |
+
ptr_v += COLS_PER_ITER*params.v_stride_L;
|
247 |
+
}
|
248 |
+
|
249 |
+
// Store to shared memory.
|
250 |
+
if( tidx < COLS_PER_ITER ) {
|
251 |
+
smem_v[tidx] = ldg_v;
|
252 |
+
}
|
253 |
+
|
254 |
+
// Make sure V is in shared memory.
|
255 |
+
__syncthreads();
|
256 |
+
|
257 |
+
// Read V from shared memory.
|
258 |
+
float v[COLS_PER_THREAD];
|
259 |
+
#pragma unroll
|
260 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
261 |
+
v[ci] = smem_v[warp*COLS_PER_THREAD + ci];
|
262 |
+
}
|
263 |
+
|
264 |
+
// Each thread computes local K*V products.
|
265 |
+
float kv[ROWS_PER_THREAD][COLS_PER_THREAD];
|
266 |
+
#pragma unroll
|
267 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
268 |
+
#pragma unroll
|
269 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
270 |
+
kv[ri][ci] = 0.f;
|
271 |
+
}
|
272 |
+
}
|
273 |
+
|
274 |
+
// Update the K*V^T product.
|
275 |
+
#pragma unroll
|
276 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
277 |
+
#pragma unroll
|
278 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
279 |
+
kv[ri][ci] += k[ri][ci] * v[ci];
|
280 |
+
}
|
281 |
+
}
|
282 |
+
|
283 |
+
// We must perform the prefix sums within the thread-block. Start with the thread.
|
284 |
+
#pragma unroll
|
285 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
286 |
+
#pragma unroll
|
287 |
+
for( int ci = 1; ci < COLS_PER_THREAD; ++ci ) {
|
288 |
+
kv[ri][ci] += kv[ri][ci-1];
|
289 |
+
}
|
290 |
+
}
|
291 |
+
|
292 |
+
// Store the partial sums to shared memory. Unless we have no inter-warp reduction to perform.
|
293 |
+
#pragma unroll
|
294 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
295 |
+
smem_reds[warp*E + ri*THREADS_PER_WARP + lane] = kv[ri][COLS_PER_THREAD-1];
|
296 |
+
}
|
297 |
+
|
298 |
+
// Make sure the data is in shared memory.
|
299 |
+
__syncthreads();
|
300 |
+
|
301 |
+
// Each thread deals with one or more column(s) of the matrix.
|
302 |
+
constexpr int SUMS_PER_THREAD = (E + THREADS_PER_BLOCK-1) / THREADS_PER_BLOCK;
|
303 |
+
#pragma unroll
|
304 |
+
for( int ii = 0, idx = tidx; ii < SUMS_PER_THREAD; ++ii, idx += THREADS_PER_BLOCK ) {
|
305 |
+
if( idx < E ) {
|
306 |
+
float sum = smem_reds[idx];
|
307 |
+
#pragma unroll
|
308 |
+
for( int jj = 1; jj < WARPS; ++jj ) {
|
309 |
+
smem_reds[idx + jj*E] = sum += smem_reds[idx + jj*E];
|
310 |
+
}
|
311 |
+
}
|
312 |
+
}
|
313 |
+
|
314 |
+
// Make sure the reductions are stored in shared memory.
|
315 |
+
__syncthreads();
|
316 |
+
|
317 |
+
// Each thread updates his partial products.
|
318 |
+
#pragma unroll
|
319 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
320 |
+
float sum = running_kv[ri];
|
321 |
+
if( warp > 0 ) {
|
322 |
+
sum += smem_reds[(warp-1)*E + lane + ri*THREADS_PER_WARP];
|
323 |
+
}
|
324 |
+
#pragma unroll
|
325 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
326 |
+
kv[ri][ci] += sum;
|
327 |
+
}
|
328 |
+
}
|
329 |
+
|
330 |
+
// Compute the partial output values for that thread.
|
331 |
+
float sum[COLS_PER_THREAD];
|
332 |
+
#pragma unroll
|
333 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
334 |
+
sum[ci] = q[0][ci] * kv[0][ci];
|
335 |
+
#pragma unroll
|
336 |
+
for( int ri = 1; ri < ROWS_PER_THREAD; ++ri ) {
|
337 |
+
sum[ci] += q[ri][ci] * kv[ri][ci];
|
338 |
+
}
|
339 |
+
}
|
340 |
+
|
341 |
+
// Run the parallel reductions inside the warp.
|
342 |
+
#pragma unroll
|
343 |
+
for( int mask = THREADS_PER_WARP / 2; mask >= 1; mask /= 2 ) {
|
344 |
+
#pragma unroll
|
345 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
346 |
+
sum[ci] += __shfl_xor_sync(uint32_t(-1), sum[ci], mask);
|
347 |
+
}
|
348 |
+
}
|
349 |
+
|
350 |
+
// Store the final output to shared memory.
|
351 |
+
if( lane == 0 ) {
|
352 |
+
#pragma unroll
|
353 |
+
for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) {
|
354 |
+
smem_o[warp*COLS_PER_THREAD + ci] = sum[ci];
|
355 |
+
}
|
356 |
+
}
|
357 |
+
|
358 |
+
// Make sure the data is in shared memory.
|
359 |
+
__syncthreads();
|
360 |
+
|
361 |
+
// Store the output.
|
362 |
+
if( valid_vo ) {
|
363 |
+
*ptr_o = smem_o[tidx];
|
364 |
+
}
|
365 |
+
|
366 |
+
// Each thread updates his running kv.
|
367 |
+
#pragma unroll
|
368 |
+
for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) {
|
369 |
+
running_kv[ri] += smem_reds[(WARPS-1)*E + lane + ri*THREADS_PER_WARP];
|
370 |
+
}
|
371 |
+
|
372 |
+
// Move to next location.
|
373 |
+
if( GO_BACKWARD ) {
|
374 |
+
ptr_o -= COLS_PER_ITER*params.o_stride_L;
|
375 |
+
} else {
|
376 |
+
ptr_o += COLS_PER_ITER*params.o_stride_L;
|
377 |
+
}
|
378 |
+
}
|
379 |
+
}
|
380 |
+
|
381 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
382 |
+
|
383 |
+
template< int E, bool GO_BACKWARD, int WARPS >
|
384 |
+
int lmha_low_occupancy_(const Lmha_params<float> ¶ms) {
|
385 |
+
|
386 |
+
// Make sure we are not going to launch an invalid grid.
|
387 |
+
if( params.H > 65535 || params.B > 65535 ) {
|
388 |
+
return 1;
|
389 |
+
}
|
390 |
+
|
391 |
+
// Prepare the grid and trigger the CUDA kernel.
|
392 |
+
dim3 grid;
|
393 |
+
grid.x = params.M;
|
394 |
+
grid.y = params.H;
|
395 |
+
grid.z = params.B;
|
396 |
+
lmha_low_occupancy_kernel<E, GO_BACKWARD, WARPS><<<grid, WARPS*THREADS_PER_WARP>>>(params);
|
397 |
+
return 0;
|
398 |
+
}
|
399 |
+
|
400 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
401 |
+
|
402 |
+
template< int E, bool GO_BACKWARD >
|
403 |
+
int lmha_low_occupancy_(const Lmha_params<float> ¶ms, int blocks) {
|
404 |
+
if( params.M * blocks >= 8*LOW_OCCUPANCY_THRESHOLD ) {
|
405 |
+
return lmha_low_occupancy_<E, GO_BACKWARD, 4>(params);
|
406 |
+
} else if( params.M * blocks >= 4*LOW_OCCUPANCY_THRESHOLD ) {
|
407 |
+
return lmha_low_occupancy_<E, GO_BACKWARD, 8>(params);
|
408 |
+
} else {
|
409 |
+
return lmha_low_occupancy_<E, GO_BACKWARD, 16>(params);
|
410 |
+
}
|
411 |
+
return 1;
|
412 |
+
}
|
413 |
+
|
414 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
415 |
+
|
416 |
+
template< int E, typename Params >
|
417 |
+
static inline __device__ __host__ int smem_buffer_elts_(const Params ¶ms) {
|
418 |
+
int M = round_up(params.M, 4);
|
419 |
+
return 2*E + 2*M;
|
420 |
+
}
|
421 |
+
|
422 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
423 |
+
|
424 |
+
template< int E, int THREADS_PER_HEAD, bool GO_BACKWARD >
|
425 |
+
__global__
|
426 |
+
void lmha_kernel(Lmha_params<float> params) {
|
427 |
+
|
428 |
+
// Make sure E is a multiple of 4.
|
429 |
+
static_assert(E % 4 == 0, "");
|
430 |
+
|
431 |
+
// The amount of shared memory per buffer (2 buffers for double-buffering).
|
432 |
+
const int smem_buffer_elts = smem_buffer_elts_<E>(params);
|
433 |
+
// The M dimension for shared memory.
|
434 |
+
const int M = round_up(params.M, 4);
|
435 |
+
|
436 |
+
// Shared memory to store Q, K and V. Size is 2*smem_buffer_elts.
|
437 |
+
extern __shared__ float smem_[];
|
438 |
+
|
439 |
+
// The various shared memory buffers.
|
440 |
+
float *smem_q = &smem_[0*E];
|
441 |
+
float *smem_k = &smem_[1*E];
|
442 |
+
float *smem_v = &smem_[2*E];
|
443 |
+
float *smem_o = &smem_[2*E + M];
|
444 |
+
|
445 |
+
// The index of the shared memory buffer (for double-buffering).
|
446 |
+
int smem_curr = 0;
|
447 |
+
|
448 |
+
// The sequence processed by that block.
|
449 |
+
const int bi = blockIdx.y;
|
450 |
+
// The head processed by that block.
|
451 |
+
const int hi = blockIdx.x;
|
452 |
+
|
453 |
+
// The linear index of the thread.
|
454 |
+
const int tidx = threadIdx.x;
|
455 |
+
|
456 |
+
// The offset to the position loaded by the thread in Q.
|
457 |
+
int offset_q = bi*params.q_stride_B + hi*params.q_stride_H + tidx;
|
458 |
+
// The offset to the position loaded by the thread in K.
|
459 |
+
int offset_k = bi*params.k_stride_B + hi*params.k_stride_H + tidx;
|
460 |
+
|
461 |
+
// If we walk backward, account for the extra offset.
|
462 |
+
if( GO_BACKWARD ) {
|
463 |
+
offset_q += (params.L-1)*params.q_stride_L;
|
464 |
+
offset_k += (params.L-1)*params.k_stride_L;
|
465 |
+
}
|
466 |
+
|
467 |
+
// Determine the base pointers for Q and K.
|
468 |
+
const float *ptr_q = ¶ms.q[offset_q];
|
469 |
+
const float *ptr_k = ¶ms.k[offset_k];
|
470 |
+
|
471 |
+
// The offset to the position loaded by the thread in V and O.
|
472 |
+
int offset_v = bi*params.v_stride_B + hi*params.v_stride_H + tidx;
|
473 |
+
int offset_o = bi*params.o_stride_B + hi*params.o_stride_H + tidx;
|
474 |
+
|
475 |
+
// If we walk backward, account for the extra offset.
|
476 |
+
if( GO_BACKWARD ) {
|
477 |
+
offset_v += (params.L-1)*params.v_stride_L;
|
478 |
+
offset_o += (params.L-1)*params.o_stride_L;
|
479 |
+
}
|
480 |
+
|
481 |
+
// Determine the base pointers for V.
|
482 |
+
const float *ptr_v = ¶ms.v[offset_v];
|
483 |
+
|
484 |
+
// Is it an active Q/K thread?
|
485 |
+
const int active_qk = tidx < params.E;
|
486 |
+
|
487 |
+
// Trigger the memory loads for Q and K.
|
488 |
+
float ldg_q = 0.f, ldg_k = 0.f;
|
489 |
+
if( active_qk ) {
|
490 |
+
ldg_q = *ptr_q;
|
491 |
+
ldg_k = *ptr_k;
|
492 |
+
}
|
493 |
+
|
494 |
+
// Is it an active V thread?
|
495 |
+
const int active_v = tidx < params.M;
|
496 |
+
|
497 |
+
// Trigger the memory loads for V.
|
498 |
+
float ldg_v = 0.f;
|
499 |
+
if( active_v ) {
|
500 |
+
ldg_v = *ptr_v;
|
501 |
+
}
|
502 |
+
|
503 |
+
// Move the load pointers.
|
504 |
+
if( GO_BACKWARD ) {
|
505 |
+
ptr_q -= params.q_stride_L;
|
506 |
+
ptr_k -= params.k_stride_L;
|
507 |
+
ptr_v -= params.v_stride_L;
|
508 |
+
} else {
|
509 |
+
ptr_q += params.q_stride_L;
|
510 |
+
ptr_k += params.k_stride_L;
|
511 |
+
ptr_v += params.v_stride_L;
|
512 |
+
}
|
513 |
+
|
514 |
+
// The number of FLOAT4s per head.
|
515 |
+
constexpr int FLOAT4s_PER_HEAD = E / 4;
|
516 |
+
// The number of FLOAT4s per thread.
|
517 |
+
constexpr int FLOAT4s_PER_THREAD = FLOAT4s_PER_HEAD / THREADS_PER_HEAD;
|
518 |
+
|
519 |
+
// The storage for the K*V^T values.
|
520 |
+
float4 kv[FLOAT4s_PER_THREAD];
|
521 |
+
#pragma unroll
|
522 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
523 |
+
kv[ii] = make_float4(0.f, 0.f, 0.f, 0.f);
|
524 |
+
}
|
525 |
+
|
526 |
+
// The output pointer.
|
527 |
+
float *out_ptr = ¶ms.out[offset_o];
|
528 |
+
|
529 |
+
// Store to shared memory Q and K.
|
530 |
+
if( tidx < E ) {
|
531 |
+
smem_q[smem_curr*smem_buffer_elts + tidx] = ldg_q;
|
532 |
+
smem_k[smem_curr*smem_buffer_elts + tidx] = ldg_k;
|
533 |
+
}
|
534 |
+
|
535 |
+
// Store to shared memory V. All threads store valid values.
|
536 |
+
if( tidx < M ) {
|
537 |
+
smem_v[smem_curr*smem_buffer_elts + tidx] = ldg_v;
|
538 |
+
}
|
539 |
+
|
540 |
+
// The position of the thread in the V dimension.
|
541 |
+
int vo = tidx / THREADS_PER_HEAD;
|
542 |
+
int vi = tidx % THREADS_PER_HEAD;
|
543 |
+
|
544 |
+
// Iterate over the timesteps.
|
545 |
+
for( int ti = 0; ti < params.L; ++ti ) {
|
546 |
+
|
547 |
+
// Is it the last iteration?
|
548 |
+
int is_last = ti == params.L - 1;
|
549 |
+
|
550 |
+
// Trigger the next loads for Q and K.
|
551 |
+
if( !is_last && active_qk ) {
|
552 |
+
ldg_q = *ptr_q;
|
553 |
+
ldg_k = *ptr_k;
|
554 |
+
}
|
555 |
+
|
556 |
+
// Trigger the next loads for V.
|
557 |
+
if( !is_last && active_v ) {
|
558 |
+
ldg_v = *ptr_v;
|
559 |
+
}
|
560 |
+
|
561 |
+
// Move the load pointers.
|
562 |
+
if( GO_BACKWARD ) {
|
563 |
+
ptr_q -= params.q_stride_L;
|
564 |
+
ptr_k -= params.k_stride_L;
|
565 |
+
ptr_v -= params.v_stride_L;
|
566 |
+
} else {
|
567 |
+
ptr_q += params.q_stride_L;
|
568 |
+
ptr_k += params.k_stride_L;
|
569 |
+
ptr_v += params.v_stride_L;
|
570 |
+
}
|
571 |
+
|
572 |
+
// Make sure the data is in shared memory.
|
573 |
+
__syncthreads();
|
574 |
+
|
575 |
+
// Each thread loads 4 values from K.
|
576 |
+
float4 k[FLOAT4s_PER_THREAD];
|
577 |
+
#pragma unroll
|
578 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
579 |
+
int ki = tidx % THREADS_PER_HEAD * 4 + ii * THREADS_PER_HEAD * 4;
|
580 |
+
k[ii] = *reinterpret_cast<const float4*>(&smem_k[smem_curr*smem_buffer_elts + ki]);
|
581 |
+
}
|
582 |
+
|
583 |
+
// Each thread loads a single V value.
|
584 |
+
float v = 0.f;
|
585 |
+
if( vo < params.M ) {
|
586 |
+
v = *reinterpret_cast<const float *>(&smem_v[smem_curr*smem_buffer_elts + vo]);
|
587 |
+
}
|
588 |
+
|
589 |
+
// Update the K*V^T product.
|
590 |
+
#pragma unroll
|
591 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
592 |
+
kv[ii].x += k[ii].x * v;
|
593 |
+
kv[ii].y += k[ii].y * v;
|
594 |
+
kv[ii].z += k[ii].z * v;
|
595 |
+
kv[ii].w += k[ii].w * v;
|
596 |
+
}
|
597 |
+
|
598 |
+
// Load the Q values from shared memory.
|
599 |
+
float4 q[FLOAT4s_PER_THREAD];
|
600 |
+
#pragma unroll
|
601 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
602 |
+
int qi = tidx % THREADS_PER_HEAD * 4 + ii * THREADS_PER_HEAD * 4;
|
603 |
+
q[ii] = *reinterpret_cast<const float4*>(&smem_q[smem_curr*smem_buffer_elts + qi]);
|
604 |
+
}
|
605 |
+
|
606 |
+
// Compute the partial output value for that thread.
|
607 |
+
float sum = 0.f;
|
608 |
+
#pragma unroll
|
609 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
610 |
+
sum += q[ii].x * kv[ii].x;
|
611 |
+
sum += q[ii].y * kv[ii].y;
|
612 |
+
sum += q[ii].z * kv[ii].z;
|
613 |
+
sum += q[ii].w * kv[ii].w;
|
614 |
+
}
|
615 |
+
|
616 |
+
// Finalize the computation of the sum (if we have more than 1 thread per head).
|
617 |
+
if( THREADS_PER_HEAD > 1 ) {
|
618 |
+
|
619 |
+
// Finalize the sum for each head.
|
620 |
+
#pragma unroll
|
621 |
+
for( int mask = THREADS_PER_HEAD / 2; mask >= 1; mask /= 2 ) {
|
622 |
+
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
623 |
+
}
|
624 |
+
|
625 |
+
// Store to shared memory.
|
626 |
+
if( vo < M && vi == 0 ) {
|
627 |
+
smem_o[smem_curr*smem_buffer_elts + vo] = sum;
|
628 |
+
}
|
629 |
+
|
630 |
+
// Make sure the data is in shared memory.
|
631 |
+
__syncthreads();
|
632 |
+
|
633 |
+
// Active threads read the data to store.
|
634 |
+
if( active_v ) {
|
635 |
+
sum = smem_o[smem_curr*smem_buffer_elts + tidx];
|
636 |
+
}
|
637 |
+
|
638 |
+
} // THREADS_PER_HEAD > 1.
|
639 |
+
|
640 |
+
// Store the output. All the threads are active.
|
641 |
+
if( active_v ) {
|
642 |
+
*out_ptr = sum;
|
643 |
+
}
|
644 |
+
|
645 |
+
// Move to next location.
|
646 |
+
if( GO_BACKWARD ) {
|
647 |
+
out_ptr -= params.o_stride_L;
|
648 |
+
} else {
|
649 |
+
out_ptr += params.o_stride_L;
|
650 |
+
}
|
651 |
+
|
652 |
+
// Move the shared memory buffer.
|
653 |
+
smem_curr = (smem_curr + 1) % 2;
|
654 |
+
|
655 |
+
// Store to shared memory for Q and K.
|
656 |
+
if( !is_last && tidx < E ) {
|
657 |
+
smem_q[smem_curr*smem_buffer_elts + tidx] = ldg_q;
|
658 |
+
smem_k[smem_curr*smem_buffer_elts + tidx] = ldg_k;
|
659 |
+
}
|
660 |
+
|
661 |
+
// Store to shared memory for V.
|
662 |
+
if( !is_last && tidx < M ) {
|
663 |
+
smem_v[smem_curr*smem_buffer_elts + tidx] = ldg_v;
|
664 |
+
}
|
665 |
+
}
|
666 |
+
}
|
667 |
+
|
668 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
669 |
+
|
670 |
+
template< int E, int THREADS_PER_HEAD, bool GO_BACKWARD >
|
671 |
+
int lmha_(const Lmha_params<float> ¶ms) {
|
672 |
+
// The M dimension rounded up to 4.
|
673 |
+
int M = round_up(params.M, 4);
|
674 |
+
|
675 |
+
// The number of threads in the block.
|
676 |
+
int block = round_up(max(E, M*THREADS_PER_HEAD), 32);
|
677 |
+
if( block > 512 || params.B > 65535 ) {
|
678 |
+
return 1;
|
679 |
+
}
|
680 |
+
|
681 |
+
// Prepare the kernel.
|
682 |
+
dim3 grid(params.H, params.B);
|
683 |
+
size_t smem = smem_buffer_elts_<E>(params)*2*sizeof(float);
|
684 |
+
lmha_kernel<E, THREADS_PER_HEAD, GO_BACKWARD><<<grid, block, smem>>>(params);
|
685 |
+
return 0;
|
686 |
+
}
|
687 |
+
|
688 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
689 |
+
|
690 |
+
template< bool GO_BACKWARD >
|
691 |
+
int lmha(const Lmha_params<float> ¶ms) {
|
692 |
+
int blocks = params.B * params.H;
|
693 |
+
int res = 1;
|
694 |
+
if( blocks < LOW_OCCUPANCY_THRESHOLD ) {
|
695 |
+
if( params.E <= 32 ) {
|
696 |
+
res = lmha_low_occupancy_< 32, GO_BACKWARD>(params, blocks);
|
697 |
+
} else if( params.E <= 64 ) {
|
698 |
+
res = lmha_low_occupancy_< 64, GO_BACKWARD>(params, blocks);
|
699 |
+
} else if( params.E <= 128 ) {
|
700 |
+
res = lmha_low_occupancy_<128, GO_BACKWARD>(params, blocks);
|
701 |
+
} else if( params.E <= 256 ) {
|
702 |
+
res = lmha_low_occupancy_<256, GO_BACKWARD>(params, blocks);
|
703 |
+
}
|
704 |
+
} else {
|
705 |
+
if( params.E <= 32 ) {
|
706 |
+
res = lmha_< 32, 1, GO_BACKWARD>(params);
|
707 |
+
} else if( params.E <= 48 ) {
|
708 |
+
res = lmha_< 48, 1, GO_BACKWARD>(params);
|
709 |
+
} else if( params.E <= 64 ) {
|
710 |
+
res = lmha_< 64, 1, GO_BACKWARD>(params);
|
711 |
+
} else if( params.E <= 128 ) {
|
712 |
+
res = lmha_<128, 2, GO_BACKWARD>(params);
|
713 |
+
} else if( params.E <= 256 ) {
|
714 |
+
res = lmha_<256, 4, GO_BACKWARD>(params);
|
715 |
+
}
|
716 |
+
}
|
717 |
+
return res;
|
718 |
+
}
|
719 |
+
|
720 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
721 |
+
|
722 |
+
template< typename T >
|
723 |
+
inline void set_params(Lmha_params<T> ¶ms,
|
724 |
+
const torch::Tensor q,
|
725 |
+
const torch::Tensor k,
|
726 |
+
const torch::Tensor v,
|
727 |
+
torch::Tensor o) {
|
728 |
+
|
729 |
+
// Define the pointers.
|
730 |
+
params.out = o.data_ptr<T>();
|
731 |
+
params.q = q.data_ptr<T>();
|
732 |
+
params.k = k.data_ptr<T>();
|
733 |
+
params.v = v.data_ptr<T>();
|
734 |
+
|
735 |
+
// Define the strides.
|
736 |
+
params.q_stride_B = (int) q.stride(0);
|
737 |
+
params.q_stride_H = (int) q.stride(1);
|
738 |
+
params.q_stride_L = (int) q.stride(2);
|
739 |
+
params.k_stride_B = (int) k.stride(0);
|
740 |
+
params.k_stride_H = (int) k.stride(1);
|
741 |
+
params.k_stride_L = (int) k.stride(2);
|
742 |
+
params.v_stride_B = (int) v.stride(0);
|
743 |
+
params.v_stride_H = (int) v.stride(1);
|
744 |
+
params.v_stride_L = (int) v.stride(2);
|
745 |
+
params.o_stride_B = (int) o.stride(0);
|
746 |
+
params.o_stride_H = (int) o.stride(1);
|
747 |
+
params.o_stride_L = (int) o.stride(2);
|
748 |
+
|
749 |
+
// Extract the dimensions.
|
750 |
+
int N = q.size(0);
|
751 |
+
int H = q.size(1);
|
752 |
+
int L = q.size(2);
|
753 |
+
int E = q.size(3);
|
754 |
+
int M = v.size(3);
|
755 |
+
|
756 |
+
params.B = N;
|
757 |
+
params.L = L;
|
758 |
+
params.H = H;
|
759 |
+
params.E = E;
|
760 |
+
params.M = M;
|
761 |
+
}
|
762 |
+
|
763 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
764 |
+
|
765 |
+
int lmha_fwd(const torch::Tensor queries,
|
766 |
+
const torch::Tensor keys,
|
767 |
+
const torch::Tensor values,
|
768 |
+
torch::Tensor product) {
|
769 |
+
|
770 |
+
// Make sure that we are using the correct GPU device
|
771 |
+
torch::DeviceGuard _guard(queries.device());
|
772 |
+
|
773 |
+
// Make sure the inner-most dimension of the tensors is packed.
|
774 |
+
assert(queries.stride(3) == 1);
|
775 |
+
assert(keys .stride(3) == 1);
|
776 |
+
assert(values .stride(3) == 1);
|
777 |
+
assert(product.stride(3) == 1);
|
778 |
+
|
779 |
+
// Extract the dimensions.
|
780 |
+
int N = queries.size(0);
|
781 |
+
int H = queries.size(1);
|
782 |
+
int L = queries.size(2);
|
783 |
+
int E = queries.size(3);
|
784 |
+
int M = values.size (3);
|
785 |
+
|
786 |
+
// The structure of params.
|
787 |
+
Lmha_params<float> params;
|
788 |
+
set_params(params, queries, keys, values, product);
|
789 |
+
|
790 |
+
// Launch the kernel.
|
791 |
+
return lmha<false>(params);
|
792 |
+
}
|
793 |
+
|
794 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
795 |
+
|
796 |
+
template< typename T >
|
797 |
+
struct Lmha_bwd_params {
|
798 |
+
|
799 |
+
// The output buffer for K. Dimensions [B, H, L, D].
|
800 |
+
T *out_k;
|
801 |
+
// The output buffer for V. Dimensions [B, H, L, D].
|
802 |
+
T *out_v;
|
803 |
+
|
804 |
+
// The input Qs. Dimensions [B, H, L, D].
|
805 |
+
const T *q;
|
806 |
+
// The input Ks. Dimensions [B, H, L, D].
|
807 |
+
const T *k;
|
808 |
+
// The input Vs. Dimensions [B, H, L, D].
|
809 |
+
const T *v;
|
810 |
+
// The input Gs. Dimensions [B, H, L, D].
|
811 |
+
const T *g;
|
812 |
+
|
813 |
+
// The dimensions.
|
814 |
+
int B, L, H, M, E;
|
815 |
+
|
816 |
+
// The strides for the input tensors.
|
817 |
+
int q_stride_B, q_stride_L, q_stride_H;
|
818 |
+
int k_stride_B, k_stride_L, k_stride_H;
|
819 |
+
int v_stride_B, v_stride_L, v_stride_H;
|
820 |
+
int g_stride_B, g_stride_L, g_stride_H;
|
821 |
+
|
822 |
+
// The strides for the outputs.
|
823 |
+
int out_k_stride_B, out_k_stride_L, out_k_stride_H;
|
824 |
+
int out_v_stride_B, out_v_stride_L, out_v_stride_H;
|
825 |
+
};
|
826 |
+
|
827 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
828 |
+
|
829 |
+
template< int D, int THREADS_PER_HEAD >
|
830 |
+
__global__ __launch_bounds__(D*THREADS_PER_HEAD*2)
|
831 |
+
void lmha_bwd_kernel(Lmha_bwd_params<float> params) {
|
832 |
+
|
833 |
+
// Make sure D is a multiple of 4.
|
834 |
+
static_assert(D % 4 == 0, "");
|
835 |
+
|
836 |
+
// The shared memory buffers.
|
837 |
+
__shared__ struct Smem { float qg[2*D], kv[2*D], out_kv[2*D]; } smem_[2];
|
838 |
+
|
839 |
+
// The index of the shared memory buffer (for double-buffering).
|
840 |
+
int smem_curr = 0;
|
841 |
+
|
842 |
+
// The sequence processed by that block.
|
843 |
+
const int bi = blockIdx.y;
|
844 |
+
// The head processed by that block.
|
845 |
+
const int hi = blockIdx.x;
|
846 |
+
|
847 |
+
// The linear index of the thread.
|
848 |
+
const int tidx = threadIdx.x;
|
849 |
+
|
850 |
+
// Split the threads into two slices.
|
851 |
+
int so = tidx / (D*THREADS_PER_HEAD);
|
852 |
+
int si = tidx % (D*THREADS_PER_HEAD);
|
853 |
+
|
854 |
+
// The strides for B/L/H for the Q/G tensors.
|
855 |
+
int qg_stride_B, qg_stride_L, qg_stride_H;
|
856 |
+
if( so == 0 ) {
|
857 |
+
qg_stride_B = params.q_stride_B;
|
858 |
+
qg_stride_L = params.q_stride_L;
|
859 |
+
qg_stride_H = params.q_stride_H;
|
860 |
+
} else {
|
861 |
+
qg_stride_B = params.g_stride_B;
|
862 |
+
qg_stride_L = params.g_stride_L;
|
863 |
+
qg_stride_H = params.g_stride_H;
|
864 |
+
}
|
865 |
+
|
866 |
+
// The strides for B/L/H for the K/V tensors.
|
867 |
+
int kv_stride_B, kv_stride_L, kv_stride_H;
|
868 |
+
if( so == 0 ) {
|
869 |
+
kv_stride_B = params.k_stride_B;
|
870 |
+
kv_stride_L = params.k_stride_L;
|
871 |
+
kv_stride_H = params.k_stride_H;
|
872 |
+
} else {
|
873 |
+
kv_stride_B = params.v_stride_B;
|
874 |
+
kv_stride_L = params.v_stride_L;
|
875 |
+
kv_stride_H = params.v_stride_H;
|
876 |
+
}
|
877 |
+
|
878 |
+
// The hidden size.
|
879 |
+
int hidden_size_per_head = 0;
|
880 |
+
if( so == 0 ) {
|
881 |
+
hidden_size_per_head = params.E;
|
882 |
+
} else {
|
883 |
+
hidden_size_per_head = params.M;
|
884 |
+
}
|
885 |
+
|
886 |
+
// Where to start reading from.
|
887 |
+
int offset_qg = bi*qg_stride_B + hi*qg_stride_H + si;
|
888 |
+
int offset_kv = bi*kv_stride_B + hi*kv_stride_H + si;
|
889 |
+
|
890 |
+
// We walk backward, account for the extra offset.
|
891 |
+
offset_qg += (params.L-1)*qg_stride_L;
|
892 |
+
offset_kv += (params.L-1)*kv_stride_L;
|
893 |
+
|
894 |
+
// Determine the base pointers for Q, K, V and G.
|
895 |
+
const float *ptr_qg = &(so == 0 ? params.q : params.g)[offset_qg];
|
896 |
+
const float *ptr_kv = &(so == 0 ? params.k : params.v)[offset_kv];
|
897 |
+
|
898 |
+
// Is it an active thread?
|
899 |
+
const int active = si < hidden_size_per_head;
|
900 |
+
|
901 |
+
// Trigger the memory loads for Q, K, V and G.
|
902 |
+
float ldg_qg = 0.f, ldg_kv = 0.f;
|
903 |
+
if( active ) {
|
904 |
+
ldg_qg = *ptr_qg;
|
905 |
+
ldg_kv = *ptr_kv;
|
906 |
+
}
|
907 |
+
|
908 |
+
// Move the load pointers (backward).
|
909 |
+
ptr_qg -= qg_stride_L;
|
910 |
+
ptr_kv -= kv_stride_L;
|
911 |
+
|
912 |
+
// The number of FLOAT4s per head.
|
913 |
+
constexpr int FLOAT4s_PER_HEAD = D / 4;
|
914 |
+
// The number of FLOAT4s per thread.
|
915 |
+
constexpr int FLOAT4s_PER_THREAD = FLOAT4s_PER_HEAD / THREADS_PER_HEAD;
|
916 |
+
|
917 |
+
// The storage for the G*Q^T or Q^T*G values.
|
918 |
+
float4 gq[FLOAT4s_PER_THREAD];
|
919 |
+
#pragma unroll
|
920 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
921 |
+
gq[ii] = make_float4(0.f, 0.f, 0.f, 0.f);
|
922 |
+
}
|
923 |
+
|
924 |
+
// The strides for B/L/H for the K/V tensors.
|
925 |
+
int out_kv_stride_B, out_kv_stride_L, out_kv_stride_H;
|
926 |
+
if( so == 0 ) {
|
927 |
+
out_kv_stride_B = params.out_k_stride_B;
|
928 |
+
out_kv_stride_L = params.out_k_stride_L;
|
929 |
+
out_kv_stride_H = params.out_k_stride_H;
|
930 |
+
} else {
|
931 |
+
out_kv_stride_B = params.out_v_stride_B;
|
932 |
+
out_kv_stride_L = params.out_v_stride_L;
|
933 |
+
out_kv_stride_H = params.out_v_stride_H;
|
934 |
+
}
|
935 |
+
|
936 |
+
// Where to start reading from.
|
937 |
+
int offset_out_kv = bi*out_kv_stride_B + hi*out_kv_stride_H + si;
|
938 |
+
|
939 |
+
// We walk backward, account for the extra offset.
|
940 |
+
offset_out_kv += (params.L-1)*out_kv_stride_L;
|
941 |
+
|
942 |
+
// The output pointer.
|
943 |
+
float *ptr_out_kv = &(so == 0 ? params.out_k : params.out_v)[offset_out_kv];
|
944 |
+
|
945 |
+
// Store to shared memory.
|
946 |
+
if( si < D ) {
|
947 |
+
smem_[smem_curr].qg[so*D + si] = ldg_qg;
|
948 |
+
smem_[smem_curr].kv[so*D + si] = ldg_kv;
|
949 |
+
}
|
950 |
+
|
951 |
+
// The position of the thread in the output dimension.
|
952 |
+
int oo = si / THREADS_PER_HEAD % D;
|
953 |
+
int oi = si % THREADS_PER_HEAD * 4;
|
954 |
+
|
955 |
+
// Iterate over the timesteps.
|
956 |
+
for( int ti = 0; ti < params.L; ++ti ) {
|
957 |
+
|
958 |
+
// Is it the last iteration?
|
959 |
+
int is_last = ti == params.L - 1;
|
960 |
+
|
961 |
+
// Trigger the next loads.
|
962 |
+
if( !is_last && active ) {
|
963 |
+
ldg_qg = *ptr_qg;
|
964 |
+
ldg_kv = *ptr_kv;
|
965 |
+
}
|
966 |
+
|
967 |
+
// Move the load pointers.
|
968 |
+
ptr_qg -= qg_stride_L;
|
969 |
+
ptr_kv -= kv_stride_L;
|
970 |
+
|
971 |
+
// Make sure the data is in shared memory.
|
972 |
+
__syncthreads();
|
973 |
+
|
974 |
+
// Each thread loads 4 values from G or Q.
|
975 |
+
float4 g[FLOAT4s_PER_THREAD];
|
976 |
+
#pragma unroll
|
977 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
978 |
+
float *smem_ptr = &smem_[smem_curr].qg[(so^1)*D + oi];
|
979 |
+
g[ii] = *reinterpret_cast<const float4*>(&smem_ptr[ii*THREADS_PER_HEAD*4]);
|
980 |
+
}
|
981 |
+
|
982 |
+
// Each thread loads a single from Q or G value.
|
983 |
+
float q = smem_[smem_curr].qg[so*D + oo];
|
984 |
+
|
985 |
+
// Update the G*Q^T or Q*G^T product.
|
986 |
+
#pragma unroll
|
987 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
988 |
+
gq[ii].x += g[ii].x * q;
|
989 |
+
gq[ii].y += g[ii].y * q;
|
990 |
+
gq[ii].z += g[ii].z * q;
|
991 |
+
gq[ii].w += g[ii].w * q;
|
992 |
+
}
|
993 |
+
|
994 |
+
// Load the V or K values from shared memory.
|
995 |
+
float4 v[FLOAT4s_PER_THREAD];
|
996 |
+
#pragma unroll
|
997 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
998 |
+
float *smem_ptr = &smem_[smem_curr].kv[(so^1)*D + oi];
|
999 |
+
v[ii] = *reinterpret_cast<const float4*>(&smem_ptr[ii*THREADS_PER_HEAD*4]);
|
1000 |
+
}
|
1001 |
+
|
1002 |
+
// Compute the partial output value for that thread.
|
1003 |
+
float sum = 0.f;
|
1004 |
+
#pragma unroll
|
1005 |
+
for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) {
|
1006 |
+
sum += v[ii].x * gq[ii].x;
|
1007 |
+
sum += v[ii].y * gq[ii].y;
|
1008 |
+
sum += v[ii].z * gq[ii].z;
|
1009 |
+
sum += v[ii].w * gq[ii].w;
|
1010 |
+
}
|
1011 |
+
|
1012 |
+
// Finalize the computation of the sum (if we have more than 1 thread per head).
|
1013 |
+
if( THREADS_PER_HEAD > 1 ) {
|
1014 |
+
|
1015 |
+
// Finalize the sum for each head.
|
1016 |
+
#pragma unroll
|
1017 |
+
for( int mask = THREADS_PER_HEAD / 2; mask >= 1; mask /= 2 ) {
|
1018 |
+
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
1019 |
+
}
|
1020 |
+
|
1021 |
+
// Store to shared memory.
|
1022 |
+
if( oi == 0 ) {
|
1023 |
+
smem_[smem_curr].out_kv[so*D + oo] = sum;
|
1024 |
+
}
|
1025 |
+
|
1026 |
+
// Make sure the data is in shared memory.
|
1027 |
+
__syncthreads();
|
1028 |
+
|
1029 |
+
// Active threads read the data to store.
|
1030 |
+
if( si < hidden_size_per_head ) {
|
1031 |
+
sum = smem_[smem_curr].out_kv[so*D + si];
|
1032 |
+
}
|
1033 |
+
|
1034 |
+
} // THREADS_PER_HEAD > 1.
|
1035 |
+
|
1036 |
+
// Store the output. All the threads are active.
|
1037 |
+
if( si < hidden_size_per_head ) {
|
1038 |
+
*ptr_out_kv = sum;
|
1039 |
+
}
|
1040 |
+
|
1041 |
+
// Move to next location.
|
1042 |
+
ptr_out_kv -= out_kv_stride_L;
|
1043 |
+
|
1044 |
+
// Move the shared memory buffer.
|
1045 |
+
smem_curr = (smem_curr + 1) % 2;
|
1046 |
+
|
1047 |
+
// Store to shared memory for Q and K.
|
1048 |
+
if( !is_last && si < D ) {
|
1049 |
+
smem_[smem_curr].qg[so*D + si] = ldg_qg;
|
1050 |
+
smem_[smem_curr].kv[so*D + si] = ldg_kv;
|
1051 |
+
}
|
1052 |
+
}
|
1053 |
+
}
|
1054 |
+
|
1055 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1056 |
+
|
1057 |
+
template< int D, int THREADS_PER_HEAD >
|
1058 |
+
int lmha_bwd_(const Lmha_bwd_params<float> ¶ms) {
|
1059 |
+
int block = D*THREADS_PER_HEAD*2;
|
1060 |
+
if( block >= 1024 || params.B > 65535 ) {
|
1061 |
+
return 1;
|
1062 |
+
}
|
1063 |
+
dim3 grid(params.H, params.B);
|
1064 |
+
lmha_bwd_kernel<D, THREADS_PER_HEAD><<<grid, block>>>(params);
|
1065 |
+
return 0;
|
1066 |
+
}
|
1067 |
+
|
1068 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1069 |
+
|
1070 |
+
int lmha_bwd(const Lmha_bwd_params<float> ¶ms) {
|
1071 |
+
int blocks = params.B * params.H;
|
1072 |
+
if( blocks < LOW_OCCUPANCY_THRESHOLD ) {
|
1073 |
+
return 1;
|
1074 |
+
}
|
1075 |
+
|
1076 |
+
int hidden_size_per_head = max(params.E, params.M);
|
1077 |
+
int res = 1;
|
1078 |
+
if( hidden_size_per_head <= 32 ) {
|
1079 |
+
res = lmha_bwd_< 32, 1>(params);
|
1080 |
+
} else if( hidden_size_per_head <= 64 ) {
|
1081 |
+
res = lmha_bwd_< 64, 1>(params);
|
1082 |
+
} else if( hidden_size_per_head <= 128 ) {
|
1083 |
+
res = lmha_bwd_<128, 2>(params);
|
1084 |
+
} else if( hidden_size_per_head <= 256 ) {
|
1085 |
+
res = lmha_bwd_<256, 4>(params);
|
1086 |
+
}
|
1087 |
+
return res;
|
1088 |
+
}
|
1089 |
+
|
1090 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1091 |
+
|
1092 |
+
int lmha_bwd(const torch::Tensor queries,
|
1093 |
+
const torch::Tensor keys,
|
1094 |
+
const torch::Tensor values,
|
1095 |
+
const torch::Tensor grad_out,
|
1096 |
+
torch::Tensor grad_queries,
|
1097 |
+
torch::Tensor grad_keys,
|
1098 |
+
torch::Tensor grad_values) {
|
1099 |
+
|
1100 |
+
// Make sure that we are using the correct GPU device
|
1101 |
+
torch::DeviceGuard _guard(queries.device());
|
1102 |
+
|
1103 |
+
// Make sure the inner-most dimension of the tensors is packed.
|
1104 |
+
assert(queries .stride(3) == 1);
|
1105 |
+
assert(keys .stride(3) == 1);
|
1106 |
+
assert(values .stride(3) == 1);
|
1107 |
+
assert(grad_out .stride(3) == 1);
|
1108 |
+
assert(grad_queries.stride(3) == 1);
|
1109 |
+
assert(grad_keys .stride(3) == 1);
|
1110 |
+
assert(grad_values .stride(3) == 1);
|
1111 |
+
|
1112 |
+
// Extract the dimensions.
|
1113 |
+
int N = queries.size(0);
|
1114 |
+
int H = queries.size(1);
|
1115 |
+
int L = queries.size(2);
|
1116 |
+
int E = queries.size(3);
|
1117 |
+
int M = values.size (3);
|
1118 |
+
|
1119 |
+
// Gradient on Q.
|
1120 |
+
|
1121 |
+
// The structure of params.
|
1122 |
+
Lmha_params<float> params;
|
1123 |
+
set_params(params, grad_out, values, keys, grad_queries);
|
1124 |
+
|
1125 |
+
// Launch the kernel.
|
1126 |
+
int res = lmha<false>(params);
|
1127 |
+
if( res ) {
|
1128 |
+
return res;
|
1129 |
+
}
|
1130 |
+
|
1131 |
+
// Gradient on K and V together.
|
1132 |
+
|
1133 |
+
Lmha_bwd_params<float> bwd_params;
|
1134 |
+
bwd_params.out_k = grad_keys.data_ptr<float>();
|
1135 |
+
bwd_params.out_v = grad_values.data_ptr<float>();
|
1136 |
+
bwd_params.q = queries.data_ptr<float>();
|
1137 |
+
bwd_params.k = keys.data_ptr<float>();
|
1138 |
+
bwd_params.v = values.data_ptr<float>();
|
1139 |
+
bwd_params.g = grad_out.data_ptr<float>();
|
1140 |
+
|
1141 |
+
bwd_params.B = N;
|
1142 |
+
bwd_params.L = L;
|
1143 |
+
bwd_params.H = H;
|
1144 |
+
bwd_params.E = E;
|
1145 |
+
bwd_params.M = M;
|
1146 |
+
|
1147 |
+
bwd_params.q_stride_B = queries.stride(0);
|
1148 |
+
bwd_params.q_stride_H = queries.stride(1);
|
1149 |
+
bwd_params.q_stride_L = queries.stride(2);
|
1150 |
+
bwd_params.k_stride_B = keys.stride(0);
|
1151 |
+
bwd_params.k_stride_H = keys.stride(1);
|
1152 |
+
bwd_params.k_stride_L = keys.stride(2);
|
1153 |
+
bwd_params.v_stride_B = values.stride(0);
|
1154 |
+
bwd_params.v_stride_H = values.stride(1);
|
1155 |
+
bwd_params.v_stride_L = values.stride(2);
|
1156 |
+
bwd_params.g_stride_B = grad_out.stride(0);
|
1157 |
+
bwd_params.g_stride_H = grad_out.stride(1);
|
1158 |
+
bwd_params.g_stride_L = grad_out.stride(2);
|
1159 |
+
|
1160 |
+
bwd_params.out_k_stride_B = grad_keys.stride(0);
|
1161 |
+
bwd_params.out_k_stride_H = grad_keys.stride(1);
|
1162 |
+
bwd_params.out_k_stride_L = grad_keys.stride(2);
|
1163 |
+
bwd_params.out_v_stride_B = grad_values.stride(0);
|
1164 |
+
bwd_params.out_v_stride_H = grad_values.stride(1);
|
1165 |
+
bwd_params.out_v_stride_L = grad_values.stride(2);
|
1166 |
+
|
1167 |
+
// Try to run the fused kernel.
|
1168 |
+
int fallback = lmha_bwd(bwd_params);
|
1169 |
+
|
1170 |
+
// If it failed, fallback on separate kernels for K and V.
|
1171 |
+
if( fallback ) {
|
1172 |
+
|
1173 |
+
// Gradient on K.
|
1174 |
+
|
1175 |
+
// Launch the kernel.
|
1176 |
+
set_params(params, values, grad_out, queries, grad_keys);
|
1177 |
+
res = lmha<true>(params);
|
1178 |
+
if( res ) {
|
1179 |
+
return res;
|
1180 |
+
}
|
1181 |
+
|
1182 |
+
// Gradient on V.
|
1183 |
+
|
1184 |
+
// Launch the kernel.
|
1185 |
+
set_params(params, keys, queries, grad_out, grad_values);
|
1186 |
+
return lmha<true>(params);
|
1187 |
+
}
|
1188 |
+
|
1189 |
+
// It worked...
|
1190 |
+
return 0;
|
1191 |
+
}
|
1192 |
+
|
1193 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1194 |
+
|
1195 |
+
} // namespace nvidia
|
1196 |
+
#endif // #ifdef ENABLE_NVIDIA_OPTIMIZATIONS
|
1197 |
+
|
1198 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1199 |
+
|
1200 |
+
typedef torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits> float_accessor;
|
1201 |
+
|
1202 |
+
#define E_BLOCK_SIZE 8
|
1203 |
+
|
1204 |
+
__global__ void causal_dot_product_kernel(
|
1205 |
+
const float_accessor queries,
|
1206 |
+
const float_accessor keys,
|
1207 |
+
const float_accessor values,
|
1208 |
+
float_accessor result,
|
1209 |
+
const int N,
|
1210 |
+
const int H,
|
1211 |
+
const int L,
|
1212 |
+
const int E,
|
1213 |
+
const int M
|
1214 |
+
) {
|
1215 |
+
int n = blockIdx.y;
|
1216 |
+
int h = blockIdx.z;
|
1217 |
+
|
1218 |
+
int e_start = blockIdx.x * E_BLOCK_SIZE;
|
1219 |
+
int m = threadIdx.x % M;
|
1220 |
+
|
1221 |
+
extern __shared__ float shared_mem[];
|
1222 |
+
float* shared_kv = shared_mem;
|
1223 |
+
|
1224 |
+
for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) {
|
1225 |
+
shared_kv[m + e_local * M] = 0;
|
1226 |
+
}
|
1227 |
+
|
1228 |
+
for (int t=0; t<L; t++) {
|
1229 |
+
float res = 0;
|
1230 |
+
for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) {
|
1231 |
+
shared_kv[e_local*M + m] += keys[n][h][t][e_local + e_start] * values[n][h][t][m];
|
1232 |
+
res += queries[n][h][t][e_local + e_start] * shared_kv[e_local*M + m];
|
1233 |
+
}
|
1234 |
+
atomicAdd(
|
1235 |
+
&result[n][h][t][m],
|
1236 |
+
res
|
1237 |
+
);
|
1238 |
+
}
|
1239 |
+
}
|
1240 |
+
|
1241 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1242 |
+
|
1243 |
+
void causal_dot_product_(const torch::Tensor queries,
|
1244 |
+
const torch::Tensor keys,
|
1245 |
+
const torch::Tensor values,
|
1246 |
+
torch::Tensor product) {
|
1247 |
+
// Make sure that we are using the correct GPU device
|
1248 |
+
torch::DeviceGuard _guard(queries.device());
|
1249 |
+
|
1250 |
+
int N = queries.size(0);
|
1251 |
+
int H = queries.size(1);
|
1252 |
+
int L = queries.size(2);
|
1253 |
+
int E = queries.size(3);
|
1254 |
+
int M = values.size(3);
|
1255 |
+
|
1256 |
+
const int blocks_per_sequence = (E + E_BLOCK_SIZE - 1) / E_BLOCK_SIZE;
|
1257 |
+
|
1258 |
+
dim3 blockDim(M, 1, 1);
|
1259 |
+
dim3 gridDim(blocks_per_sequence, N, H);
|
1260 |
+
const int shared_mem_forward = E_BLOCK_SIZE * M * sizeof(float);
|
1261 |
+
|
1262 |
+
causal_dot_product_kernel<<<gridDim, blockDim, shared_mem_forward>>>(
|
1263 |
+
queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1264 |
+
keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1265 |
+
values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1266 |
+
product.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1267 |
+
N, H, L, E, M
|
1268 |
+
);
|
1269 |
+
}
|
1270 |
+
|
1271 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1272 |
+
|
1273 |
+
void causal_dot_product(const torch::Tensor queries,
|
1274 |
+
const torch::Tensor keys,
|
1275 |
+
const torch::Tensor values,
|
1276 |
+
torch::Tensor product) {
|
1277 |
+
#ifdef ENABLE_NVIDIA_OPTIMIZATIONS
|
1278 |
+
int fallback = nvidia::lmha_fwd(queries, keys, values, product);
|
1279 |
+
#else
|
1280 |
+
int fallback = 1;
|
1281 |
+
#endif
|
1282 |
+
if( fallback ) {
|
1283 |
+
causal_dot_product_(queries, keys, values, product);
|
1284 |
+
}
|
1285 |
+
}
|
1286 |
+
|
1287 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1288 |
+
|
1289 |
+
#define M_BLOCK_SIZE 4
|
1290 |
+
|
1291 |
+
// we need shared memory to store
|
1292 |
+
// kv
|
1293 |
+
// Backward direction
|
1294 |
+
// kv_backwards
|
1295 |
+
// Shared memory usage
|
1296 |
+
__global__ void causal_dot_backward_query_key_kernel(
|
1297 |
+
const float_accessor queries,
|
1298 |
+
const float_accessor keys,
|
1299 |
+
const float_accessor values,
|
1300 |
+
const float_accessor grad_out,
|
1301 |
+
float_accessor grad_queries,
|
1302 |
+
float_accessor grad_keys,
|
1303 |
+
int N,
|
1304 |
+
int H,
|
1305 |
+
int L,
|
1306 |
+
int E,
|
1307 |
+
int M
|
1308 |
+
) {
|
1309 |
+
int n = blockIdx.y;
|
1310 |
+
int h = blockIdx.z;
|
1311 |
+
|
1312 |
+
int m_start = blockIdx.x * M_BLOCK_SIZE;
|
1313 |
+
int e = threadIdx.x % E;
|
1314 |
+
|
1315 |
+
extern __shared__ float shared_mem[];
|
1316 |
+
const int shared_kv_size = M_BLOCK_SIZE * E;
|
1317 |
+
float* shared_kv = shared_mem;
|
1318 |
+
float* shared_kv_bw = shared_mem + shared_kv_size;
|
1319 |
+
|
1320 |
+
for (int m_local = 0; m_local < M_BLOCK_SIZE && m_local + m_start < M; m_local++) {
|
1321 |
+
shared_kv[m_local * E + e] = 0;
|
1322 |
+
shared_kv_bw[m_local * E + e] = 0;
|
1323 |
+
}
|
1324 |
+
|
1325 |
+
for (int l=0; l<L; l++) {
|
1326 |
+
float res = 0, res_bw = 0;
|
1327 |
+
int l_b = L - l - 1;
|
1328 |
+
for (int m_local = 0; m_local < M_BLOCK_SIZE && m_local + m_start < M; m_local++) {
|
1329 |
+
shared_kv[m_local*E + e] += keys[n][h][l][e] * values[n][h][l][m_start + m_local];
|
1330 |
+
shared_kv_bw[m_local*E + e] += queries[n][h][l_b][e] * grad_out[n][h][l_b][m_start + m_local];
|
1331 |
+
res += grad_out[n][h][l][m_start + m_local] * shared_kv[m_local*E + e];
|
1332 |
+
res_bw += values[n][h][l_b][m_start + m_local] * shared_kv_bw[m_local*E + e];
|
1333 |
+
}
|
1334 |
+
atomicAdd(
|
1335 |
+
&grad_queries[n][h][l][e],
|
1336 |
+
res
|
1337 |
+
);
|
1338 |
+
atomicAdd(
|
1339 |
+
&grad_keys[n][h][l_b][e],
|
1340 |
+
res_bw
|
1341 |
+
);
|
1342 |
+
}
|
1343 |
+
}
|
1344 |
+
|
1345 |
+
|
1346 |
+
__global__ void causal_dot_backward_value_kernel(
|
1347 |
+
const float_accessor queries,
|
1348 |
+
const float_accessor keys,
|
1349 |
+
const float_accessor values,
|
1350 |
+
const float_accessor grad_out,
|
1351 |
+
float_accessor grad_keys,
|
1352 |
+
float_accessor grad_values,
|
1353 |
+
int N,
|
1354 |
+
int H,
|
1355 |
+
int L,
|
1356 |
+
int E,
|
1357 |
+
int M
|
1358 |
+
) {
|
1359 |
+
int n = blockIdx.y;
|
1360 |
+
int h = blockIdx.z;
|
1361 |
+
|
1362 |
+
int e_start = blockIdx.x * E_BLOCK_SIZE;
|
1363 |
+
int m = threadIdx.x % M;
|
1364 |
+
|
1365 |
+
extern __shared__ float shared_mem[];
|
1366 |
+
float* shared_kv = shared_mem;
|
1367 |
+
for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) {
|
1368 |
+
shared_kv[m + e_local * M] = 0;
|
1369 |
+
}
|
1370 |
+
|
1371 |
+
for (int l = 0; l < L; l++) {
|
1372 |
+
int l_b = L - l -1;
|
1373 |
+
float res = 0;
|
1374 |
+
for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) {
|
1375 |
+
shared_kv[e_local*M + m] += queries[n][h][l_b][e_start + e_local] * grad_out[n][h][l_b][m];
|
1376 |
+
res += keys[n][h][l_b][e_start + e_local] * shared_kv[e_local*M + m];
|
1377 |
+
}
|
1378 |
+
atomicAdd(
|
1379 |
+
&grad_values[n][h][l_b][m],
|
1380 |
+
res
|
1381 |
+
);
|
1382 |
+
}
|
1383 |
+
}
|
1384 |
+
|
1385 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1386 |
+
|
1387 |
+
void causal_dot_backward_(const torch::Tensor queries,
|
1388 |
+
const torch::Tensor keys,
|
1389 |
+
const torch::Tensor values,
|
1390 |
+
const torch::Tensor grad_out,
|
1391 |
+
torch::Tensor grad_queries,
|
1392 |
+
torch::Tensor grad_keys,
|
1393 |
+
torch::Tensor grad_values) {
|
1394 |
+
|
1395 |
+
// Make sure that we are using the correct GPU device
|
1396 |
+
torch::DeviceGuard _guard(queries.device());
|
1397 |
+
|
1398 |
+
int N = queries.size(0);
|
1399 |
+
int H = queries.size(1);
|
1400 |
+
int L = queries.size(2);
|
1401 |
+
int E = queries.size(3);
|
1402 |
+
int M = values.size(3);
|
1403 |
+
|
1404 |
+
const int blocks_per_sequence = (M + M_BLOCK_SIZE - 1) / M_BLOCK_SIZE;
|
1405 |
+
|
1406 |
+
dim3 blockDim(E, 1, 1);
|
1407 |
+
dim3 gridDim(blocks_per_sequence, N, H);
|
1408 |
+
const int shared_mem_qk_backward = 2 * M_BLOCK_SIZE * E * sizeof(float);
|
1409 |
+
|
1410 |
+
causal_dot_backward_query_key_kernel<<<gridDim, blockDim, shared_mem_qk_backward>>>(
|
1411 |
+
queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1412 |
+
keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1413 |
+
values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1414 |
+
grad_out.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1415 |
+
grad_queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1416 |
+
grad_keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1417 |
+
N, H, L, E, M
|
1418 |
+
);
|
1419 |
+
|
1420 |
+
const int blocks_per_sequence_value = (E + E_BLOCK_SIZE - 1) / E_BLOCK_SIZE;
|
1421 |
+
|
1422 |
+
dim3 blockDimv(M, 1, 1);
|
1423 |
+
dim3 gridDimv(blocks_per_sequence_value, N, H);
|
1424 |
+
const int shared_mem_v_backward = E_BLOCK_SIZE * M * sizeof(float);
|
1425 |
+
causal_dot_backward_value_kernel<<<gridDimv, blockDimv, shared_mem_v_backward>>>(
|
1426 |
+
queries.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1427 |
+
keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1428 |
+
values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1429 |
+
grad_out.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1430 |
+
grad_keys.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1431 |
+
grad_values.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
|
1432 |
+
N, H, L, E, M
|
1433 |
+
);
|
1434 |
+
}
|
1435 |
+
|
1436 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1437 |
+
|
1438 |
+
void causal_dot_backward(const torch::Tensor queries,
|
1439 |
+
const torch::Tensor keys,
|
1440 |
+
const torch::Tensor values,
|
1441 |
+
const torch::Tensor grad_out,
|
1442 |
+
torch::Tensor grad_queries,
|
1443 |
+
torch::Tensor grad_keys,
|
1444 |
+
torch::Tensor grad_values) {
|
1445 |
+
#ifdef ENABLE_NVIDIA_OPTIMIZATIONS
|
1446 |
+
int fallback = nvidia::lmha_bwd(queries,
|
1447 |
+
keys,
|
1448 |
+
values,
|
1449 |
+
grad_out,
|
1450 |
+
grad_queries,
|
1451 |
+
grad_keys,
|
1452 |
+
grad_values);
|
1453 |
+
#else
|
1454 |
+
int fallback = 1;
|
1455 |
+
#endif
|
1456 |
+
if( fallback ) {
|
1457 |
+
// Make sure that the gradient tensors are 0. This is needed because the
|
1458 |
+
// bwd pass might have partially executed and filled in some values in
|
1459 |
+
// grad_queries or grad_keys.
|
1460 |
+
//
|
1461 |
+
// This adds a small overhead every time we have to fall back to the old
|
1462 |
+
// kernel for the backward pass.
|
1463 |
+
grad_queries.zero_();
|
1464 |
+
grad_keys.zero_();
|
1465 |
+
causal_dot_backward_(queries, keys, values, grad_out, grad_queries, grad_keys, grad_values);
|
1466 |
+
}
|
1467 |
+
}
|
1468 |
+
|
1469 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
1470 |
+
|
1471 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
1472 |
+
m.def(
|
1473 |
+
"causal_dot_product",
|
1474 |
+
&causal_dot_product,
|
1475 |
+
"Compute the weighted sum of values but attending only to previous "
|
1476 |
+
"values."
|
1477 |
+
);
|
1478 |
+
m.def(
|
1479 |
+
"causal_dot_backward",
|
1480 |
+
&causal_dot_backward,
|
1481 |
+
"Compute the gradients for the causal dot product."
|
1482 |
+
);
|
1483 |
+
}
|
csrc/setup.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
3 |
+
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
4 |
+
# Apoorv Vyas <avyas@idiap.ch>
|
5 |
+
#
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from setuptools import setup
|
9 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
|
10 |
+
import subprocess
|
11 |
+
|
12 |
+
def get_last_arch_torch():
|
13 |
+
arch = torch.cuda.get_arch_list()[-1]
|
14 |
+
print(f"Found arch: {arch} from existing torch installation")
|
15 |
+
return arch
|
16 |
+
|
17 |
+
def get_cuda_bare_metal_version(cuda_dir):
|
18 |
+
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
19 |
+
output = raw_output.split()
|
20 |
+
release_idx = output.index("release") + 1
|
21 |
+
release = output[release_idx].split(".")
|
22 |
+
bare_metal_major = release[0]
|
23 |
+
bare_metal_minor = release[1][0]
|
24 |
+
return raw_output, bare_metal_major, bare_metal_minor
|
25 |
+
|
26 |
+
def append_nvcc_threads(nvcc_extra_args):
|
27 |
+
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
28 |
+
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
|
29 |
+
return nvcc_extra_args + ["--threads", "4"]
|
30 |
+
return nvcc_extra_args
|
31 |
+
|
32 |
+
arch = get_last_arch_torch()
|
33 |
+
sm_num = arch[-2:]
|
34 |
+
cc_flag = ['--generate-code=arch=compute_90,code=compute_90'] # for H100
|
35 |
+
# cc_flag = ['--generate-code=arch=compute_80,code=compute_80'] # for A100
|
36 |
+
# cc_flag = ['--generate-code=arch=compute_89,code=compute_89'] # for RTX 6000, 4090
|
37 |
+
# cc_flag = ['--generate-code=arch=compute_86,code=compute_86'] # for A6000, 3090
|
38 |
+
# cc_flag = ['--generate-code=arch=compute_75,code=compute_75']
|
39 |
+
|
40 |
+
setup(
|
41 |
+
name='causal_attention_cuda_cpp',
|
42 |
+
ext_modules=[
|
43 |
+
CUDAExtension('causal_attention_cuda', [
|
44 |
+
# 'causal_attention.cpp',
|
45 |
+
'causal_attention_cuda.cu',
|
46 |
+
],
|
47 |
+
extra_compile_args={'cxx': ['-O3'],
|
48 |
+
'nvcc': append_nvcc_threads(['-O3', '-lineinfo', '--use_fast_math', '-std=c++17'] + cc_flag)
|
49 |
+
})
|
50 |
+
],
|
51 |
+
cmdclass={
|
52 |
+
'build_ext': BuildExtension
|
53 |
+
})
|
src/__init__.py
ADDED
File without changes
|
src/dataloaders/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Load dataloaders
|
3 |
+
"""
|
4 |
+
import importlib
|
5 |
+
|
6 |
+
|
7 |
+
def load_data(dataset_config: dict, dataloader_config: dict):
|
8 |
+
"""Return dataloaders from dataset_config"""
|
9 |
+
try:
|
10 |
+
dataset_module = importlib.import_module(f'dataloaders.{dataset_config["name"]}')
|
11 |
+
except Exception:
|
12 |
+
try:
|
13 |
+
dataset_module = importlib.import_module(f'src.dataloaders.{dataset_config["name"]}')
|
14 |
+
except Exception as e2:
|
15 |
+
print(e2)
|
16 |
+
try: # e.g., tasks like GLUE where name is benchmark and path specifies the dataset / task
|
17 |
+
dataset_module = importlib.import_module(f'dataloaders.{dataset_config["path"]}')
|
18 |
+
except Exception as e3:
|
19 |
+
print(f'Error from {dataset_config}')
|
20 |
+
raise e3
|
21 |
+
_load_data = getattr(dataset_module, 'load_data')
|
22 |
+
return _load_data(**dataset_config, **dataloader_config)
|
src/dataloaders/alpaca_clean.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Alpaca training dataloaders
|
3 |
+
|
4 |
+
We adopt the original prompt template; goes something like:
|
5 |
+
```
|
6 |
+
Below is an instruction that describes a task.
|
7 |
+
Write a response that appropriately completes the request.
|
8 |
+
### Instruction:
|
9 |
+
{instruction}
|
10 |
+
|
11 |
+
### Response:
|
12 |
+
{response}
|
13 |
+
```
|
14 |
+
See `PROMPT_DICT` for more.
|
15 |
+
"""
|
16 |
+
from functools import partial
|
17 |
+
from os.path import join
|
18 |
+
|
19 |
+
from datasets import load_metric, load_dataset
|
20 |
+
|
21 |
+
from .utils import (
|
22 |
+
get_lm_loader, get_seq2seq_loader,
|
23 |
+
convert_to_hf_dataset,
|
24 |
+
get_tokenizer_from_config,
|
25 |
+
download_scrolls_metric as download_metric
|
26 |
+
)
|
27 |
+
from .utils.packing import ConcatDataset
|
28 |
+
|
29 |
+
|
30 |
+
PROMPT_DICT = {
|
31 |
+
"prompt_input": (
|
32 |
+
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
33 |
+
"Write a response that appropriately completes the request.\n\n"
|
34 |
+
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
35 |
+
),
|
36 |
+
"prompt_no_input": (
|
37 |
+
"Below is an instruction that describes a task. "
|
38 |
+
"Write a response that appropriately completes the request.\n\n"
|
39 |
+
"### Instruction:\n{instruction}\n\n### Response:\n"
|
40 |
+
),
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
def load_data(name: str, dataset_config: dict, pretrained_model_config: dict,
|
45 |
+
preprocess_config: dict, **loader_kwargs: any):
|
46 |
+
"""
|
47 |
+
Shared function to load dataset from experiment config
|
48 |
+
-> e.g., see configs/experiments/distill_alpaca_clean_lr1e-2.yaml
|
49 |
+
"""
|
50 |
+
# Misc. setup
|
51 |
+
cache_dir = dataset_config['cache_dir']
|
52 |
+
input_len = dataset_config['chunk_size']
|
53 |
+
concat_data = dataset_config['concat_data']
|
54 |
+
|
55 |
+
tokenizer_name = pretrained_model_config['pretrained_model_name_or_path']
|
56 |
+
tokenizer_name = tokenizer_name.split('/')[-1]
|
57 |
+
# save_path = join(cache_dir, f'{name}_{tokenizer_name}')
|
58 |
+
|
59 |
+
# Setup tokenizer
|
60 |
+
tokenizer = get_tokenizer_from_config(pretrained_model_config)
|
61 |
+
if tokenizer.pad_token is None:
|
62 |
+
tokenizer.pad_token = tokenizer.eos_token
|
63 |
+
print(f'Setting tokenizer.pad_token to {tokenizer.pad_token}')
|
64 |
+
|
65 |
+
tokenizer.padding_side = 'left' # for decoder-only generation
|
66 |
+
# Get initial data
|
67 |
+
ignore_kwargs = ['concat_data', 'chunk_size', 'pose_kwargs']
|
68 |
+
dataset = load_dataset(
|
69 |
+
**{k: v for k, v in dataset_config.items() if k not in ignore_kwargs}
|
70 |
+
)
|
71 |
+
if dataset_config['name'] == 'samsum': # hack
|
72 |
+
dataset = dataset.rename_column('dialogue', 'input')
|
73 |
+
dataset = dataset.rename_column('summary', 'output')
|
74 |
+
_instruction = 'Summarize this dialogue.'
|
75 |
+
for split in dataset.keys():
|
76 |
+
dataset[split] = dataset[split].add_column(
|
77 |
+
'instruction', [_instruction] * len(dataset[split])
|
78 |
+
)
|
79 |
+
train_set, val_set, test_set = dataset['train'], dataset['validation'], dataset['test']
|
80 |
+
dataset = train_set # hack to work with below code
|
81 |
+
else:
|
82 |
+
dataset = dataset['train']
|
83 |
+
train_set = convert_to_hf_dataset([dataset[ix] for ix in range(200, len(dataset))], cache_dir)
|
84 |
+
val_set = convert_to_hf_dataset([dataset[ix] for ix in range(200)], cache_dir)
|
85 |
+
test_set = convert_to_hf_dataset([dataset[ix] for ix in range(200)], cache_dir)
|
86 |
+
|
87 |
+
# Convert to dicts of {input_ids, attention_mask, labels}
|
88 |
+
train_set = train_set.map(
|
89 |
+
partial(template_and_tokenize, tokenizer=tokenizer, include_label=True),
|
90 |
+
remove_columns=list(dataset.features),) # load_from_cache_file=False)
|
91 |
+
val_set = val_set.map(
|
92 |
+
partial(template_and_tokenize, tokenizer=tokenizer, include_label=True),
|
93 |
+
remove_columns=list(dataset.features),) # load_from_cache_file=False)
|
94 |
+
test_set = test_set.map(
|
95 |
+
partial(template_and_tokenize, tokenizer=tokenizer, include_label=False),
|
96 |
+
remove_columns=list(dataset.features),) # load_from_cache_file=False)
|
97 |
+
|
98 |
+
# Chunk together train and val sets
|
99 |
+
if concat_data:
|
100 |
+
train_set = ConcatDataset(train_set, chunk_size=input_len)
|
101 |
+
val_set = ConcatDataset(val_set, chunk_size=input_len)
|
102 |
+
|
103 |
+
# Get dataloaders
|
104 |
+
dataloaders = {
|
105 |
+
'train': get_lm_loader(train_set, tokenizer, 'train', input_len, **loader_kwargs),
|
106 |
+
'validation': get_lm_loader(val_set, tokenizer, 'validation', input_len, **loader_kwargs),
|
107 |
+
'test': get_seq2seq_loader(test_set, tokenizer, 'test', **loader_kwargs),
|
108 |
+
}
|
109 |
+
# Evaluation metric
|
110 |
+
try:
|
111 |
+
metric = load_metric(download_metric(), 'gov_report') # hack but we want rouge
|
112 |
+
except Exception as e:
|
113 |
+
print(f'Error loading metric: {e}')
|
114 |
+
metric = None
|
115 |
+
|
116 |
+
# Finishing touches
|
117 |
+
for k, v in dataloaders.items(): # Make tokenizer accessible
|
118 |
+
dataloaders[k].dataset.tokenizer = tokenizer
|
119 |
+
dataloaders[k].dataset.metric = metric
|
120 |
+
return dataloaders
|
121 |
+
|
122 |
+
|
123 |
+
def template_and_tokenize(sample, tokenizer, include_label: bool = True):
|
124 |
+
"""
|
125 |
+
Format dataset context and answers into single-sequence prompts
|
126 |
+
"""
|
127 |
+
if sample.get('input', '') == '':
|
128 |
+
prompt = PROMPT_DICT["prompt_no_input"].format_map(sample)
|
129 |
+
else:
|
130 |
+
prompt = PROMPT_DICT["prompt_input"].format_map(sample)
|
131 |
+
|
132 |
+
prompt = tokenizer.encode(prompt, add_special_tokens=True)
|
133 |
+
if include_label:
|
134 |
+
answer = tokenizer.encode(f'{sample["output"]}{tokenizer.eos_token}',
|
135 |
+
add_special_tokens=False)
|
136 |
+
target = None
|
137 |
+
else:
|
138 |
+
answer = []
|
139 |
+
target = tokenizer.encode(f'{sample["output"]}{tokenizer.eos_token}',
|
140 |
+
add_special_tokens=False)
|
141 |
+
input_ids = prompt + answer
|
142 |
+
attn_mask = [1] * len(input_ids)
|
143 |
+
|
144 |
+
sample = {
|
145 |
+
"input_ids": input_ids,
|
146 |
+
"attention_mask" : attn_mask,
|
147 |
+
"labels": [-100] * len(prompt) + answer if include_label else target,
|
148 |
+
}
|
149 |
+
return sample
|
src/dataloaders/alpaca_clean_instruct.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Alpaca Clean dataset with Llama3-Instruct prompt formatting
|
3 |
+
"""
|
4 |
+
|
5 |
+
from functools import partial
|
6 |
+
from os.path import join
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch.utils.data import Dataset, DataLoader
|
13 |
+
|
14 |
+
from datasets import load_metric, load_dataset
|
15 |
+
from transformers import AutoTokenizer
|
16 |
+
from transformers import DataCollatorForSeq2Seq, DefaultDataCollator, DataCollatorWithPadding
|
17 |
+
|
18 |
+
from .utils import (
|
19 |
+
get_lm_loader, get_seq2seq_loader,
|
20 |
+
convert_to_hf_dataset,
|
21 |
+
get_tokenizer_from_config,
|
22 |
+
download_scrolls_metric as download_metric
|
23 |
+
)
|
24 |
+
from .utils.packing import ConcatDataset
|
25 |
+
|
26 |
+
|
27 |
+
SYSTEM_PROMPT = "You are a helpful AI assistant who always responds to appropriately complete a user's request."
|
28 |
+
|
29 |
+
|
30 |
+
def encode_response(response: str, tokenizer) -> list[int]:
|
31 |
+
tokens = tokenizer.encode(response.strip(), add_special_tokens=False)
|
32 |
+
# For Llama 3 Instruct: tokens.append(tokenizer.get_added_vocab()["<|eot_id|>"])
|
33 |
+
tokens.append(tokenizer.eos_token_id)
|
34 |
+
try: # Llama 3 Instruct
|
35 |
+
tokens.append(tokenizer.get_added_vocab()["<|end_of_text|>"])
|
36 |
+
except KeyError:
|
37 |
+
pass
|
38 |
+
return tokens
|
39 |
+
|
40 |
+
|
41 |
+
def load_data(name: str, dataset_config: dict, pretrained_model_config: dict,
|
42 |
+
preprocess_config: dict, **loader_kwargs: any):
|
43 |
+
|
44 |
+
# Misc. setup
|
45 |
+
cache_dir = dataset_config['cache_dir']
|
46 |
+
input_len = dataset_config['chunk_size']
|
47 |
+
concat_data = dataset_config['concat_data']
|
48 |
+
load_from_cache_file = False # False if want to retokenize dataset
|
49 |
+
|
50 |
+
# Hard-code system prompt handling
|
51 |
+
if 'istral' in pretrained_model_config['pretrained_model_name_or_path']:
|
52 |
+
system_prompt = ''
|
53 |
+
else:
|
54 |
+
system_prompt = SYSTEM_PROMPT
|
55 |
+
|
56 |
+
tokenizer_name = pretrained_model_config['pretrained_model_name_or_path']
|
57 |
+
tokenizer_name = tokenizer_name.split('/')[-1]
|
58 |
+
save_path = join(cache_dir, f'{name}_{tokenizer_name}')
|
59 |
+
|
60 |
+
# Setup tokenizer
|
61 |
+
tokenizer = get_tokenizer_from_config(pretrained_model_config)
|
62 |
+
if tokenizer.pad_token is None:
|
63 |
+
tokenizer.pad_token = tokenizer.eos_token
|
64 |
+
print(f'Setting tokenizer.pad_token to {tokenizer.pad_token}')
|
65 |
+
|
66 |
+
tokenizer.padding_side = 'left' # for decoder-only generation
|
67 |
+
|
68 |
+
# Get initial data
|
69 |
+
ignore_kwargs = ['concat_data', 'chunk_size', 'pose_kwargs', 'system_prompt', 'name']
|
70 |
+
train_set = load_dataset(
|
71 |
+
**{k: v for k, v in dataset_config.items() if k not in ignore_kwargs},
|
72 |
+
split='train[100:-100]',
|
73 |
+
)
|
74 |
+
val_set = load_dataset( # we just use this dataset as a validation set
|
75 |
+
**{k: v for k, v in dataset_config.items() if k not in ignore_kwargs},
|
76 |
+
split='train[:100]+train[-100:]',
|
77 |
+
)
|
78 |
+
test_set = load_dataset(
|
79 |
+
**{k: v for k, v in dataset_config.items() if k not in ignore_kwargs},
|
80 |
+
split='train[:100]+train[-100:]',
|
81 |
+
)
|
82 |
+
|
83 |
+
# Convert to dicts of {input_ids, attention_mask, labels}
|
84 |
+
train_set = train_set.map(partial(template_and_tokenize, tokenizer=tokenizer,
|
85 |
+
include_label=True, system_prompt=system_prompt),
|
86 |
+
remove_columns=list(train_set.features),
|
87 |
+
load_from_cache_file=load_from_cache_file)
|
88 |
+
val_set = val_set.map(partial(template_and_tokenize, tokenizer=tokenizer,
|
89 |
+
include_label=True, system_prompt=system_prompt),
|
90 |
+
remove_columns=list(val_set.features),
|
91 |
+
load_from_cache_file=load_from_cache_file)
|
92 |
+
test_set = test_set.map(partial(template_and_tokenize, tokenizer=tokenizer,
|
93 |
+
include_label=False, system_prompt=system_prompt),
|
94 |
+
remove_columns=list(test_set.features),
|
95 |
+
load_from_cache_file=load_from_cache_file)
|
96 |
+
|
97 |
+
# Chunk together train and val sets
|
98 |
+
if concat_data:
|
99 |
+
train_set = ConcatDataset(train_set, chunk_size=input_len)
|
100 |
+
val_set = ConcatDataset(val_set, chunk_size=input_len)
|
101 |
+
|
102 |
+
# Get dataloaders
|
103 |
+
dataloaders = {
|
104 |
+
'train': get_lm_loader(train_set, tokenizer, 'train', input_len, **loader_kwargs),
|
105 |
+
'validation': get_lm_loader(val_set, tokenizer, 'validation', input_len, **loader_kwargs),
|
106 |
+
'test': get_seq2seq_loader(test_set, tokenizer, 'test', **loader_kwargs),
|
107 |
+
}
|
108 |
+
# Evaluation metric
|
109 |
+
metric = load_metric(download_metric(), 'gov_report') # hack but we want rouge
|
110 |
+
|
111 |
+
# Finishing touches
|
112 |
+
for k, v in dataloaders.items(): # Make tokenizer accessible
|
113 |
+
dataloaders[k].dataset.tokenizer = tokenizer
|
114 |
+
dataloaders[k].dataset.metric = metric
|
115 |
+
return dataloaders
|
116 |
+
|
117 |
+
|
118 |
+
def template_and_tokenize(sample, tokenizer, include_label: bool = True,
|
119 |
+
system_prompt: str = None):
|
120 |
+
if system_prompt is None:
|
121 |
+
system_prompt = SYSTEM_PROMPT
|
122 |
+
|
123 |
+
prompt = sample['instruction']
|
124 |
+
if sample['input'] != '':
|
125 |
+
prompt += f"\n\n{sample['input']}"
|
126 |
+
|
127 |
+
messages = [
|
128 |
+
{"role": "system", "content": system_prompt},
|
129 |
+
] if system_prompt != '' else []
|
130 |
+
messages.append({"role": "user", "content": prompt})
|
131 |
+
prompt_ids = tokenizer.apply_chat_template(
|
132 |
+
messages, tokenize=True, add_generation_prompt=True,
|
133 |
+
)
|
134 |
+
if include_label:
|
135 |
+
answer = encode_response(sample['output'], tokenizer)
|
136 |
+
else:
|
137 |
+
answer = []
|
138 |
+
target = encode_response(sample['output'], tokenizer)
|
139 |
+
|
140 |
+
input_ids = prompt_ids + answer
|
141 |
+
attn_mask = [1] * len(input_ids)
|
142 |
+
sample = {
|
143 |
+
"input_ids": input_ids,
|
144 |
+
"attention_mask" : attn_mask,
|
145 |
+
"labels": [-100] * len(prompt_ids) + answer if include_label else target,
|
146 |
+
}
|
147 |
+
return sample
|
148 |
+
|
src/dataloaders/utils/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Helper functions dataset setup and loading
|
3 |
+
"""
|
4 |
+
from .setup import *
|
src/dataloaders/utils/llama3.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Data utils for Llama3
|
3 |
+
"""
|
4 |
+
|
5 |
+
def encode_header(message: str, tokenizer) -> list[int]:
|
6 |
+
tokens = []
|
7 |
+
tokens.append(tokenizer.get_added_vocab()["<|start_header_id|>"])
|
8 |
+
tokens.extend(tokenizer.encode(message["role"], add_special_tokens=False))
|
9 |
+
tokens.append(tokenizer.get_added_vocab()["<|end_header_id|>"])
|
10 |
+
tokens.extend(tokenizer.encode("\n\n", add_special_tokens=False))
|
11 |
+
return tokens
|
12 |
+
|
13 |
+
|
14 |
+
def encode_message(message: str, tokenizer, include_header: bool = True) -> list[int]:
|
15 |
+
tokens = encode_header(message, tokenizer) if include_header else []
|
16 |
+
tokens.extend(
|
17 |
+
tokenizer.encode(message["content"].strip(), add_special_tokens=False)
|
18 |
+
)
|
19 |
+
tokens.append(tokenizer.get_added_vocab()["<|eot_id|>"])
|
20 |
+
return tokens
|
21 |
+
|
22 |
+
|
23 |
+
def template_and_tokenize(sample, tokenizer, include_label: bool = True,
|
24 |
+
system_prompt: str = None):
|
25 |
+
if system_prompt is not None:
|
26 |
+
dialog = [{'role': 'system', 'content': system_prompt}]
|
27 |
+
else:
|
28 |
+
dialog = []
|
29 |
+
|
30 |
+
chat = []
|
31 |
+
instruction = sample['instruction']
|
32 |
+
if sample['input'] != '':
|
33 |
+
instruction += f"\n\n{sample['input']}"
|
34 |
+
dialog.extend([
|
35 |
+
{'role': 'user', 'content': instruction},
|
36 |
+
{'role': 'assistant', 'content': sample['output']},
|
37 |
+
])
|
38 |
+
|
39 |
+
prompt = []
|
40 |
+
prompt.append(tokenizer.get_added_vocab()["<|begin_of_text|>"])
|
41 |
+
for message in dialog[:-1]:
|
42 |
+
prompt.extend(encode_message(message, tokenizer))
|
43 |
+
|
44 |
+
if include_label:
|
45 |
+
answer = encode_message(dialog[-1], tokenizer)
|
46 |
+
answer.append(tokenizer.get_added_vocab()["<|end_of_text|>"])
|
47 |
+
else:
|
48 |
+
answer = []
|
49 |
+
target = encode_message(dialog[-1], tokenizer, include_header=False)
|
50 |
+
target.append(tokenizer.get_added_vocab()["<|end_of_text|>"])
|
51 |
+
# Add the start of an assistant message for the model to complete.
|
52 |
+
prompt.extend(encode_header({"role": "assistant", "content": ""}, tokenizer))
|
53 |
+
|
54 |
+
input_ids = prompt + answer
|
55 |
+
attn_mask = [1] * len(input_ids)
|
56 |
+
|
57 |
+
sample = {
|
58 |
+
"input_ids": input_ids,
|
59 |
+
"attention_mask" : attn_mask,
|
60 |
+
"labels": [-100] * len(prompt) + answer if include_label else target,
|
61 |
+
}
|
62 |
+
return sample
|
src/dataloaders/utils/packing.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
3 |
+
"""
|
4 |
+
Copied from https://github.com/meta-llama/llama-recipes/blob/9b3dabcaac78980eae40005bbc8b1a8276c82af3/src/llama_recipes/data/concatenator.py#L1
|
5 |
+
"""
|
6 |
+
import random
|
7 |
+
from itertools import chain
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
|
11 |
+
from torch.utils.data import Dataset
|
12 |
+
|
13 |
+
|
14 |
+
class Concatenator(object):
|
15 |
+
def __init__(self, chunk_size=2048):
|
16 |
+
self.chunk_size=chunk_size
|
17 |
+
self.residual = {"input_ids": [], "attention_mask": []}
|
18 |
+
|
19 |
+
def __call__(self, batch):
|
20 |
+
concatenated_samples = {
|
21 |
+
k: v + list(chain(*batch[k])) for k, v in self.residual.items()
|
22 |
+
}
|
23 |
+
|
24 |
+
total_length = len(concatenated_samples[list(concatenated_samples.keys())[0]])
|
25 |
+
|
26 |
+
if total_length >= self.chunk_size:
|
27 |
+
chunk_num = total_length // self.chunk_size
|
28 |
+
result = {
|
29 |
+
k: [
|
30 |
+
v[i : i + self.chunk_size]
|
31 |
+
for i in range(0, chunk_num * self.chunk_size, self.chunk_size)
|
32 |
+
]
|
33 |
+
for k, v in concatenated_samples.items()
|
34 |
+
}
|
35 |
+
self.residual = {
|
36 |
+
k: v[(chunk_num * self.chunk_size) :]
|
37 |
+
for k, v in concatenated_samples.items()
|
38 |
+
}
|
39 |
+
else:
|
40 |
+
result = concatenated_samples
|
41 |
+
self.residual = {k: [] for k in concatenated_samples.keys()}
|
42 |
+
|
43 |
+
result["labels"] = result["input_ids"].copy()
|
44 |
+
|
45 |
+
return result
|
46 |
+
|
47 |
+
class ConcatDataset(Dataset):
|
48 |
+
"""
|
49 |
+
Concatenates or packs samples of a dataset into chunks of size `chunk_size`
|
50 |
+
"""
|
51 |
+
def __init__(self, dataset, chunk_size: int = 1024, seed: int = 42,) -> None:
|
52 |
+
self.dataset = dataset
|
53 |
+
self.chunk_size = chunk_size
|
54 |
+
self.samples = []
|
55 |
+
buffer = {
|
56 |
+
"input_ids": [],
|
57 |
+
"attention_mask": [],
|
58 |
+
"labels": [],
|
59 |
+
}
|
60 |
+
random.seed(seed)
|
61 |
+
for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
|
62 |
+
buffer = {k: v + sample[k] for k,v in buffer.items()}
|
63 |
+
|
64 |
+
while len(next(iter(buffer.values()))) > self.chunk_size:
|
65 |
+
self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
|
66 |
+
buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}
|
67 |
+
# Slow hack, but filter out any samples without valid labels (all -100)
|
68 |
+
self.filtered_samples = []
|
69 |
+
for s in self.samples:
|
70 |
+
if sum(s['labels']) != chunk_size * -100:
|
71 |
+
self.filtered_samples.append(s)
|
72 |
+
if len(self.filtered_samples) < len(self.samples):
|
73 |
+
print(f'OG dataset: {len(self.samples)} samples -> Filtered dataset: {len(self.filtered_samples)}')
|
74 |
+
print(f'-> Filtered out {len(self.samples) - len(self.filtered_samples)} samples')
|
75 |
+
|
76 |
+
def __getitem__(self, idx):
|
77 |
+
return self.filtered_samples[idx]
|
78 |
+
|
79 |
+
def __len__(self):
|
80 |
+
return len(self.filtered_samples)
|
src/dataloaders/utils/setup.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Helper functions dataset setup and loading
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
from os.path import join
|
6 |
+
import shutil
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from torch.utils.data import Dataset, DataLoader
|
10 |
+
|
11 |
+
from datasets import Dataset as HFDataset
|
12 |
+
from huggingface_hub import hf_hub_download
|
13 |
+
from transformers import AutoTokenizer, LlamaTokenizer
|
14 |
+
from transformers import DataCollatorForSeq2Seq
|
15 |
+
# from transformers import DefaultDataCollator, DataCollatorWithPadding
|
16 |
+
|
17 |
+
|
18 |
+
def get_seq2seq_loader(dataset: Dataset, tokenizer: AutoTokenizer,
|
19 |
+
split: str, **loader_kwargs: any):
|
20 |
+
"""
|
21 |
+
Get dataloader for seq2seq tasks (evaluation)
|
22 |
+
"""
|
23 |
+
tokenizer.padding_side = 'right'
|
24 |
+
collate_fn = DataCollatorForSeq2Seq(
|
25 |
+
tokenizer, label_pad_token_id=-100, return_tensors='pt')
|
26 |
+
return DataLoader(
|
27 |
+
dataset, shuffle='train' in split, collate_fn=collate_fn, **loader_kwargs)
|
28 |
+
|
29 |
+
|
30 |
+
def get_lm_loader(dataset: Dataset, tokenizer: AutoTokenizer,
|
31 |
+
split: str, max_length: int = None, **loader_kwargs: any):
|
32 |
+
"""
|
33 |
+
Get dataloader for language modeling (training)
|
34 |
+
-> Currently this ends up being the same as get_seq2seq_loader
|
35 |
+
"""
|
36 |
+
# collate_fn = DefaultDataCollator(return_tensors='pt')
|
37 |
+
# collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, padding=True,
|
38 |
+
# max_length=max_length, return_tensors='pt')
|
39 |
+
collate_fn = DataCollatorForSeq2Seq(
|
40 |
+
tokenizer, label_pad_token_id=-100, return_tensors='pt')
|
41 |
+
return DataLoader(
|
42 |
+
dataset, shuffle='train' in split, collate_fn=collate_fn, **loader_kwargs)
|
43 |
+
|
44 |
+
|
45 |
+
def convert_to_hf_dataset(dataset, cache_dir: str):
|
46 |
+
"""
|
47 |
+
Convert iterable dataset to HuggingFace HFDataset object
|
48 |
+
"""
|
49 |
+
def gen():
|
50 |
+
for _, sample in enumerate(dataset):
|
51 |
+
yield sample # dataset[idx]
|
52 |
+
return HFDataset.from_generator(gen, cache_dir=cache_dir)
|
53 |
+
|
54 |
+
|
55 |
+
def get_tokenizer_from_config(model_config):
|
56 |
+
"""
|
57 |
+
Get pretrained tokenizer based on (pretrained) model config
|
58 |
+
"""
|
59 |
+
# Get tokenizer
|
60 |
+
if 'llama' in model_config['pretrained_model_name_or_path']:
|
61 |
+
try: # if we store locally
|
62 |
+
model_path = join(model_config['cache_dir'],
|
63 |
+
model_config['pretrained_model_name_or_path'])
|
64 |
+
tokenizer = LlamaTokenizer.from_pretrained(model_path)
|
65 |
+
except Exception as e:
|
66 |
+
try:
|
67 |
+
tokenizer = AutoTokenizer.from_pretrained(**model_config)
|
68 |
+
print("-> Bad LlamaTokenizer.from_pretrained(model_path)", e)
|
69 |
+
print("-> But resolved with: AutoTokenizer.from_pretrained(**model_config)")
|
70 |
+
except Exception as e2:
|
71 |
+
print("-> Error with AutoTokenizer.from_pretrained(**model_config)", e2)
|
72 |
+
# tokenizer = LlamaTokenizer.from_pretrained(**model_config) # v4.43 errors with `*** TypeError: not a string`
|
73 |
+
elif 'Mistral-7B-Instruct-v0.3' in model_config['pretrained_model_name_or_path']:
|
74 |
+
tokenizer = LlamaTokenizer.from_pretrained(**model_config) # hack where AutoTokenizer doesn't recognize
|
75 |
+
elif 'Mistral-7B' in model_config['pretrained_model_name_or_path']:
|
76 |
+
tokenizer = AutoTokenizer.from_pretrained(**model_config)
|
77 |
+
else:
|
78 |
+
tokenizer = AutoTokenizer.from_pretrained(**model_config)
|
79 |
+
return tokenizer
|
80 |
+
|
81 |
+
|
82 |
+
def add_special_tokens_to_dataset(dataset, tokenizer):
|
83 |
+
"""
|
84 |
+
Add special tokens as attributes to a dataset object
|
85 |
+
"""
|
86 |
+
token_map = {k: v for k, v in tokenizer.special_tokens_map.items()}
|
87 |
+
special_ids = tokenizer.all_special_ids
|
88 |
+
for idx, k in enumerate(tokenizer.special_tokens_map.keys()):
|
89 |
+
token_map[f'{k}_id'] = special_ids[idx]
|
90 |
+
for k, v in token_map.items():
|
91 |
+
setattr(dataset, k, v)
|
92 |
+
return dataset
|
93 |
+
|
94 |
+
|
95 |
+
def train_test_split(samples: any, train_size: int, test_size: int, seed: int):
|
96 |
+
"""
|
97 |
+
Split samples into train and test sets
|
98 |
+
"""
|
99 |
+
try:
|
100 |
+
assert len(samples) == train_size + test_size
|
101 |
+
except Exception as e:
|
102 |
+
print(len(samples), train_size + test_size)
|
103 |
+
raise e
|
104 |
+
arange = np.arange(len(samples))
|
105 |
+
np.random.seed(seed)
|
106 |
+
test_idx = np.random.choice(arange, size=test_size, replace=False)
|
107 |
+
train_idx = np.setdiff1d(arange, test_idx)
|
108 |
+
return samples[train_idx], samples[test_idx]
|
109 |
+
|
110 |
+
|
111 |
+
def download_scrolls_metric():
|
112 |
+
"""
|
113 |
+
Download ROUGE, F1, and other accuracy metrics included in the SCROLLS dataset
|
114 |
+
"""
|
115 |
+
scrolls_metric_path = hf_hub_download(
|
116 |
+
repo_id="tau/scrolls", filename="metrics/scrolls.py", repo_type="dataset"
|
117 |
+
)
|
118 |
+
updated_scrolls_metric_path = (
|
119 |
+
os.path.dirname(scrolls_metric_path) +
|
120 |
+
os.path.basename(scrolls_metric_path).replace(".", "_") + ".py"
|
121 |
+
)
|
122 |
+
shutil.copy(scrolls_metric_path, updated_scrolls_metric_path)
|
123 |
+
return updated_scrolls_metric_path
|
src/finetune.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Finetuning functions to do post-distillation
|
3 |
+
"""
|
4 |
+
from os.path import join
|
5 |
+
from omegaconf import OmegaConf
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch.nn import Module
|
9 |
+
|
10 |
+
from src.utils.setup import update_config_from_args
|
11 |
+
from src.dataloaders import load_data
|
12 |
+
from src.trainer import get_trainer, get_optimizer, get_scheduler
|
13 |
+
|
14 |
+
|
15 |
+
def prepare_finetune_configs(args, model_config: dict,
|
16 |
+
finetune_config_name: str = None,
|
17 |
+
finetune_checkpoint_name: str = None,
|
18 |
+
config_dir='./configs/experiment'):
|
19 |
+
"""
|
20 |
+
Prepare finetuning configs
|
21 |
+
"""
|
22 |
+
# Load finetuning config
|
23 |
+
finetune_config = (finetune_config_name if finetune_config_name is not None else
|
24 |
+
finetune_checkpoint_name.split('-f=')[-1].split('-')[0])
|
25 |
+
finetune_config_path = join(config_dir, f'{finetune_config}.yaml')
|
26 |
+
finetune_config = OmegaConf.load(finetune_config_path)
|
27 |
+
finetune_config = update_config_from_args(finetune_config, args,
|
28 |
+
ignore_args=['lr', 'weight_decay'])
|
29 |
+
# Update data tokenizer to match model
|
30 |
+
if getattr(finetune_config.dataset, 'pretrained_model_config', None) is not None:
|
31 |
+
for k in ['pretrained_model_name_or_path', 'cache_dir']:
|
32 |
+
finetune_config.dataset.pretrained_model_config[k] = model_config['model'][k]
|
33 |
+
# Set finetuning args
|
34 |
+
for arg, argv in finetune_config.trainer.items():
|
35 |
+
if arg != 'name':
|
36 |
+
setattr(args, arg, argv)
|
37 |
+
for _config in ['dataloader', 'optimizer', 'lr_scheduler']:
|
38 |
+
setattr(args, _config, OmegaConf.to_container(getattr(finetune_config, _config)))
|
39 |
+
return finetune_config, args
|
40 |
+
|
41 |
+
|
42 |
+
def get_finetuner(model: Module, finetune_config: dict, device: torch.device,
|
43 |
+
args: any, wandb: any, initial_eval: bool = False):
|
44 |
+
"""
|
45 |
+
Initialize finetuning trainer
|
46 |
+
"""
|
47 |
+
model.to(device) # if using a fused optimizer
|
48 |
+
model.train()
|
49 |
+
|
50 |
+
# Initialize optimizer and scheduler
|
51 |
+
optimizer = get_optimizer(model=model, **finetune_config.optimizer)
|
52 |
+
scheduler = get_scheduler(optimizer=optimizer, **finetune_config.lr_scheduler)
|
53 |
+
|
54 |
+
dataloaders = load_data(finetune_config.dataset, finetune_config.dataloader)
|
55 |
+
train_loader = dataloaders[finetune_config.trainer.train_split]
|
56 |
+
eval_loader = dataloaders[finetune_config.trainer.val_split]
|
57 |
+
|
58 |
+
OurTrainer = get_trainer(finetune_config.trainer.name)
|
59 |
+
trainer = OurTrainer(model=model,
|
60 |
+
args=args,
|
61 |
+
train_loader=train_loader,
|
62 |
+
eval_loader=eval_loader,
|
63 |
+
optimizer_and_scheduler=(optimizer, scheduler),
|
64 |
+
device=device,
|
65 |
+
wandb=wandb,
|
66 |
+
checkpoint_suffix='_ft',
|
67 |
+
**finetune_config.trainer)
|
68 |
+
return trainer
|
src/model/__init__.py
ADDED
File without changes
|
src/model/convert_model.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Attention conversion helpers
|
3 |
+
"""
|
4 |
+
from functools import partial
|
5 |
+
from tqdm import tqdm
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
|
9 |
+
def convert_attention(model: nn.Module,
|
10 |
+
attention_config: dict,
|
11 |
+
train_attention: bool = False,
|
12 |
+
remove_base_attn: bool = True,):
|
13 |
+
"""
|
14 |
+
Call to convert all attention layers
|
15 |
+
"""
|
16 |
+
softmax_attns = []
|
17 |
+
if 'softmax_attentions' in attention_config:
|
18 |
+
softmax_attns = attention_config['softmax_attentions']
|
19 |
+
if attention_config.attention_type != 'softmax':
|
20 |
+
layers = traverse_layers(model)
|
21 |
+
for layer_idx, layer in enumerate(tqdm(layers, desc='Converting attentions...')):
|
22 |
+
if layer_idx not in softmax_attns:
|
23 |
+
layer.self_attn = convert_llama_attention(
|
24 |
+
layer, attention_config, layers, train_attention, remove_base_attn,
|
25 |
+
)
|
26 |
+
layer.self_attn.converted = True
|
27 |
+
else: # Freeze any preserved softmax attention layers
|
28 |
+
for p in layer.parameters():
|
29 |
+
p.requires_grad = False
|
30 |
+
else:
|
31 |
+
print(f'-> attention_config.attention_type is {attention_config.attention_type}; not converting attentions')
|
32 |
+
return model
|
33 |
+
|
34 |
+
|
35 |
+
def toggle_attention(llama_model: nn.Module, train: bool = False):
|
36 |
+
"""
|
37 |
+
Make attentions trainable if train is True
|
38 |
+
-> Set train_attention = False when finetuning
|
39 |
+
"""
|
40 |
+
for layer in traverse_layers(llama_model):
|
41 |
+
layer.self_attn.train_attention = train
|
42 |
+
return llama_model
|
43 |
+
|
44 |
+
|
45 |
+
def remove_base_attention(llama_model: nn.Module):
|
46 |
+
"""
|
47 |
+
Remove teacher attention after distillation (if we keep it)
|
48 |
+
"""
|
49 |
+
for layer in traverse_layers(llama_model):
|
50 |
+
if getattr(layer.self_attn, 'base_attn', False):
|
51 |
+
del layer.self_attn.base_attn
|
52 |
+
return llama_model
|
53 |
+
|
54 |
+
|
55 |
+
def traverse_layers(model: nn.Module, verbose: bool = False):
|
56 |
+
"""
|
57 |
+
Return list of model layers
|
58 |
+
"""
|
59 |
+
try:
|
60 |
+
layers = model.model.layers
|
61 |
+
if verbose:
|
62 |
+
print('-> Loading from model.model.layers')
|
63 |
+
except AttributeError as e: # if base model
|
64 |
+
if verbose:
|
65 |
+
print(e)
|
66 |
+
try:
|
67 |
+
layers = model.layers
|
68 |
+
if verbose:
|
69 |
+
print('-> Loading from model.layers')
|
70 |
+
except AttributeError as e1: # If we make a PEFT model
|
71 |
+
if verbose:
|
72 |
+
print(e1)
|
73 |
+
layers = model.base_model.model.model.layers
|
74 |
+
if verbose:
|
75 |
+
print('-> Loading from model.base_model.model.model.layers')
|
76 |
+
return layers
|
77 |
+
|
78 |
+
|
79 |
+
def convert_llama_attention(layer: nn.Module,
|
80 |
+
attention_config: dict,
|
81 |
+
layers: list[nn.Module], # list of layers
|
82 |
+
train_attention: bool = False,
|
83 |
+
remove_base_attn: bool = True):
|
84 |
+
"""
|
85 |
+
Converts a single layer's attention layer as specified by attention_config
|
86 |
+
"""
|
87 |
+
return get_attention(**attention_config)(
|
88 |
+
base_attn=layer.self_attn,
|
89 |
+
layer_idx=layer.self_attn.layer_idx, # Transformers v4.36
|
90 |
+
max_layer_idx=len(layers) - 1,
|
91 |
+
train_attention=train_attention,
|
92 |
+
remove_base_attn=remove_base_attn,
|
93 |
+
)
|
94 |
+
|
95 |
+
|
96 |
+
def get_attention(attention_type: str, **kwargs: any):
|
97 |
+
"""
|
98 |
+
Get the linear attention class; either purely linear or linear with sliding window
|
99 |
+
-> 'linear' == 'lolcats_llama'
|
100 |
+
-> 'linear and sliding_window' == 'lolcats_llama_window_*'
|
101 |
+
"""
|
102 |
+
kwargs['attention_type'] = attention_type
|
103 |
+
|
104 |
+
if attention_type == 'lolcats_llama':
|
105 |
+
from .linear_attention import LolcatsLinearAttention
|
106 |
+
return partial(LolcatsLinearAttention, **kwargs)
|
107 |
+
|
108 |
+
elif attention_type == 'lolcats_llama_window_tk':
|
109 |
+
from .linear_attention import LolcatsTKWindowAttention
|
110 |
+
return partial(LolcatsTKWindowAttention, **kwargs)
|
111 |
+
|
112 |
+
elif attention_type == 'lolcats_llama_window_sw':
|
113 |
+
from .linear_attention import LolcatsSlidingWindowAttention
|
114 |
+
return partial(LolcatsSlidingWindowAttention, **kwargs)
|
115 |
+
|
116 |
+
elif attention_type == 'lolcats_llama_window_sw_linear':
|
117 |
+
from .linear_attention.linear_window_attention_sw_linear import LolcatsLinearSlidingWindowAttention
|
118 |
+
return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
|
119 |
+
|
120 |
+
## Experimental chunked linear attentions below
|
121 |
+
elif attention_type == 'lolcats_long_llama_window_tk':
|
122 |
+
from .linear_attention import LolcatsTKWindowLongAttention
|
123 |
+
return partial(LolcatsTKWindowLongAttention, **kwargs)
|
124 |
+
|
125 |
+
elif attention_type == 'lolcats_long_llama_window_sw':
|
126 |
+
from .linear_attention import LolcatsSlidingWindowLongAttention
|
127 |
+
return partial(LolcatsSlidingWindowLongAttention, **kwargs)
|
128 |
+
|
129 |
+
## TK generation build (requires Thunderkittens)
|
130 |
+
elif attention_type == 'lolcats_llama_window_tk_gen':
|
131 |
+
from .linear_attention import LolcatsWindowAttentionTKGen
|
132 |
+
return partial(LolcatsWindowAttentionTKGen, **kwargs)
|
133 |
+
|
134 |
+
else:
|
135 |
+
print(f'-> attention_type {attention_type} not handled... returning None')
|
136 |
+
return None
|
137 |
+
|
138 |
+
|
139 |
+
def get_attention_cache(attention_type: str, past_key_values: any = None):
|
140 |
+
"""
|
141 |
+
Determine how we store past keys and values when generating
|
142 |
+
"""
|
143 |
+
if attention_type is None:
|
144 |
+
return past_key_values
|
145 |
+
|
146 |
+
# print(f'Returning attention cache based on attention_type == {attention_type}')
|
147 |
+
elif 'lolcats_llama_window_tk_gen' in attention_type:
|
148 |
+
from .linear_attention import LinearAttentionTKWindowGenerationCache
|
149 |
+
return LinearAttentionTKWindowGenerationCache()
|
150 |
+
|
151 |
+
elif 'llama_window_tk' in attention_type:
|
152 |
+
from .linear_attention import LinearAttentionTKWindowCache
|
153 |
+
return LinearAttentionTKWindowCache()
|
154 |
+
|
155 |
+
elif 'llama_window_sw' in attention_type:
|
156 |
+
from .linear_attention import LinearAttentionSlidingWindowCache
|
157 |
+
return LinearAttentionSlidingWindowCache()
|
158 |
+
|
159 |
+
elif 'llama_window_sw_linear' in attention_type:
|
160 |
+
from .linear_attention import LinearAttentionSlidingWindowCache
|
161 |
+
return LinearAttentionSlidingWindowCache()
|
162 |
+
|
163 |
+
## TK generation build (requires Thunderkittens)
|
164 |
+
elif attention_type == 'lolcats_llama_window_tk_gen':
|
165 |
+
from .linear_attention.linear_window_attention_tk_gen import LinearAttentionTKWindowGenerationCache
|
166 |
+
return LinearAttentionTKWindowGenerationCache()
|
167 |
+
|
168 |
+
elif 'softmax' in attention_type:
|
169 |
+
return past_key_values
|
170 |
+
|
171 |
+
else:
|
172 |
+
from .linear_attention import LinearAttentionState
|
173 |
+
return LinearAttentionState()
|
src/model/feature_map.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Learnable linear attention feature map classes and functions
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
def init_feature_map(name: str, mlp: nn.Module, **kwargs: dict):
|
10 |
+
"""
|
11 |
+
Initialize feature map final activation for linear attention
|
12 |
+
"""
|
13 |
+
return FeatureMap(activation_name=name, mlp=mlp, **kwargs)
|
14 |
+
|
15 |
+
|
16 |
+
def init_feature_map_act(name: str, fullspace: bool = True, **kwargs):
|
17 |
+
"""
|
18 |
+
Initialize feature map final activation for linear attention
|
19 |
+
"""
|
20 |
+
if name == 'softmax_dim' and fullspace:
|
21 |
+
return SoftmaxDim(**kwargs)
|
22 |
+
elif name == 'softmax_dim' and not fullspace:
|
23 |
+
return SoftmaxDimHalfspace(**kwargs)
|
24 |
+
elif name == 'exp_dim' and fullspace:
|
25 |
+
return Exp(**kwargs)
|
26 |
+
elif name == 'exp_dim' and not fullspace:
|
27 |
+
return ExpHalfspace(**kwargs)
|
28 |
+
elif name == 'pos_elu':
|
29 |
+
return PosELU(**kwargs)
|
30 |
+
elif name == 'relu':
|
31 |
+
return ReLU(**kwargs)
|
32 |
+
|
33 |
+
else:
|
34 |
+
raise NotImplementedError
|
35 |
+
|
36 |
+
|
37 |
+
def init_learned_kernel(name: str, **kwargs: any):
|
38 |
+
"""
|
39 |
+
Initialize feature map MLP for linear attention
|
40 |
+
"""
|
41 |
+
if name == 'untied_head_einsum':
|
42 |
+
return FeatureMapMLP(**kwargs)
|
43 |
+
elif name == 'untied_head_adapter':
|
44 |
+
return FeatureMapAdapter(**kwargs)
|
45 |
+
else:
|
46 |
+
raise NotImplementedError
|
47 |
+
|
48 |
+
|
49 |
+
class FeatureMap(nn.Module):
|
50 |
+
"""
|
51 |
+
Final 'activation' of feature map. Can probably be combined with
|
52 |
+
`FeatureMapMLP` below
|
53 |
+
|
54 |
+
Full feature map is like f(xW + b)
|
55 |
+
-> This is the `f` part
|
56 |
+
"""
|
57 |
+
def __init__(self,
|
58 |
+
activation_name: str,
|
59 |
+
head_dim_idx: int = -1,
|
60 |
+
eps: float = 1e-12,
|
61 |
+
mlp: nn.Module = None,
|
62 |
+
fullspace: bool = True,):
|
63 |
+
super().__init__()
|
64 |
+
self.head_dim_idx = head_dim_idx
|
65 |
+
self.eps = eps
|
66 |
+
self.mlp = mlp if mlp is not None else nn.Identity()
|
67 |
+
self.activation = init_feature_map_act(activation_name, fullspace, eps=eps)
|
68 |
+
|
69 |
+
def forward(self, x: torch.Tensor, *mlp_args: any, **mlp_kwargs: any):
|
70 |
+
"""
|
71 |
+
Assume x.shape is (batch_size, n_heads, seq_len, head_dim)
|
72 |
+
"""
|
73 |
+
return self.activation(self.mlp(x, *mlp_args, **mlp_kwargs), x)
|
74 |
+
|
75 |
+
def q_map(self, *args: any, **kwargs: any):
|
76 |
+
"""
|
77 |
+
Use for inference in case q and k feature maps differ
|
78 |
+
"""
|
79 |
+
return self.forward(*args, **kwargs)
|
80 |
+
|
81 |
+
def k_map(self, *args: any, **kwargs: any):
|
82 |
+
"""
|
83 |
+
Use for inference in case q and k feature maps differ
|
84 |
+
"""
|
85 |
+
return self.forward(*args, **kwargs)
|
86 |
+
|
87 |
+
|
88 |
+
# -----------------------
|
89 |
+
# Feature map activations
|
90 |
+
# -----------------------
|
91 |
+
class FeatureMapAct(nn.Module):
|
92 |
+
"""
|
93 |
+
Base class for feature map activations
|
94 |
+
"""
|
95 |
+
def __init__(self, eps: float = 1e-12):
|
96 |
+
super().__init__()
|
97 |
+
self.eps = eps
|
98 |
+
|
99 |
+
def forward(self, x: torch.Tensor, *args: any, **kwargs: any):
|
100 |
+
"""
|
101 |
+
x.shape is (batch_size, n_heads, seq_len, head_dim)
|
102 |
+
"""
|
103 |
+
return x
|
104 |
+
|
105 |
+
|
106 |
+
class PosELU(FeatureMapAct):
|
107 |
+
"""
|
108 |
+
1 + ELU activation as in https://arxiv.org/abs/2006.16236
|
109 |
+
"""
|
110 |
+
def forward(self, x: torch.Tensor, *args: any, **kwargs: any):
|
111 |
+
return (1 + F.elu(x)).clamp(min=self.eps)
|
112 |
+
|
113 |
+
|
114 |
+
class ReLU(FeatureMapAct):
|
115 |
+
"""
|
116 |
+
ReLU activation as in https://arxiv.org/abs/2103.13076
|
117 |
+
"""
|
118 |
+
def forward(self, x: torch.Tensor, *args: any, **kwargs: any):
|
119 |
+
return F.relu(x).clamp(min=self.eps)
|
120 |
+
|
121 |
+
|
122 |
+
class SoftmaxDim(FeatureMapAct):
|
123 |
+
"""
|
124 |
+
Softmax activation as in https://arxiv.org/abs/2402.04347
|
125 |
+
"""
|
126 |
+
def forward(self, x: torch.Tensor, *args: any, **kwargs: any):
|
127 |
+
return torch.cat([
|
128 |
+
torch.softmax(x, dim=-1), torch.softmax(-x, dim=-1)
|
129 |
+
], dim=-1).clamp(min=self.eps)
|
130 |
+
|
131 |
+
|
132 |
+
class SoftmaxDimHalfspace(FeatureMapAct):
|
133 |
+
"""
|
134 |
+
Softmax activation as in https://arxiv.org/abs/2402.04347
|
135 |
+
"""
|
136 |
+
def forward(self, x: torch.Tensor, *args: any, **kwargs: any):
|
137 |
+
return torch.softmax(x, dim=-1).clamp(min=self.eps)
|
138 |
+
|
139 |
+
|
140 |
+
class Exp(FeatureMapAct):
|
141 |
+
"""
|
142 |
+
Exp activation as in https://arxiv.org/abs/2402.04347
|
143 |
+
"""
|
144 |
+
def forward(self, x: torch.Tensor, *args: any, **kwargs: any):
|
145 |
+
x_max = torch.amax(x, dim=-1, keepdim=True)
|
146 |
+
x_min = torch.amin(x, dim=-1, keepdim=True)
|
147 |
+
return torch.cat([
|
148 |
+
torch.exp(x - x_max), torch.exp(-x + x_min)
|
149 |
+
], dim=-1).clamp(min=self.eps)
|
150 |
+
|
151 |
+
|
152 |
+
class ExpHalfspace(FeatureMapAct):
|
153 |
+
"""
|
154 |
+
Exp activation as in https://arxiv.org/abs/2402.04347
|
155 |
+
"""
|
156 |
+
def forward(self, x: torch.Tensor, *args: any, **kwargs: any):
|
157 |
+
x_max = torch.amax(x, dim=-1, keepdim=True)
|
158 |
+
return torch.exp(x - x_max).clamp(min=self.eps)
|
159 |
+
|
160 |
+
|
161 |
+
# ----------------
|
162 |
+
# Feature map MLPs
|
163 |
+
# ----------------
|
164 |
+
|
165 |
+
class FeatureMapMLP(nn.Module):
|
166 |
+
"""
|
167 |
+
Learnable MLP in feature map.
|
168 |
+
|
169 |
+
Full feature map is like f(xW + b)
|
170 |
+
-> This is the `W` and (optional) `b` part
|
171 |
+
"""
|
172 |
+
def __init__(self,
|
173 |
+
num_heads: int,
|
174 |
+
head_dim: int, # input dim
|
175 |
+
feature_dim: int, # output dim
|
176 |
+
dtype: torch.dtype,
|
177 |
+
device: torch.device,
|
178 |
+
skip_connection: bool = False,
|
179 |
+
bias: bool = False,
|
180 |
+
zero_init: bool = False,
|
181 |
+
normal_init: bool = False,):
|
182 |
+
super().__init__()
|
183 |
+
self.num_heads = num_heads
|
184 |
+
self.head_dim = head_dim
|
185 |
+
self.feature_dim = feature_dim
|
186 |
+
self.dtype = dtype
|
187 |
+
self.device = device
|
188 |
+
self.skip_connection = skip_connection
|
189 |
+
self.bias = bias
|
190 |
+
self.zero_init = zero_init
|
191 |
+
self.normal_init = normal_init
|
192 |
+
self.init_weights_()
|
193 |
+
|
194 |
+
if self.zero_init: # Zero-out weights or set as identity post-initialization
|
195 |
+
self.zero_init_with_skip_() if self.skip_connection else self.zero_init_()
|
196 |
+
|
197 |
+
if self.normal_init:
|
198 |
+
with torch.no_grad():
|
199 |
+
nn.init.normal_(self.layer)
|
200 |
+
|
201 |
+
if self.skip_connection:
|
202 |
+
assertion_fail = f'If self.skip_connection we need self.head_dim == self.feature_dim but self.head_dim is {self.head_dim} != self.feature_dim is {self.feature_dim}'
|
203 |
+
assert self.head_dim == self.feature_dim, assertion_fail
|
204 |
+
|
205 |
+
def init_weights_(self):
|
206 |
+
"""
|
207 |
+
Initialize (W)eights and (b)iases
|
208 |
+
"""
|
209 |
+
self.layer = nn.Parameter(torch.zeros(
|
210 |
+
(self.num_heads, self.head_dim, self.feature_dim),
|
211 |
+
dtype=self.dtype, device=self.device,
|
212 |
+
))
|
213 |
+
nn.init.kaiming_uniform_(self.layer)
|
214 |
+
|
215 |
+
if self.bias:
|
216 |
+
self.bias = nn.Parameter(torch.zeros(
|
217 |
+
(1, self.num_heads, 1, 1), # self.feature_dim),
|
218 |
+
dtype=self.dtype, device=self.device,
|
219 |
+
))
|
220 |
+
nn.init.kaiming_uniform_(self.bias)
|
221 |
+
else:
|
222 |
+
self.bias = 0. # hack
|
223 |
+
|
224 |
+
def zero_init_with_skip_(self):
|
225 |
+
"""
|
226 |
+
Initialize weights to zero matrix if skip connection
|
227 |
+
"""
|
228 |
+
with torch.no_grad():
|
229 |
+
nn.init.zeros_(self.layer)
|
230 |
+
|
231 |
+
def zero_init_(self):
|
232 |
+
"""
|
233 |
+
Initialize weights to identity matrix if no skip connection
|
234 |
+
"""
|
235 |
+
with torch.no_grad():
|
236 |
+
for i in range(self.layer.shape[0]):
|
237 |
+
try:
|
238 |
+
nn.init.eye_(self.layer[i])
|
239 |
+
except RuntimeError:
|
240 |
+
with torch.no_grad():
|
241 |
+
dtype = self.layer[i].dtype
|
242 |
+
weight = torch.eye(*self.layer[i].shape,
|
243 |
+
requires_grad=self.layer[i].requires_grad,
|
244 |
+
device=self.layer[i].device)
|
245 |
+
self.layer[i] = weight.to(dtype=dtype)
|
246 |
+
|
247 |
+
def forward(self, x: torch.Tensor):
|
248 |
+
"""
|
249 |
+
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
|
250 |
+
"""
|
251 |
+
_x = torch.einsum('hdf,bhld->bhlf', self.layer, x) + self.bias
|
252 |
+
return x + _x if self.skip_connection else _x
|
253 |
+
|
254 |
+
|
255 |
+
class FeatureMapAdapter(FeatureMapMLP):
|
256 |
+
"""
|
257 |
+
Learnable Feature map with bottleneck adapter
|
258 |
+
as in https://arxiv.org/abs/1902.00751
|
259 |
+
|
260 |
+
We don't use but could be fun to try
|
261 |
+
"""
|
262 |
+
def __init__(self, hidden_dim: int, *args, **kwargs):
|
263 |
+
kwargs['skip_connection'] = True
|
264 |
+
kwargs['bias'] = True
|
265 |
+
kwargs['zero_init'] = True
|
266 |
+
self.hidden_dim = hidden_dim
|
267 |
+
super().__init__(*args, **kwargs)
|
268 |
+
|
269 |
+
def init_weights_(self):
|
270 |
+
"""
|
271 |
+
Initialize (W)eights and (b)iases
|
272 |
+
"""
|
273 |
+
kwargs = {'dtype': self.dtype, 'device': self.device}
|
274 |
+
self.layer0 = nn.Parameter(
|
275 |
+
torch.zeros((self.num_heads, self.head_dim, self.hidden_dim), **kwargs)
|
276 |
+
)
|
277 |
+
self.layer1 = nn.Parameter(
|
278 |
+
torch.zeros((self.num_heads, self.hidden_dim, self.feature_dim), **kwargs)
|
279 |
+
)
|
280 |
+
nn.init.kaiming_uniform_(self.layer0)
|
281 |
+
nn.init.kaiming_uniform_(self.layer1)
|
282 |
+
|
283 |
+
self.bias0 = nn.Parameter(torch.zeros((1, self.num_heads, 1, self.hidden_dim), **kwargs))
|
284 |
+
self.bias1 = nn.Parameter(torch.zeros((1, self.num_heads, 1, self.feature_dim), **kwargs))
|
285 |
+
nn.init.kaiming_uniform_(self.bias0)
|
286 |
+
nn.init.kaiming_uniform_(self.bias1)
|
287 |
+
|
288 |
+
def zero_init_with_skip_(self):
|
289 |
+
with torch.no_grad():
|
290 |
+
nn.init.zeros_(self.layer0)
|
291 |
+
nn.init.zeros_(self.layer1)
|
292 |
+
nn.init.zeros_(self.bias0)
|
293 |
+
nn.init.zeros_(self.bias1)
|
294 |
+
|
295 |
+
def zero_init_(self):
|
296 |
+
assert NotImplementedError
|
297 |
+
|
298 |
+
def forward(self, x: torch.Tensor):
|
299 |
+
"""
|
300 |
+
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
|
301 |
+
-> Down-project, apply nonlinearity, up-project; add skip connection
|
302 |
+
"""
|
303 |
+
_x = torch.einsum('hde,bhld->bhle', self.layer0, x) + self.bias0
|
304 |
+
_x = F.relu(_x)
|
305 |
+
_x = torch.einsum('hef,bhle->bhlf', self.layer1, _x) + self.bias1
|
306 |
+
return x + _x if self.skip_connection else _x
|
src/model/linear_attention/__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Linear and linear attention + sliding window classes
|
3 |
+
"""
|
4 |
+
from .linear_attention import (
|
5 |
+
LolcatsLinearAttention, LinearAttentionState
|
6 |
+
)
|
7 |
+
from .linear_window_attention_tk import (
|
8 |
+
LolcatsTKWindowAttention, LinearAttentionTKWindowCache
|
9 |
+
)
|
10 |
+
from .linear_window_attention_sw import (
|
11 |
+
LolcatsSlidingWindowAttention, LinearAttentionSlidingWindowCache
|
12 |
+
)
|
13 |
+
# Experimental chunk linear attentions
|
14 |
+
from .linear_window_attention_tk_long import (
|
15 |
+
LolcatsTKWindowLongAttention,
|
16 |
+
)
|
17 |
+
from .linear_window_attention_sw_long import (
|
18 |
+
LolcatsSlidingWindowLongAttention,
|
19 |
+
)
|
20 |
+
from .linear_window_attention_tk_gen import (
|
21 |
+
LolcatsWindowAttentionTKGen,
|
22 |
+
LinearAttentionTKWindowGenerationCache
|
23 |
+
)
|
src/model/linear_attention/linear_attention.py
ADDED
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Linear attention classes
|
3 |
+
"""
|
4 |
+
from typing import List, Tuple, Optional
|
5 |
+
import copy
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from omegaconf import OmegaConf, DictConfig
|
9 |
+
|
10 |
+
from transformers.cache_utils import Cache # starting at Transformers v4.36
|
11 |
+
|
12 |
+
# Causal linear attention dot product CUDA kernel from fast-transformers
|
13 |
+
try:
|
14 |
+
from csrc import causal_dot_product as fast_causal_dot_product
|
15 |
+
except ImportError:
|
16 |
+
fast_causal_dot_product = None
|
17 |
+
|
18 |
+
from src.model.feature_map import init_feature_map, init_learned_kernel
|
19 |
+
from src.model.rotary import get_rotary_embeddings, apply_rotary_pos_emb
|
20 |
+
from .utils import repeat_kv
|
21 |
+
|
22 |
+
|
23 |
+
# -------------------
|
24 |
+
# Attention functions
|
25 |
+
# -------------------
|
26 |
+
|
27 |
+
def causal_dot_product(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
28 |
+
"""
|
29 |
+
Causal linear attention dot product
|
30 |
+
- If available, use CUDA kernel from fast-transformers
|
31 |
+
"""
|
32 |
+
if fast_causal_dot_product is None:
|
33 |
+
kv = torch.einsum('bhlf,bhld->bhlfd', k, v)
|
34 |
+
return torch.einsum('bhlf,bhlfd->bhld', q, kv.cumsum(dim=2))
|
35 |
+
return fast_causal_dot_product(q, k, v)
|
36 |
+
|
37 |
+
def linear_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
38 |
+
fp32_attention: bool = False, eps: float = 1e-12,
|
39 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
40 |
+
"""
|
41 |
+
Compute linear attention with CUDA kernel implementation from fast-transformers
|
42 |
+
- https://github.com/idiap/fast-transformers
|
43 |
+
- Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim);
|
44 |
+
v is shape (b, h, l, head_dim)
|
45 |
+
"""
|
46 |
+
dtype = q.dtype
|
47 |
+
# Causal mask already applied
|
48 |
+
y = causal_dot_product(q.contiguous().to(dtype=torch.float32),
|
49 |
+
k.contiguous().to(dtype=torch.float32),
|
50 |
+
v.contiguous().to(dtype=torch.float32))
|
51 |
+
if fp32_attention:
|
52 |
+
y = (y / (torch.einsum(
|
53 |
+
"bhld,bhld->bhl", q.float(), k.float().cumsum(dim=2)
|
54 |
+
) + eps)[..., None]).to(dtype=dtype)
|
55 |
+
else:
|
56 |
+
y = y.to(dtype=dtype)
|
57 |
+
k = k.float().cumsum(dim=2).to(dtype=dtype)
|
58 |
+
y = y / (torch.einsum("bhld,bhld->bhl", q, k) + eps)[..., None]
|
59 |
+
return y, None, None
|
60 |
+
|
61 |
+
|
62 |
+
def softmax_attention(q: torch.Tensor, k: torch.Tensor, v: Optional[torch.Tensor] = None,
|
63 |
+
causal: bool = True, fp32_attention: bool = True,
|
64 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
65 |
+
"""
|
66 |
+
Standard softmax attention; only compute outputs if v is not None
|
67 |
+
-> Assume q, k, v are shape (batch_size, num_heads, seq_len, head_dim)
|
68 |
+
"""
|
69 |
+
y = None
|
70 |
+
a = torch.einsum('bhmd,bhnd->bhmn', q, k) * (k.shape[-1] ** -0.5)
|
71 |
+
if causal: # Apply causal mask
|
72 |
+
m, n = a.shape[-2:]
|
73 |
+
causal_mask = torch.ones((m, n), device = a.device, dtype = torch.bool).triu(n - m + 1)
|
74 |
+
a = a.masked_fill(causal_mask, -torch.finfo(a.dtype).max)
|
75 |
+
if fp32_attention:
|
76 |
+
a = torch.softmax(a, dim=-1, dtype=torch.float32).to(q.dtype)
|
77 |
+
else:
|
78 |
+
a = torch.softmax(a, dim=-1)
|
79 |
+
if v is not None:
|
80 |
+
y = torch.einsum('bhmn,bhnd->bhmd', a, v)
|
81 |
+
return y, a, None
|
82 |
+
|
83 |
+
|
84 |
+
def quadratic_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor = None,
|
85 |
+
causal: bool = True, fp32_attention: bool = False, eps: float = 1e-12,
|
86 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
87 |
+
"""
|
88 |
+
Compute attention with feature maps by instantiating L x L matrix of attention weights
|
89 |
+
-> Use for attention distillation
|
90 |
+
-> Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim); v is shape (b, h, l, head_dim)
|
91 |
+
"""
|
92 |
+
y = None
|
93 |
+
dtype = q.dtype
|
94 |
+
if fp32_attention:
|
95 |
+
q, k = q.float(), k.float()
|
96 |
+
a = torch.einsum('bhmd,bhnd->bhmn', q, k) # note we don't scale, tho we could
|
97 |
+
if causal: # Apply causal mask
|
98 |
+
m, n = a.shape[-2:]
|
99 |
+
causal_mask = torch.ones((m, n), device = a.device, dtype = torch.bool).triu(n - m + 1)
|
100 |
+
a = a.masked_fill(causal_mask, 0)
|
101 |
+
# Normalize to compute attention
|
102 |
+
a = a / (a.sum(dim=-1, keepdim=True) + eps)
|
103 |
+
a = a.to(dtype=dtype) if fp32_attention else a
|
104 |
+
if torch.isnan(a).sum() > 0:
|
105 |
+
breakpoint()
|
106 |
+
if v is not None:
|
107 |
+
y = torch.einsum('bhmn,bhnd->bhmd', a, v)
|
108 |
+
return y, a, None
|
109 |
+
|
110 |
+
|
111 |
+
# ---------------------
|
112 |
+
# Attention layer class
|
113 |
+
# ---------------------
|
114 |
+
|
115 |
+
class LolcatsLinearAttention(nn.Module):
|
116 |
+
"""
|
117 |
+
LoLCATs attention implementation initialized from a
|
118 |
+
`LlamaAttention` or `MistralAttention` object (base_attn)
|
119 |
+
|
120 |
+
Most of the arguments are directly tied to argparse args
|
121 |
+
- For now we don't support padding.
|
122 |
+
"""
|
123 |
+
def __init__(self,
|
124 |
+
base_attn: nn.Module, # like LlamaAttention
|
125 |
+
feature_map: str,
|
126 |
+
feature_map_kwargs: dict,
|
127 |
+
layer_idx: Optional[int] = None,
|
128 |
+
max_layer_idx: Optional[int] = None,
|
129 |
+
learned_kernel: Optional[str] = None,
|
130 |
+
learned_kernel_kwargs: Optional[dict] = None,
|
131 |
+
tie_qk_kernels: Optional[bool] = False,
|
132 |
+
rotary_config: Optional[dict] = None,
|
133 |
+
train_attention: Optional[bool] = False,
|
134 |
+
remove_base_attn: Optional[bool] = True,
|
135 |
+
attention_type: Optional[str] = 'lolcats_llama',
|
136 |
+
mask_value: int = 0,
|
137 |
+
eps: float = 1e-12,
|
138 |
+
fp32_attention: bool = False,
|
139 |
+
track_state_grads: bool = False,
|
140 |
+
rank: Optional[int] = 0,
|
141 |
+
**kwargs: any) -> None:
|
142 |
+
super().__init__()
|
143 |
+
self.base_config = getattr(base_attn, 'config', None)
|
144 |
+
if self.base_config is not None:
|
145 |
+
self.base_config = self.base_config.to_dict()
|
146 |
+
self.attention_type = attention_type
|
147 |
+
self.mask_value = mask_value
|
148 |
+
self.eps = eps
|
149 |
+
self.layer_idx = (layer_idx if layer_idx is not None else base_attn.layer_idx)
|
150 |
+
self.max_layer_idx = max_layer_idx
|
151 |
+
self.tie_qk_kernels = tie_qk_kernels
|
152 |
+
self.train_attention = train_attention
|
153 |
+
self.base_inference = False
|
154 |
+
self.fp32_attention = fp32_attention
|
155 |
+
self.track_state_grads = track_state_grads
|
156 |
+
if rank == 0: # multi-gpu
|
157 |
+
if fp32_attention and layer_idx == 0:
|
158 |
+
print(f'-> fp32_attention is {fp32_attention}')
|
159 |
+
if layer_idx == 0 and feature_map_kwargs is not None:
|
160 |
+
for k, v in feature_map_kwargs.items():
|
161 |
+
print(f'-> {k}: {v}')
|
162 |
+
if layer_idx == 0 and learned_kernel_kwargs is not None:
|
163 |
+
for k, v in learned_kernel_kwargs.items():
|
164 |
+
print(f'-> {k}: {v}')
|
165 |
+
|
166 |
+
self.remove_base_attn = remove_base_attn
|
167 |
+
|
168 |
+
# Rotary embeddings (patch for Llama 3.1, Transformer v4.43.0)
|
169 |
+
self.rotary_config = rotary_config
|
170 |
+
if isinstance(self.rotary_config, DictConfig): # ensure dict
|
171 |
+
self.rotary_config = OmegaConf.to_container(self.rotary_config)
|
172 |
+
|
173 |
+
self.rotary_emb = None
|
174 |
+
if self.base_config is not None and self.rotary_config is None:
|
175 |
+
self.rotary_emb = base_attn.rotary_emb
|
176 |
+
|
177 |
+
self.init_weights_(base_attn, remove_base_attn)
|
178 |
+
self.init_feature_map_(feature_map, feature_map_kwargs,
|
179 |
+
learned_kernel, learned_kernel_kwargs)
|
180 |
+
|
181 |
+
def init_feature_map_(self,
|
182 |
+
feature_map: str,
|
183 |
+
feature_map_kwargs: dict,
|
184 |
+
learned_kernel: str = None,
|
185 |
+
learned_kernel_kwargs: dict = None):
|
186 |
+
"""
|
187 |
+
Initialize MLP-based feature map
|
188 |
+
"""
|
189 |
+
self.fmap_gqa = False # Turn True if specified below
|
190 |
+
if learned_kernel is not None:
|
191 |
+
# Ensure dict
|
192 |
+
learned_kernel_kwargs = {k: v for k, v in learned_kernel_kwargs.items()}
|
193 |
+
learned_kernel_kwargs['num_heads'] = self.num_heads
|
194 |
+
learned_kernel_kwargs['head_dim'] = self.head_dim
|
195 |
+
learned_kernel_kwargs['dtype'] = self.q_proj.weight.dtype
|
196 |
+
learned_kernel_kwargs['device'] = self.q_proj.weight.device
|
197 |
+
# Create MLP
|
198 |
+
mlp_learned_kernel = init_learned_kernel(learned_kernel, **learned_kernel_kwargs)
|
199 |
+
# Add "activation"; see src.models.feature_map.py
|
200 |
+
self.feature_map_q = init_feature_map(name=feature_map,
|
201 |
+
mlp=mlp_learned_kernel,
|
202 |
+
**feature_map_kwargs)
|
203 |
+
if self.tie_qk_kernels: # tie mlp weights for query and key feature maps
|
204 |
+
self.feature_map_k = self.feature_map_q
|
205 |
+
else:
|
206 |
+
self.feature_map_k = copy.deepcopy(self.feature_map_q)
|
207 |
+
|
208 |
+
def init_weights_(self, base_attn: nn.Module, remove_base_attn: bool = True):
|
209 |
+
"""
|
210 |
+
Initialize module layers, weights, positional dependencies, etc.
|
211 |
+
from original softmax attention layer (base_attn)
|
212 |
+
"""
|
213 |
+
# Make other attributes accessible
|
214 |
+
self.attention_dropout = 0 # We don't use dropout
|
215 |
+
self.hidden_size = base_attn.hidden_size
|
216 |
+
self.num_heads = base_attn.num_heads
|
217 |
+
self.head_dim = base_attn.head_dim
|
218 |
+
self.num_key_value_heads = base_attn.num_key_value_heads
|
219 |
+
self.num_key_value_groups = base_attn.num_key_value_groups
|
220 |
+
|
221 |
+
self.q_shape = [self.num_heads, self.head_dim]
|
222 |
+
self.k_shape = [self.num_key_value_heads, self.head_dim]
|
223 |
+
self.v_shape = [self.num_key_value_heads, self.head_dim]
|
224 |
+
device = base_attn.q_proj.weight.device
|
225 |
+
# Rotary embeddings
|
226 |
+
if self.rotary_emb is None:
|
227 |
+
self.max_position_embeddings = base_attn.max_position_embeddings
|
228 |
+
scaling_factor = getattr(base_attn.rotary_emb, 'scaling_factor', 1.)
|
229 |
+
if self.rotary_config is None:
|
230 |
+
self.rotary_emb = get_rotary_embeddings(
|
231 |
+
rope_scaling_type=None,
|
232 |
+
head_dim=self.head_dim,
|
233 |
+
max_position_embeddings=self.max_position_embeddings, # base_attn.rotary_emb.max_position_embeddings,
|
234 |
+
rope_theta=base_attn.rotary_emb.base,
|
235 |
+
rope_scaling_factor=scaling_factor, # base_attn.rotary_emb.scaling_factor,
|
236 |
+
device=device,
|
237 |
+
)
|
238 |
+
else:
|
239 |
+
if 'device' not in self.rotary_config:
|
240 |
+
self.rotary_config['device'] = device
|
241 |
+
self.rotary_emb = get_rotary_embeddings(**self.rotary_config)
|
242 |
+
|
243 |
+
# Copy original model projection layers
|
244 |
+
self.q_proj = base_attn.q_proj
|
245 |
+
self.k_proj = base_attn.k_proj
|
246 |
+
self.v_proj = base_attn.v_proj
|
247 |
+
self.o_proj = base_attn.o_proj
|
248 |
+
try: # If wanting to use FA2 for ground-truth inference
|
249 |
+
self._flash_attn_uses_top_left_mask = base_attn._flash_attn_uses_top_left_mask
|
250 |
+
except AttributeError:
|
251 |
+
pass
|
252 |
+
|
253 |
+
if self.remove_base_attn or remove_base_attn:
|
254 |
+
del base_attn # We don't need to keep these around
|
255 |
+
else:
|
256 |
+
self.base_attn = base_attn # For some training runs helpful to just call
|
257 |
+
|
258 |
+
def process_qkv(self,
|
259 |
+
hidden_states: torch.Tensor,
|
260 |
+
attention_mask: Optional[torch.Tensor] = None,
|
261 |
+
position_ids: Optional[torch.LongTensor] = None,
|
262 |
+
past_key_value: Optional[Tuple[int, torch.Tensor, torch.Tensor]] = None,): # "legacy" cache approach
|
263 |
+
"""
|
264 |
+
Compute queries, keys, and values
|
265 |
+
"""
|
266 |
+
b, l, _ = hidden_states.size()
|
267 |
+
q = self.q_proj(hidden_states)
|
268 |
+
k = self.k_proj(hidden_states)
|
269 |
+
v = self.v_proj(hidden_states)
|
270 |
+
kv_seq_len = k.shape[-2]
|
271 |
+
|
272 |
+
# Shape is (batch_size, seq_len, num_heads, head_dim)
|
273 |
+
q = q.view(b, l, *self.q_shape).transpose(1, 2)
|
274 |
+
k = k.view(b, l, *self.k_shape).transpose(1, 2)
|
275 |
+
v = v.view(b, l, *self.v_shape).transpose(1, 2)
|
276 |
+
|
277 |
+
if past_key_value is not None: # and k.shape[2] > q.shape[2]: # e.g., when generating
|
278 |
+
past_key_value.window_size = getattr(self, 'decode_window_size', None) # self.decode_window_size
|
279 |
+
if isinstance(past_key_value, Cache): # In Transformers v4.36+ this is a DynamicCache object
|
280 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
281 |
+
else:
|
282 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
283 |
+
|
284 |
+
# Apply rotary embeddings and repeat for GQA
|
285 |
+
if position_ids is not None and kv_seq_len <= position_ids[0, -1]:
|
286 |
+
kv_seq_len = position_ids[0, -1] + 1 # hack for adjusting position ids
|
287 |
+
try: # As in Transformers v4.36
|
288 |
+
cos, sin = self.rotary_emb(k, seq_len=kv_seq_len)
|
289 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
|
290 |
+
except TypeError: # As in Transformers v4.39+
|
291 |
+
cos, sin = self.rotary_emb(v, position_ids)
|
292 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
293 |
+
|
294 |
+
k = repeat_kv(k, self.num_key_value_groups)
|
295 |
+
v = repeat_kv(v, self.num_key_value_groups)
|
296 |
+
return q, k, v, kv_seq_len
|
297 |
+
|
298 |
+
def forward(self,
|
299 |
+
hidden_states: torch.Tensor,
|
300 |
+
attention_mask: Optional[torch.Tensor] = None,
|
301 |
+
position_ids: Optional[torch.LongTensor] = None,
|
302 |
+
past_key_value: Optional[Tuple[int, torch.Tensor, torch.Tensor]] = None, # "legacy" cache approach
|
303 |
+
output_attentions: bool = False,
|
304 |
+
use_cache: bool = False,
|
305 |
+
**kwargs,
|
306 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
307 |
+
"""
|
308 |
+
Forward pass modified from transformers.models.mistral.modeling_mistral (v4.36)
|
309 |
+
- Consistent with HuggingFace Transformers for easy use with their pretrained models
|
310 |
+
"""
|
311 |
+
b, l, _ = hidden_states.size()
|
312 |
+
q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask,
|
313 |
+
position_ids, past_key_value)
|
314 |
+
if self.base_inference:
|
315 |
+
with torch.no_grad():
|
316 |
+
# 1. Compute "ground-truth" attention output and weights
|
317 |
+
y_true, _, _ = softmax_attention(q, k, v, causal=True)
|
318 |
+
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
319 |
+
y_true = self.o_proj(y_true)
|
320 |
+
attn_weights = (None, None)
|
321 |
+
|
322 |
+
elif self.train_attention: # Distilling / learning attentions
|
323 |
+
# Note for now we assume no padding when distilling; attention masks only enforce causality
|
324 |
+
assert output_attentions is True, f'When training feature maps, output_attentions should be True but is {output_attentions}'
|
325 |
+
with torch.no_grad():
|
326 |
+
# 1. Compute "ground-truth" attention output and weights
|
327 |
+
_y_true, attn_true, _ = softmax_attention(q, k, v, causal=True)
|
328 |
+
y_true = _y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
329 |
+
y_true = self.o_proj(y_true)
|
330 |
+
|
331 |
+
# 2. Compute "predicted" attention (just weights)
|
332 |
+
q, k = self.feature_map_q.q_map(q), self.feature_map_k.k_map(k)
|
333 |
+
y_pred, attn_pred, _ = quadratic_attention(q, k, v, causal=True)
|
334 |
+
attn_weights = ((attn_pred, attn_true), (y_pred, _y_true)) # Save both attention weights so we can supervise.
|
335 |
+
|
336 |
+
else: # Finetuning
|
337 |
+
q, k = self.feature_map_q(q), self.feature_map_k(k)
|
338 |
+
# Apply prefill mask
|
339 |
+
if attention_mask is not None and q.shape[2] > 1:
|
340 |
+
if len(attention_mask.shape) == 4:
|
341 |
+
lin_attn_mask = (attention_mask == 0)[:, :1, -1, :l][..., None] # b, 1, k_len, 1
|
342 |
+
else:
|
343 |
+
lin_attn_mask = attention_mask[:, None, :, None] # b, 1, k_len, 1
|
344 |
+
k = k.masked_fill(~lin_attn_mask, 0)
|
345 |
+
|
346 |
+
if past_key_value is not None: # Initialize states
|
347 |
+
if len(past_key_value.kv_states) == self.layer_idx:
|
348 |
+
b, h, _, f = k.shape
|
349 |
+
past_key_value.kv_states.append(
|
350 |
+
torch.zeros(b, h, f, self.head_dim, dtype=q.dtype, device=q.device)
|
351 |
+
)
|
352 |
+
past_key_value.k_states.append(
|
353 |
+
torch.zeros(b, h, 1, f, dtype=q.dtype, device=q.device)
|
354 |
+
)
|
355 |
+
# Generating
|
356 |
+
if q.shape[2] == 1 and kv_seq_len > 1 and past_key_value is not None:
|
357 |
+
assert use_cache is True
|
358 |
+
kv_state, k_state = past_key_value.update(k, v, self.layer_idx,
|
359 |
+
accumulate_in_fp32=self.fp32_attention)
|
360 |
+
if self.fp32_attention:
|
361 |
+
q = q.float()
|
362 |
+
y_true = (torch.einsum('bhlf,bhfd->bhld', q, kv_state.float()) /
|
363 |
+
torch.einsum('bhlf,bhlf->bhl', q, k_state.float())[..., None]).to(dtype=k.dtype)
|
364 |
+
else:
|
365 |
+
y_true = (torch.einsum('bhlf,bhfd->bhld', q, kv_state) /
|
366 |
+
torch.einsum('bhlf,bhlf->bhl', q, k_state)[..., None])
|
367 |
+
else:
|
368 |
+
kv_state = past_key_value.kv_states[self.layer_idx]
|
369 |
+
k_state = past_key_value.k_states[self.layer_idx]
|
370 |
+
y_true, _, _ = linear_attention(q, k, v, self.fp32_attention, self.eps) # Ordinarily the states are ignored
|
371 |
+
past_key_value.update(k.detach(), v.detach(), self.layer_idx,
|
372 |
+
accumulate_in_fp32=self.fp32_attention)
|
373 |
+
# doing some unnecessary recomputation here
|
374 |
+
else:
|
375 |
+
y_true, _, _ = linear_attention(q, k, v, self.fp32_attention, self.eps)
|
376 |
+
|
377 |
+
# Concatenate heads and apply output projection
|
378 |
+
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
379 |
+
y_true = self.o_proj(y_true)
|
380 |
+
attn_weights = None
|
381 |
+
|
382 |
+
return y_true, attn_weights, past_key_value
|
383 |
+
|
384 |
+
|
385 |
+
class LinearAttentionState(Cache):
|
386 |
+
"""
|
387 |
+
Handle the KV and K states for linear attention
|
388 |
+
- Adopts HF Transformers `past_key_values` convention
|
389 |
+
- Inherits from `Cache` class
|
390 |
+
- Modified from transformers.cache_utils.DynamicCache (v4.36)
|
391 |
+
"""
|
392 |
+
def __init__(self) -> None:
|
393 |
+
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
394 |
+
self._seen_tokens_by_layer: List[int] = []
|
395 |
+
self.kv_states: List[torch.Tensor] = []
|
396 |
+
self.k_states: List[torch.Tensor] = []
|
397 |
+
|
398 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
399 |
+
"""
|
400 |
+
Returns the sequence length of the cached states. A layer index can be optionally passed.
|
401 |
+
"""
|
402 |
+
if len(self._seen_tokens_by_layer) <= layer_idx: # Initializing kv and k states
|
403 |
+
self._seen_tokens_by_layer.append(0)
|
404 |
+
return self._seen_tokens_by_layer[layer_idx]
|
405 |
+
|
406 |
+
def get_max_length(self) -> Optional[int]:
|
407 |
+
"""
|
408 |
+
Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.
|
409 |
+
"""
|
410 |
+
return None
|
411 |
+
|
412 |
+
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
|
413 |
+
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
|
414 |
+
# Cache without size limit -> all cache is usable
|
415 |
+
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
|
416 |
+
# length, we will need to evict part of the cache (and thus not all cache is usable)
|
417 |
+
max_length = self.get_max_length()
|
418 |
+
previous_seq_length = self.get_seq_length(layer_idx)
|
419 |
+
if max_length is not None and previous_seq_length + new_seq_length > max_length:
|
420 |
+
return max_length - new_seq_length
|
421 |
+
return previous_seq_length
|
422 |
+
|
423 |
+
def update(self, key_states: torch.Tensor, value_states: torch.Tensor,
|
424 |
+
layer_idx: Optional[int] = None, cache_kwargs: Optional[any] = None,
|
425 |
+
accumulate_in_fp32: bool = True, **kwargs: any,
|
426 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
427 |
+
|
428 |
+
with torch.no_grad ():
|
429 |
+
if layer_idx == 0:
|
430 |
+
self._seen_tokens += key_states.shape[-2]
|
431 |
+
dtype = key_states.dtype
|
432 |
+
if accumulate_in_fp32:
|
433 |
+
key_states, value_states = key_states.float(), value_states.float()
|
434 |
+
|
435 |
+
kv_state = torch.einsum('bhlf,bhld->bhfd', key_states, value_states).detach()
|
436 |
+
k_state = key_states.sum(dim=-2, keepdim=True).detach() # b, h, 1, f; note the 1
|
437 |
+
# Update the cache
|
438 |
+
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
439 |
+
print('if len(self.k_states) <= layer_idx: # Initializing kv and k states')
|
440 |
+
self.kv_states.append(kv_state.to(dtype))
|
441 |
+
self.k_states.append(k_state.to(dtype))
|
442 |
+
else:
|
443 |
+
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(dtype)
|
444 |
+
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(dtype)
|
445 |
+
self.kv_states[layer_idx] = kv_state
|
446 |
+
self.k_states[layer_idx] = k_state
|
447 |
+
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
448 |
+
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
449 |
+
|
450 |
+
def to_legacy_cache(self):
|
451 |
+
"""Hack, but just return self"""
|
452 |
+
return self
|
453 |
+
|
454 |
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
455 |
+
"""
|
456 |
+
Reorders the cache for beam search, given the selected beam indices.
|
457 |
+
-> Copied from transformers/src/transformers/cache_utils.py
|
458 |
+
"""
|
459 |
+
raise NotImplementedError('Reordering cache not implemented for LinearAttentionState')
|
src/model/linear_attention/linear_window_attention_sw.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Subquadratic attention combining sliding window and linear attentions
|
3 |
+
- Using "standard" sliding windows
|
4 |
+
- Didactically computes outputs with n^2 attention weights for now
|
5 |
+
- Copied + adapted from linear_window_attention_tk.py for single-file reference
|
6 |
+
|
7 |
+
For each layer:
|
8 |
+
- We first compute (softmax) attention over sliding windows
|
9 |
+
- We then compute standard linear attention to "fill in" the earlier parts
|
10 |
+
- We combine to model the entire sequence
|
11 |
+
"""
|
12 |
+
from typing import List, Tuple, Optional, Callable
|
13 |
+
import math
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
from transformers.cache_utils import Cache
|
19 |
+
|
20 |
+
from .linear_attention import (
|
21 |
+
LolcatsLinearAttention, LinearAttentionState,
|
22 |
+
softmax_attention
|
23 |
+
)
|
24 |
+
|
25 |
+
# ----------------------
|
26 |
+
# Sliding window helpers
|
27 |
+
# ----------------------
|
28 |
+
def get_masks(window_size: int, q_len: int, k_len: int,
|
29 |
+
device: torch.device) -> tuple[torch.Tensor]:
|
30 |
+
"""
|
31 |
+
Return masks for softmax and linear attention terms
|
32 |
+
-> 1 is include, 0 is ignore
|
33 |
+
"""
|
34 |
+
kwargs = {'device': device, 'dtype': int}
|
35 |
+
causal_mask = torch.ones((q_len, k_len), **kwargs).tril(k_len - q_len)
|
36 |
+
linear_mask = torch.ones((q_len, k_len), **kwargs).tril(k_len - q_len - window_size)
|
37 |
+
window_mask = causal_mask - linear_mask
|
38 |
+
# Return softmax mask (window), linear attention mask
|
39 |
+
# -> shapes broadcast over (b, h, q_len, k_len)
|
40 |
+
return window_mask[None, None, ...], linear_mask[None, None, ...]
|
41 |
+
|
42 |
+
|
43 |
+
def hybrid_attention_quadratic(q: torch.Tensor, k: torch.Tensor,
|
44 |
+
f_q: torch.Tensor, f_k: torch.Tensor,
|
45 |
+
v: torch.Tensor,
|
46 |
+
window_factor: torch.Tensor,
|
47 |
+
linear_factor: torch.Tensor,
|
48 |
+
window_size: int,
|
49 |
+
kv_state: torch.Tensor = None,
|
50 |
+
k_state: torch.Tensor = None,
|
51 |
+
eps: float = 1e-12,
|
52 |
+
mask_value: float=-1e8):
|
53 |
+
"""
|
54 |
+
Hybrid attention combining sliding window and linear attentions
|
55 |
+
"""
|
56 |
+
|
57 |
+
mask_window, mask_linear = get_masks(window_size, q.shape[-2], k.shape[-2], q.device)
|
58 |
+
|
59 |
+
# 1. Sliding window (softmax attention)
|
60 |
+
a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k.float()) * (k.shape[-1] ** -0.5)
|
61 |
+
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
|
62 |
+
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
|
63 |
+
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
64 |
+
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
65 |
+
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
66 |
+
|
67 |
+
# 2. Under window (linear attention)
|
68 |
+
a_ln = torch.einsum('bhmd,bhnd->bhmn', f_q.float(), f_k.float())
|
69 |
+
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
|
70 |
+
sum_ln = a_ln.sum(dim=-1, keepdim=True)
|
71 |
+
|
72 |
+
# 3. Combine
|
73 |
+
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
|
74 |
+
# Allow outputs to also depend on prior kv_state and k_state
|
75 |
+
y = torch.einsum('bhmn,bhnd->bhmd', a_sm + a_ln, v.float())
|
76 |
+
if kv_state is not None: # Combine with prior kv_state and k_state
|
77 |
+
y += linear_factor * torch.einsum('bhld,bhdf->bhlf', f_q.float(), kv_state.float())
|
78 |
+
sum_ln += linear_factor * torch.einsum(
|
79 |
+
'bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None]
|
80 |
+
y = (y / (sum_sm + sum_ln)).to(q.dtype)
|
81 |
+
return y, a # attention weights only for the last chunk
|
82 |
+
|
83 |
+
|
84 |
+
# ---------------------
|
85 |
+
# Attention layer class
|
86 |
+
# ---------------------
|
87 |
+
class LolcatsSlidingWindowAttention(LolcatsLinearAttention):
|
88 |
+
"""
|
89 |
+
Lolcats attention combining sliding window and linear attention
|
90 |
+
"""
|
91 |
+
def __init__(self,
|
92 |
+
window_size: int = 64,
|
93 |
+
decode_window_size: int = None,
|
94 |
+
affine_attention_factors: bool = False,
|
95 |
+
init_window_factor: float = 0,
|
96 |
+
train_window_factor: bool = True,
|
97 |
+
state_grad_enabled: bool = False,
|
98 |
+
**kwargs):
|
99 |
+
self.window_size = window_size
|
100 |
+
self.decode_window_size = (
|
101 |
+
decode_window_size if decode_window_size is not None else window_size
|
102 |
+
)
|
103 |
+
self.window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1}
|
104 |
+
super().__init__(**kwargs)
|
105 |
+
self.attention_type = kwargs['attention_type'] # 'hedgehog_llama_window_sw'
|
106 |
+
# Determine how we compute attentions
|
107 |
+
self.quadratic_attention = hybrid_attention_quadratic
|
108 |
+
self.attention_type = kwargs['attention_type'] # 'hedgehog_long_llama_window_sw'
|
109 |
+
# Learnable factor for combining attentions
|
110 |
+
self.affine_attention_factors = affine_attention_factors
|
111 |
+
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
|
112 |
+
if train_window_factor:
|
113 |
+
self.window_factors = nn.Parameter(
|
114 |
+
init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype))
|
115 |
+
else:
|
116 |
+
self.register_buffer(
|
117 |
+
"window_factors", init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
|
118 |
+
)
|
119 |
+
# Whether we use original flash attention 2 inference (use during attention transfer)
|
120 |
+
self.base_inference = False
|
121 |
+
self.state_grad_enabled = state_grad_enabled
|
122 |
+
|
123 |
+
def forward(self,
|
124 |
+
hidden_states: torch.Tensor,
|
125 |
+
attention_mask: Optional[torch.Tensor] = None,
|
126 |
+
position_ids: Optional[torch.LongTensor] = None,
|
127 |
+
past_key_value: Optional[Cache] = None,
|
128 |
+
output_attentions: bool = False,
|
129 |
+
use_cache: bool = False,
|
130 |
+
**kwargs,
|
131 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
132 |
+
"""
|
133 |
+
Forward pass with the option to compute attention weights multiple ways
|
134 |
+
if self.train_attention is True
|
135 |
+
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
136 |
+
"""
|
137 |
+
b, l, _ = hidden_states.size()
|
138 |
+
q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask,
|
139 |
+
position_ids, past_key_value)
|
140 |
+
f_q, f_k = self.feature_map_q(q), self.feature_map_k(k) # Have to do after repeat for grouped-query attn if we use same fmap
|
141 |
+
|
142 |
+
if self.train_attention:
|
143 |
+
# 1. Compute "ground-truth" attention output and weights
|
144 |
+
with torch.no_grad():
|
145 |
+
_y_true, a_true = softmax_attention(q, k, v)[:2]
|
146 |
+
y_true = _y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
147 |
+
y_true = self.o_proj(y_true)
|
148 |
+
|
149 |
+
# 2. Compute "predicted" attention outputs
|
150 |
+
# compute attn weights under sliding window
|
151 |
+
window_factors = F.sigmoid(self.window_factors)
|
152 |
+
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
153 |
+
y_pred, a_pred = self.quadratic_attention(q, k, f_q, f_k, v,
|
154 |
+
window_factors, linear_factors,
|
155 |
+
window_size=self.window_size)
|
156 |
+
attn_weights = ((a_pred, a_true), (y_pred, _y_true))
|
157 |
+
else:
|
158 |
+
attn_weights = None
|
159 |
+
# attention_mask = None # For now this is always True
|
160 |
+
if past_key_value is None: # Regular training
|
161 |
+
window_factors = F.sigmoid(self.window_factors)
|
162 |
+
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
163 |
+
y_true, a_pred = self.quadratic_attention(q, k, f_q, f_k, v,
|
164 |
+
window_factors, linear_factors,
|
165 |
+
window_size=self.window_size)
|
166 |
+
attn_weights = a_pred
|
167 |
+
else:
|
168 |
+
past_key_value.window_size = self.decode_window_size
|
169 |
+
if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: # Generating
|
170 |
+
assert use_cache is True
|
171 |
+
_kv = past_key_value.update_for_decoding(k, v, self.layer_idx,
|
172 |
+
self.feature_map_k,
|
173 |
+
dtype=q.dtype)
|
174 |
+
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
175 |
+
|
176 |
+
# Sliding window + linear attention decode
|
177 |
+
window_factors = F.sigmoid(self.window_factors)
|
178 |
+
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
179 |
+
|
180 |
+
# Softmax attention terms
|
181 |
+
a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5)
|
182 |
+
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
183 |
+
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
184 |
+
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
185 |
+
|
186 |
+
# Combine with linear attention terms
|
187 |
+
y_true = (torch.einsum('bhmn,bhnd->bhmd', a_sm, v_cache.float())
|
188 |
+
+ linear_factors * torch.einsum('bhlf,bhfd->bhld', f_q.float(), f_kv_state.float()))
|
189 |
+
sum_ln = linear_factors * torch.einsum(
|
190 |
+
'bhlf,bhnf->bhl', f_q.float(), f_k_state.float())[..., None]
|
191 |
+
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
192 |
+
|
193 |
+
else: # Stateful training
|
194 |
+
try:
|
195 |
+
kv_state = past_key_value.kv_states[self.layer_idx]
|
196 |
+
k_state = past_key_value.k_states[self.layer_idx]
|
197 |
+
except IndexError:
|
198 |
+
kv_state, k_state = None, None
|
199 |
+
window_factors = F.sigmoid(self.window_factors)
|
200 |
+
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
201 |
+
y_true, _ = self.quadratic_attention(q, k, f_q, f_k, v,
|
202 |
+
window_factors, linear_factors,
|
203 |
+
window_size=self.window_size,
|
204 |
+
kv_state=kv_state,
|
205 |
+
k_state=k_state)
|
206 |
+
# Save and update KV cache and states
|
207 |
+
# past_key_value.update(k, v.detach(), self.layer_idx,
|
208 |
+
# fmap_key_states=f_k.detach(),
|
209 |
+
# accumulate_in_fp32=True)
|
210 |
+
past_key_value.update(k, v, self.layer_idx,
|
211 |
+
fmap_key_states=f_k,
|
212 |
+
accumulate_in_fp32=True)
|
213 |
+
# Concatenate heads and apply output projection
|
214 |
+
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
215 |
+
y_true = self.o_proj(y_true)
|
216 |
+
return y_true, attn_weights, past_key_value
|
217 |
+
|
218 |
+
|
219 |
+
class LinearAttentionSlidingWindowCache(LinearAttentionState):
|
220 |
+
"""
|
221 |
+
Class for `past_key_values`
|
222 |
+
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
|
223 |
+
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
224 |
+
"""
|
225 |
+
def __init__(self, window_size: int = 64) -> None:
|
226 |
+
super().__init__()
|
227 |
+
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
228 |
+
self._seen_tokens_by_layer: List[int] = []
|
229 |
+
self.kv_states: List[torch.Tensor] = []
|
230 |
+
self.k_states: List[torch.Tensor] = []
|
231 |
+
|
232 |
+
# Account for sliding windows
|
233 |
+
self.decode_kv_states: List[torch.Tensor] = []
|
234 |
+
self.decode_k_states: List[torch.Tensor] = []
|
235 |
+
self.k_cache: List[torch.Tensor] = []
|
236 |
+
self.v_cache: List[torch.Tensor] = []
|
237 |
+
self.window_size = window_size
|
238 |
+
|
239 |
+
def update(self, key_states: torch.Tensor, value_states: torch.Tensor,
|
240 |
+
layer_idx: Optional[int] = None, cache_kwargs: Optional[any] = None,
|
241 |
+
accumulate_in_fp32: bool = False,
|
242 |
+
fmap_key_states: torch.Tensor = None, # should not be None
|
243 |
+
grad_enabled: bool = False,
|
244 |
+
**kwargs: any,
|
245 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
246 |
+
"""
|
247 |
+
Update KV, K states; and KV cache during training
|
248 |
+
- For decoding, use `self.decode_kv_states` to keep track of KV states
|
249 |
+
up to sliding window terms
|
250 |
+
- For (chunked) training, use `self.kv_states` to keep track of KV states
|
251 |
+
up to end of sequence
|
252 |
+
- Likewise for `self.decode_k_states` and `self.k_states`
|
253 |
+
"""
|
254 |
+
with torch.set_grad_enabled(grad_enabled):
|
255 |
+
if layer_idx == 0:
|
256 |
+
self._seen_tokens += key_states.shape[-2]
|
257 |
+
|
258 |
+
dtype = key_states.dtype
|
259 |
+
if accumulate_in_fp32:
|
260 |
+
# key_states = key_states.float()
|
261 |
+
fmap_key_states = fmap_key_states.float()
|
262 |
+
value_states = value_states.float()
|
263 |
+
|
264 |
+
# Decoding KV state (KV terms up to last window_size)
|
265 |
+
decode_kv_state = torch.einsum(
|
266 |
+
'bhlf,bhld->bhfd', fmap_key_states[:, :, :-self.window_size], value_states[:, :, :-self.window_size]
|
267 |
+
)
|
268 |
+
# KV state
|
269 |
+
kv_state = decode_kv_state + torch.einsum(
|
270 |
+
'bhlf,bhld->bhfd', fmap_key_states[:, :, -self.window_size:], value_states[:, :, -self.window_size:]
|
271 |
+
)
|
272 |
+
# shape is b, h, 1, f; note the 1
|
273 |
+
decode_k_state = fmap_key_states[:, :, :-self.window_size].sum(dim=-2, keepdim=True)
|
274 |
+
k_state = (decode_k_state + fmap_key_states[:, :, -self.window_size:].sum(dim=-2, keepdim=True))
|
275 |
+
|
276 |
+
# Update the cache
|
277 |
+
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
278 |
+
self.kv_states.append(kv_state.to(dtype))
|
279 |
+
self.k_states.append(k_state.to(dtype))
|
280 |
+
|
281 |
+
self.decode_kv_states.append(decode_kv_state.to(dtype))
|
282 |
+
self.decode_k_states.append(decode_k_state.to(dtype))
|
283 |
+
|
284 |
+
self.k_cache.append(key_states[:, :, -self.window_size:, :])
|
285 |
+
self.v_cache.append(value_states[:, :, -self.window_size:, :].to(dtype))
|
286 |
+
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
|
287 |
+
else:
|
288 |
+
# Update kv and k states recurrently
|
289 |
+
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(dtype)
|
290 |
+
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(dtype)
|
291 |
+
self.kv_states[layer_idx] = kv_state
|
292 |
+
self.k_states[layer_idx] = k_state
|
293 |
+
|
294 |
+
decode_kv_state = (self.decode_kv_states[layer_idx].to(kv_state.dtype)
|
295 |
+
+ decode_kv_state).to(dtype)
|
296 |
+
decode_k_state = (self.decode_k_states[layer_idx].to(kv_state.dtype)
|
297 |
+
+ decode_k_state).to(dtype)
|
298 |
+
self.decode_kv_states[layer_idx] = decode_kv_state
|
299 |
+
self.decode_k_states[layer_idx] = decode_k_state
|
300 |
+
|
301 |
+
self.k_cache[layer_idx] = key_states[:, :, -self.window_size:, :]
|
302 |
+
self.v_cache[layer_idx] = value_states[:, :, -self.window_size:, :]
|
303 |
+
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
304 |
+
|
305 |
+
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
306 |
+
|
307 |
+
def update_for_decoding(self, keys: torch.Tensor, values: torch.Tensor,
|
308 |
+
layer_idx: int, feature_map_k: Callable, dtype: torch.dtype):
|
309 |
+
"""
|
310 |
+
Update the decoding KV and K states, and KV cache, during decodeing
|
311 |
+
"""
|
312 |
+
with torch.no_grad():
|
313 |
+
k_cache = self.k_cache[layer_idx]
|
314 |
+
v_cache = self.v_cache[layer_idx]
|
315 |
+
|
316 |
+
if k_cache.shape[-2] < self.window_size: # build window-size cache
|
317 |
+
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
|
318 |
+
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
|
319 |
+
else:
|
320 |
+
# MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size
|
321 |
+
# if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache
|
322 |
+
# f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device)
|
323 |
+
# else:
|
324 |
+
# f_k_state = feature_map_k(k_cache[:, :, :1, :])
|
325 |
+
# -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation
|
326 |
+
k_state = feature_map_k(k_cache[:, :, :1, :])
|
327 |
+
v_state = v_cache[:, :, :1, :]
|
328 |
+
kv_state = torch.einsum('bhlf,bhld->bhfd', k_state.float(), v_state.float()).to(dtype) # b, h, f, d
|
329 |
+
self.decode_kv_states[layer_idx] += kv_state
|
330 |
+
self.decode_k_states[layer_idx] += k_state
|
331 |
+
|
332 |
+
self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], keys], dim=-2)
|
333 |
+
self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], values], dim=-2)
|
334 |
+
|
335 |
+
if layer_idx == 0:
|
336 |
+
self._seen_tokens += keys.shape[-2]
|
337 |
+
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
|
338 |
+
return (self.k_cache[layer_idx], self.v_cache[layer_idx],
|
339 |
+
self.decode_kv_states[layer_idx], self.decode_k_states[layer_idx])
|
src/model/linear_attention/linear_window_attention_sw_linear.py
ADDED
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Subquadratic attention combining sliding window and linear attentions
|
3 |
+
- Using "standard" sliding windows
|
4 |
+
- Didactically computes outputs with n^2 attention weights for now
|
5 |
+
- Copied + adapted from linear_window_attention_tk.py for single-file reference
|
6 |
+
|
7 |
+
For each layer:
|
8 |
+
- We first compute (softmax) attention over sliding windows
|
9 |
+
- We then compute standard linear attention to "fill in" the earlier parts
|
10 |
+
- We combine to model the entire sequence
|
11 |
+
"""
|
12 |
+
from typing import List, Tuple, Optional, Callable
|
13 |
+
import math
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
from transformers.cache_utils import Cache
|
19 |
+
try:
|
20 |
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
21 |
+
except ModuleNotFoundError:
|
22 |
+
_flash_attention_forward = None # Transformers v4.36
|
23 |
+
|
24 |
+
# Causal linear attention dot product CUDA kernel from fast-transformers
|
25 |
+
from csrc import causal_dot_product
|
26 |
+
|
27 |
+
from src.model.rotary import apply_rotary_pos_emb
|
28 |
+
from .linear_attention import (
|
29 |
+
LolcatsLinearAttention, LinearAttentionState,
|
30 |
+
softmax_attention
|
31 |
+
)
|
32 |
+
|
33 |
+
# ----------------------
|
34 |
+
# Sliding window helpers
|
35 |
+
# ----------------------
|
36 |
+
def get_masks(window_size: int, q_len: int, k_len: int,
|
37 |
+
device: torch.device) -> tuple[torch.Tensor]:
|
38 |
+
"""
|
39 |
+
Return masks for softmax and linear attention terms
|
40 |
+
-> 1 is include, 0 is ignore
|
41 |
+
"""
|
42 |
+
kwargs = {'device': device, 'dtype': int}
|
43 |
+
causal_mask = torch.ones((q_len, k_len), **kwargs).tril(max(k_len - q_len, 0))
|
44 |
+
linear_mask = torch.ones((q_len, k_len), **kwargs).tril(max(k_len - q_len, 0) - window_size)
|
45 |
+
window_mask = causal_mask - linear_mask
|
46 |
+
# Return softmax mask (window), linear attention mask
|
47 |
+
# -> shapes broadcast over (b, h, q_len, k_len)
|
48 |
+
return window_mask[None, None, ...], linear_mask[None, None, ...]
|
49 |
+
|
50 |
+
|
51 |
+
def hybrid_attention_quadratic(q: torch.Tensor, k: torch.Tensor,
|
52 |
+
f_q: torch.Tensor, f_k: torch.Tensor,
|
53 |
+
v: torch.Tensor,
|
54 |
+
window_factor: torch.Tensor,
|
55 |
+
linear_factor: torch.Tensor,
|
56 |
+
window_size: int,
|
57 |
+
kv_state: torch.Tensor = None,
|
58 |
+
k_state: torch.Tensor = None,
|
59 |
+
eps: float = 1e-12,
|
60 |
+
mask_value: float=-1e8):
|
61 |
+
"""
|
62 |
+
Hybrid attention combining sliding window and linear attentions
|
63 |
+
"""
|
64 |
+
|
65 |
+
mask_window, mask_linear = get_masks(window_size, q.shape[-2], k.shape[-2], q.device)
|
66 |
+
|
67 |
+
# 1. Sliding window (softmax attention)
|
68 |
+
a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k.float()) * (k.shape[-1] ** -0.5)
|
69 |
+
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
|
70 |
+
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
|
71 |
+
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
72 |
+
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
73 |
+
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
74 |
+
|
75 |
+
# 2. Under window (linear attention)
|
76 |
+
a_ln = torch.einsum('bhmd,bhnd->bhmn', f_q.float(), f_k.float())
|
77 |
+
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
|
78 |
+
sum_ln = a_ln.sum(dim=-1, keepdim=True)
|
79 |
+
|
80 |
+
# 3. Combine
|
81 |
+
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
|
82 |
+
# Allow outputs to also depend on prior kv_state and k_state
|
83 |
+
y = torch.einsum('bhmn,bhnd->bhmd', a_sm + a_ln, v.float())
|
84 |
+
if kv_state is not None: # Combine with prior kv_state and k_state
|
85 |
+
y += linear_factor * torch.einsum('bhld,bhdf->bhlf', f_q.float(), kv_state.float())
|
86 |
+
sum_ln += linear_factor * torch.einsum(
|
87 |
+
'bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None]
|
88 |
+
y = (y / (sum_sm + sum_ln)).to(q.dtype)
|
89 |
+
return y, a # attention weights only for the last chunk
|
90 |
+
|
91 |
+
|
92 |
+
# ------------------------------
|
93 |
+
# Hybrid window attention linear
|
94 |
+
# ------------------------------
|
95 |
+
def under_window_linear_attention(f_q: torch.Tensor, f_k: torch.Tensor, v: torch.Tensor,
|
96 |
+
window_size: int, linear_factor: float, eps: float=1e-12):
|
97 |
+
"""Compute hybrid window attention dot product with linear complexity in q_len"""
|
98 |
+
dtype = f_q.dtype
|
99 |
+
w = window_size
|
100 |
+
f_k = F.pad(f_k, (0, 0, w, 0), value=0)[:, :, :-w, :]
|
101 |
+
v = F.pad(v, (0, 0, w, 0), value=0)[:, :, :-w, :]
|
102 |
+
qkv = linear_factor * causal_dot_product(f_q.contiguous().to(dtype=torch.float32),
|
103 |
+
f_k.contiguous().to(dtype=torch.float32),
|
104 |
+
v.contiguous().to(dtype=torch.float32)).to(dtype=dtype)
|
105 |
+
sum_f_k = f_k.float().cumsum(dim=2).to(dtype=dtype)
|
106 |
+
sum_qk = linear_factor * torch.einsum("bhld,bhld->bhl", f_q, sum_f_k)[..., None]
|
107 |
+
sum_qk[sum_qk == 0] += eps
|
108 |
+
return qkv, sum_qk
|
109 |
+
|
110 |
+
|
111 |
+
def sliding_window_softmax_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
112 |
+
window_size: int, window_factor: float, mask_value: float=-1e8):
|
113 |
+
"""
|
114 |
+
Compute sliding window softmax attention without materializing
|
115 |
+
O(seq_len^2) attention weights
|
116 |
+
"""
|
117 |
+
d = q.shape[-1]
|
118 |
+
# Compute windows for keys
|
119 |
+
window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1}
|
120 |
+
k = F.pad(k, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs)
|
121 |
+
v = F.pad(v, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs)
|
122 |
+
|
123 |
+
# Compute windowed_softmax(qk); causal in its construction
|
124 |
+
a_sm = torch.einsum('bhld,bhldw->bhlw', q, k) * (d ** -0.5)
|
125 |
+
a_sm[a_sm == 0] = -torch.finfo(q.dtype).max # heuristic for zeroing out padding above
|
126 |
+
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
127 |
+
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
128 |
+
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
129 |
+
return torch.einsum('bhlw,bhldw->bhld', a_sm, v), sum_sm
|
130 |
+
# return torch.einsum('bhlw,bhldw->bhld', torch.softmax(qk, dim=-1), v)
|
131 |
+
|
132 |
+
|
133 |
+
def hybrid_attention_linear(q: torch.Tensor, k: torch.Tensor,
|
134 |
+
f_q: torch.Tensor, f_k: torch.Tensor,
|
135 |
+
v: torch.Tensor,
|
136 |
+
window_factor: torch.Tensor = None,
|
137 |
+
linear_factor: torch.Tensor = None,
|
138 |
+
window_size: int = 64,
|
139 |
+
kv_state: torch.Tensor = None,
|
140 |
+
k_state: torch.Tensor = None,
|
141 |
+
eps: float = 1e-12,
|
142 |
+
mask_value: float=-1e8):
|
143 |
+
"""
|
144 |
+
Alternative hybrid attention combining sliding window and linear attentions
|
145 |
+
-> Uses O(n) memory if n is sequence length by padding and unfolding windows
|
146 |
+
"""
|
147 |
+
window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1}
|
148 |
+
# 1. Sliding window (softmax attention)
|
149 |
+
with torch.no_grad():
|
150 |
+
qkv_sm, sum_qk_sm = sliding_window_softmax_attention(q, k, v, window_size, window_factor, mask_value)
|
151 |
+
|
152 |
+
# 2. Under window (linear attention)
|
153 |
+
qkv_ln, sum_qk_ln = under_window_linear_attention(f_q, f_k, v, window_size, linear_factor, eps)
|
154 |
+
|
155 |
+
# 3. Combine
|
156 |
+
y = (qkv_sm + qkv_ln) / (sum_qk_sm + sum_qk_ln)
|
157 |
+
return y, None
|
158 |
+
|
159 |
+
|
160 |
+
# ---------------------
|
161 |
+
# Attention layer class
|
162 |
+
# ---------------------
|
163 |
+
class LolcatsLinearSlidingWindowAttention(LolcatsLinearAttention):
|
164 |
+
"""
|
165 |
+
Lolcats attention combining sliding window and linear attention
|
166 |
+
"""
|
167 |
+
def __init__(self,
|
168 |
+
window_size: int = 64,
|
169 |
+
decode_window_size: int = None,
|
170 |
+
affine_attention_factors: bool = False,
|
171 |
+
init_window_factor: float = 0,
|
172 |
+
train_window_factor: bool = True,
|
173 |
+
state_grad_enabled: bool = False,
|
174 |
+
**kwargs):
|
175 |
+
self.window_size = window_size
|
176 |
+
self.decode_window_size = (
|
177 |
+
decode_window_size if decode_window_size is not None else window_size
|
178 |
+
)
|
179 |
+
self.window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1}
|
180 |
+
super().__init__(**kwargs)
|
181 |
+
# Determine how we compute attentions
|
182 |
+
self.linear_attention = hybrid_attention_linear
|
183 |
+
self.attention_type = 'lolcats_llama_window_sw'
|
184 |
+
# Learnable factor for combining attentions
|
185 |
+
self.affine_attention_factors = affine_attention_factors
|
186 |
+
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
|
187 |
+
if train_window_factor:
|
188 |
+
self.window_factors = nn.Parameter(
|
189 |
+
init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype))
|
190 |
+
else:
|
191 |
+
self.register_buffer(
|
192 |
+
"window_factors", init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
|
193 |
+
)
|
194 |
+
# Whether we use original flash attention 2 inference (use during attention transfer)
|
195 |
+
self.base_inference = False
|
196 |
+
self.state_grad_enabled = state_grad_enabled
|
197 |
+
|
198 |
+
def forward(self,
|
199 |
+
hidden_states: torch.Tensor,
|
200 |
+
attention_mask: Optional[torch.Tensor] = None,
|
201 |
+
position_ids: Optional[torch.LongTensor] = None,
|
202 |
+
past_key_value: Optional[Cache] = None,
|
203 |
+
output_attentions: bool = False,
|
204 |
+
use_cache: bool = False,
|
205 |
+
**kwargs,
|
206 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
207 |
+
"""
|
208 |
+
Forward pass with the option to compute attention weights multiple ways
|
209 |
+
if self.train_attention is True
|
210 |
+
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
211 |
+
"""
|
212 |
+
b, l, _ = hidden_states.size()
|
213 |
+
|
214 |
+
if self.train_attention and self.base_inference:
|
215 |
+
with torch.no_grad():
|
216 |
+
_y_true = flash_attention_2(self, # self.base_attn,
|
217 |
+
hidden_states=hidden_states,
|
218 |
+
attention_mask=None,
|
219 |
+
position_ids=position_ids,
|
220 |
+
past_key_value=None,
|
221 |
+
output_attentions=False,
|
222 |
+
use_cache=False)[0]
|
223 |
+
# _y_true.shape is (batch_size, seq_len, num_heads, head_dim)
|
224 |
+
y_true = _y_true.reshape(b, l, -1).contiguous()
|
225 |
+
y_true = self.o_proj(y_true)
|
226 |
+
# layer_io = (hidden_states, _y_true) # hack
|
227 |
+
layer_io = (hidden_states.cpu(), _y_true.cpu()) # hack
|
228 |
+
return y_true, layer_io, None
|
229 |
+
|
230 |
+
else:
|
231 |
+
q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask,
|
232 |
+
position_ids, past_key_value)
|
233 |
+
f_q, f_k = self.feature_map_q(q), self.feature_map_k(k) # Have to do after repeat for grouped-query attn if we use same fmap
|
234 |
+
|
235 |
+
attn_weights = None
|
236 |
+
# attention_mask = None # For now this is always True
|
237 |
+
if past_key_value is None: # Regular training
|
238 |
+
window_factors = F.sigmoid(self.window_factors)
|
239 |
+
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
240 |
+
y_true, a_pred = self.linear_attention(q, k, f_q, f_k, v,
|
241 |
+
window_factors, linear_factors,
|
242 |
+
window_size=self.window_size)
|
243 |
+
attn_weights = a_pred
|
244 |
+
else:
|
245 |
+
past_key_value.window_size = self.decode_window_size
|
246 |
+
if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: # Generating
|
247 |
+
assert use_cache is True
|
248 |
+
_kv = past_key_value.update_for_decoding(k, v, self.layer_idx,
|
249 |
+
self.feature_map_k,
|
250 |
+
dtype=q.dtype)
|
251 |
+
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
252 |
+
|
253 |
+
# Sliding window + linear attention decode
|
254 |
+
window_factors = F.sigmoid(self.window_factors)
|
255 |
+
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
256 |
+
|
257 |
+
# Softmax attention terms
|
258 |
+
a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5)
|
259 |
+
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
260 |
+
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
261 |
+
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
262 |
+
|
263 |
+
# Combine with linear attention terms
|
264 |
+
y_true = (torch.einsum('bhmn,bhnd->bhmd', a_sm, v_cache.float())
|
265 |
+
+ linear_factors * torch.einsum('bhlf,bhfd->bhld', f_q.float(), f_kv_state.float()))
|
266 |
+
sum_ln = linear_factors * torch.einsum(
|
267 |
+
'bhlf,bhnf->bhl', f_q.float(), f_k_state.float())[..., None]
|
268 |
+
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
269 |
+
|
270 |
+
else: # Stateful training
|
271 |
+
try:
|
272 |
+
kv_state = past_key_value.kv_states[self.layer_idx]
|
273 |
+
k_state = past_key_value.k_states[self.layer_idx]
|
274 |
+
except IndexError:
|
275 |
+
kv_state, k_state = None, None
|
276 |
+
window_factors = F.sigmoid(self.window_factors)
|
277 |
+
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
278 |
+
y_true, _ = self.linear_attention(q, k, f_q, f_k, v,
|
279 |
+
window_factors, linear_factors,
|
280 |
+
window_size=self.window_size,
|
281 |
+
kv_state=kv_state,
|
282 |
+
k_state=k_state)
|
283 |
+
# Save and update KV cache and states
|
284 |
+
# past_key_value.update(k, v.detach(), self.layer_idx,
|
285 |
+
# fmap_key_states=f_k.detach(),
|
286 |
+
# accumulate_in_fp32=True)
|
287 |
+
past_key_value.update(k, v, self.layer_idx,
|
288 |
+
fmap_key_states=f_k,
|
289 |
+
accumulate_in_fp32=True)
|
290 |
+
# Concatenate heads and apply output projection
|
291 |
+
_y_true = y_true.transpose(1, 2).contiguous()
|
292 |
+
y_true = self.o_proj(_y_true.view(b, l, self.hidden_size))
|
293 |
+
|
294 |
+
if self.train_attention:
|
295 |
+
attn_weights = _y_true # flash_attn outputs are shape (b, l, h, d)
|
296 |
+
return y_true, attn_weights, past_key_value
|
297 |
+
|
298 |
+
|
299 |
+
class LinearAttentionSlidingWindowCache(LinearAttentionState):
|
300 |
+
"""
|
301 |
+
Class for `past_key_values`
|
302 |
+
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
|
303 |
+
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
304 |
+
"""
|
305 |
+
def __init__(self, window_size: int = 64) -> None:
|
306 |
+
super().__init__()
|
307 |
+
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
308 |
+
self._seen_tokens_by_layer: List[int] = []
|
309 |
+
self.kv_states: List[torch.Tensor] = []
|
310 |
+
self.k_states: List[torch.Tensor] = []
|
311 |
+
|
312 |
+
# Account for sliding windows
|
313 |
+
self.decode_kv_states: List[torch.Tensor] = []
|
314 |
+
self.decode_k_states: List[torch.Tensor] = []
|
315 |
+
self.k_cache: List[torch.Tensor] = []
|
316 |
+
self.v_cache: List[torch.Tensor] = []
|
317 |
+
self.window_size = window_size
|
318 |
+
|
319 |
+
def update(self, key_states: torch.Tensor, value_states: torch.Tensor,
|
320 |
+
layer_idx: Optional[int] = None, cache_kwargs: Optional[any] = None,
|
321 |
+
accumulate_in_fp32: bool = False,
|
322 |
+
fmap_key_states: torch.Tensor = None, # should not be None
|
323 |
+
grad_enabled: bool = False,
|
324 |
+
**kwargs: any,
|
325 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
326 |
+
"""
|
327 |
+
Update KV, K states; and KV cache during training
|
328 |
+
- For decoding, use `self.decode_kv_states` to keep track of KV states
|
329 |
+
up to sliding window terms
|
330 |
+
- For (chunked) training, use `self.kv_states` to keep track of KV states
|
331 |
+
up to end of sequence
|
332 |
+
- Likewise for `self.decode_k_states` and `self.k_states`
|
333 |
+
"""
|
334 |
+
with torch.set_grad_enabled(grad_enabled):
|
335 |
+
if layer_idx == 0:
|
336 |
+
self._seen_tokens += key_states.shape[-2]
|
337 |
+
|
338 |
+
dtype = key_states.dtype
|
339 |
+
if accumulate_in_fp32:
|
340 |
+
# key_states = key_states.float()
|
341 |
+
fmap_key_states = fmap_key_states.float()
|
342 |
+
value_states = value_states.float()
|
343 |
+
|
344 |
+
# Decoding KV state (KV terms up to last window_size)
|
345 |
+
decode_kv_state = torch.einsum(
|
346 |
+
'bhlf,bhld->bhfd', fmap_key_states[:, :, :-self.window_size], value_states[:, :, :-self.window_size]
|
347 |
+
)
|
348 |
+
# KV state
|
349 |
+
kv_state = decode_kv_state + torch.einsum(
|
350 |
+
'bhlf,bhld->bhfd', fmap_key_states[:, :, -self.window_size:], value_states[:, :, -self.window_size:]
|
351 |
+
)
|
352 |
+
# shape is b, h, 1, f; note the 1
|
353 |
+
decode_k_state = fmap_key_states[:, :, :-self.window_size].sum(dim=-2, keepdim=True)
|
354 |
+
k_state = (decode_k_state + fmap_key_states[:, :, -self.window_size:].sum(dim=-2, keepdim=True))
|
355 |
+
|
356 |
+
# Update the cache
|
357 |
+
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
358 |
+
self.kv_states.append(kv_state.to(dtype))
|
359 |
+
self.k_states.append(k_state.to(dtype))
|
360 |
+
|
361 |
+
self.decode_kv_states.append(decode_kv_state.to(dtype))
|
362 |
+
self.decode_k_states.append(decode_k_state.to(dtype))
|
363 |
+
|
364 |
+
self.k_cache.append(key_states[:, :, -self.window_size:, :])
|
365 |
+
self.v_cache.append(value_states[:, :, -self.window_size:, :].to(dtype))
|
366 |
+
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
|
367 |
+
else:
|
368 |
+
# Update kv and k states recurrently
|
369 |
+
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(dtype)
|
370 |
+
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(dtype)
|
371 |
+
self.kv_states[layer_idx] = kv_state
|
372 |
+
self.k_states[layer_idx] = k_state
|
373 |
+
|
374 |
+
decode_kv_state = (self.decode_kv_states[layer_idx].to(kv_state.dtype)
|
375 |
+
+ decode_kv_state).to(dtype)
|
376 |
+
decode_k_state = (self.decode_k_states[layer_idx].to(kv_state.dtype)
|
377 |
+
+ decode_k_state).to(dtype)
|
378 |
+
self.decode_kv_states[layer_idx] = decode_kv_state
|
379 |
+
self.decode_k_states[layer_idx] = decode_k_state
|
380 |
+
|
381 |
+
self.k_cache[layer_idx] = key_states[:, :, -self.window_size:, :]
|
382 |
+
self.v_cache[layer_idx] = value_states[:, :, -self.window_size:, :]
|
383 |
+
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
384 |
+
|
385 |
+
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
386 |
+
|
387 |
+
def update_for_decoding(self, keys: torch.Tensor, values: torch.Tensor,
|
388 |
+
layer_idx: int, feature_map_k: Callable, dtype: torch.dtype):
|
389 |
+
"""
|
390 |
+
Update the decoding KV and K states, and KV cache, during decodeing
|
391 |
+
"""
|
392 |
+
with torch.no_grad():
|
393 |
+
k_cache = self.k_cache[layer_idx]
|
394 |
+
v_cache = self.v_cache[layer_idx]
|
395 |
+
|
396 |
+
if k_cache.shape[-2] < self.window_size: # build window-size cache
|
397 |
+
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
|
398 |
+
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
|
399 |
+
else:
|
400 |
+
# MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size
|
401 |
+
# if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache
|
402 |
+
# f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device)
|
403 |
+
# else:
|
404 |
+
# f_k_state = feature_map_k(k_cache[:, :, :1, :])
|
405 |
+
# -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation
|
406 |
+
k_state = feature_map_k(k_cache[:, :, :1, :])
|
407 |
+
v_state = v_cache[:, :, :1, :]
|
408 |
+
kv_state = torch.einsum('bhlf,bhld->bhfd', k_state.float(), v_state.float()).to(dtype) # b, h, f, d
|
409 |
+
self.decode_kv_states[layer_idx] += kv_state
|
410 |
+
self.decode_k_states[layer_idx] += k_state
|
411 |
+
|
412 |
+
self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], keys], dim=-2)
|
413 |
+
self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], values], dim=-2)
|
414 |
+
|
415 |
+
if layer_idx == 0:
|
416 |
+
self._seen_tokens += keys.shape[-2]
|
417 |
+
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
|
418 |
+
return (self.k_cache[layer_idx], self.v_cache[layer_idx],
|
419 |
+
self.decode_kv_states[layer_idx], self.decode_k_states[layer_idx])
|
420 |
+
|
421 |
+
|
422 |
+
# -----------------
|
423 |
+
# Flash Attention 2
|
424 |
+
# -----------------
|
425 |
+
|
426 |
+
def flash_attention_2(self,
|
427 |
+
hidden_states: torch.Tensor,
|
428 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
429 |
+
position_ids: Optional[torch.LongTensor] = None,
|
430 |
+
past_key_value: Optional[Cache] = None,
|
431 |
+
output_attentions: bool = False,
|
432 |
+
use_cache: bool = False,
|
433 |
+
cache_position: Optional[torch.LongTensor] = None,
|
434 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
435 |
+
"""
|
436 |
+
Wrapper for LlamaFlashAttention2
|
437 |
+
Copied and modified from HF Transformers v4.36 and v4.43 implementations
|
438 |
+
- (4.43) https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L402
|
439 |
+
- (4.36) https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L456
|
440 |
+
"""
|
441 |
+
output_attentions = False
|
442 |
+
|
443 |
+
bsz, q_len, _ = hidden_states.size()
|
444 |
+
|
445 |
+
query_states = self.q_proj(hidden_states)
|
446 |
+
key_states = self.k_proj(hidden_states)
|
447 |
+
value_states = self.v_proj(hidden_states)
|
448 |
+
|
449 |
+
# Flash attention requires the input to have the shape
|
450 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
451 |
+
# therefore we just need to keep the original shape
|
452 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
453 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
454 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
455 |
+
|
456 |
+
try: # As in Transformers v4.36
|
457 |
+
kv_seq_len = key_states.shape[-2]
|
458 |
+
if past_key_value is not None:
|
459 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
460 |
+
cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len)
|
461 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
462 |
+
except: # As in Transformers v4.39
|
463 |
+
cos, sin = self.rotary_emb(key_states, position_ids)
|
464 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
465 |
+
|
466 |
+
if past_key_value is not None:
|
467 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
468 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
469 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
470 |
+
|
471 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
472 |
+
# to be able to avoid many of these transpose/reshape/view.
|
473 |
+
query_states = query_states.transpose(1, 2)
|
474 |
+
key_states = key_states.transpose(1, 2)
|
475 |
+
value_states = value_states.transpose(1, 2)
|
476 |
+
|
477 |
+
dropout_rate = self.attention_dropout if self.training else 0.0
|
478 |
+
|
479 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
480 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
481 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
482 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
483 |
+
# in fp32. (LlamaRMSNorm handles it correctly)
|
484 |
+
|
485 |
+
input_dtype = query_states.dtype
|
486 |
+
if input_dtype == torch.float32:
|
487 |
+
if torch.is_autocast_enabled():
|
488 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
489 |
+
# Handle the case where the model is quantized
|
490 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
491 |
+
target_dtype = self.config._pre_quantization_dtype
|
492 |
+
else:
|
493 |
+
target_dtype = self.q_proj.weight.dtype
|
494 |
+
|
495 |
+
logger.warning_once(
|
496 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
497 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
498 |
+
f" {target_dtype}."
|
499 |
+
)
|
500 |
+
|
501 |
+
query_states = query_states.to(target_dtype)
|
502 |
+
key_states = key_states.to(target_dtype)
|
503 |
+
value_states = value_states.to(target_dtype)
|
504 |
+
|
505 |
+
if getattr(self, '_flash_attention_forward', False):
|
506 |
+
attn_output = self._flash_attention_forward(
|
507 |
+
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate,
|
508 |
+
is_causal=True,
|
509 |
+
)
|
510 |
+
else:
|
511 |
+
attn_output = _flash_attention_forward(
|
512 |
+
query_states,
|
513 |
+
key_states,
|
514 |
+
value_states,
|
515 |
+
attention_mask,
|
516 |
+
q_len,
|
517 |
+
dropout=0, # dropout_rate,
|
518 |
+
sliding_window=getattr(self, "sliding_window", None),
|
519 |
+
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
520 |
+
is_causal=True,
|
521 |
+
)
|
522 |
+
return attn_output, past_key_value
|
src/model/linear_attention/linear_window_attention_sw_long.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
LoLCATs attention combining sliding window and linear attentions
|
3 |
+
- Using standard sliding window arrangement
|
4 |
+
- Training over long sequences with fixed memory with recurrent view
|
5 |
+
- During attention transfer, use Flash Attention to compute softmax attention outputs
|
6 |
+
|
7 |
+
For each layer:
|
8 |
+
- We first compute (softmax) attention over sliding windows
|
9 |
+
- We then compute standard linear attention to "fill in" the earlier parts
|
10 |
+
- We combine to model the entire sequence
|
11 |
+
"""
|
12 |
+
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
13 |
+
from .linear_window_attention_sw import hybrid_attention_quadratic
|
14 |
+
|
15 |
+
|
16 |
+
class LolcatsSlidingWindowLongAttention(LolcatsTKWindowLongAttention):
|
17 |
+
"""
|
18 |
+
Lolcats attention combining sliding window and linear attention
|
19 |
+
"""
|
20 |
+
def __init__(self, remove_base_attn=True, **kwargs):
|
21 |
+
# keep self.base_attn for Flash Attention inference
|
22 |
+
super().__init__(remove_base_attn=True, **kwargs)
|
23 |
+
self.quadratic_attention = hybrid_attention_quadratic
|