diff --git a/configs/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml b/configs/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e71170dee7ac8b7742af040033a0c90801338560 --- /dev/null +++ b/configs/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml @@ -0,0 +1,52 @@ +dataset: + name: alpaca_clean + dataset_config: + name: default + path: yahma/alpaca-cleaned + chunk_size: 1024 # sequence length for distilling + concat_data: true + cache_dir: 'data/alpaca' # Change this to where you want to save + pretrained_model_config: # will be updated based on model_config + pretrained_model_name_or_path: 'meta-llama/Meta-Llama-3-8B' + cache_dir: '/scratch/' + preprocess_config: null + +dataloader: + batch_size: 1 + num_workers: 2 + drop_last: false + pin_memory: true + +optimizer: + optim: adamw_torch_fused + lr: 0.01 + weight_decay: 0.0 + +lr_scheduler: + lr_scheduler_type: reduce_lr_on_plateau + mode: min + factor: 0.1 + patience: 10 + min_lr: 0.00001 + +trainer: # HuggingFace Trainer-like arguments + name: distill_attention_xent_mse + reverse_kl: false + mse_factor: 1000 + xent_factor: 0 + + bf16: true + train_split: train + val_split: validation + num_train_epochs: 2 + gradient_accumulation_steps: 8 + seed: 42 + batch_size: 1 + load_best_model_at_end: true + greater_is_better: false + metric_for_best_model: distill/eval/loss + logging_steps: 100 + evaluation_strategy: steps + max_steps: -1 + eval_steps: 100 + max_eval_batches: null diff --git a/configs/experiment/distill_alpaca_clean_xent1_mse1000_lr1e-2.yaml b/configs/experiment/distill_alpaca_clean_xent1_mse1000_lr1e-2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c1bfa91b63d4a029a969fa63092e5661019536d3 --- /dev/null +++ b/configs/experiment/distill_alpaca_clean_xent1_mse1000_lr1e-2.yaml @@ -0,0 +1,52 @@ +dataset: + name: alpaca_clean + dataset_config: + name: default + path: yahma/alpaca-cleaned + chunk_size: 1024 # sequence length for distilling + concat_data: true + cache_dir: 'data/alpaca' # Change this to where you want to save + pretrained_model_config: # will be updated based on model_config + pretrained_model_name_or_path: 'meta-llama/Meta-Llama-3.1-8B' + cache_dir: '/data_persistent2/sim_data/llama-3_1-8b/' + preprocess_config: null + +dataloader: + batch_size: 1 + num_workers: 2 + drop_last: false + pin_memory: true + +optimizer: + optim: adamw_torch_fused + lr: 0.01 + weight_decay: 0.0 + +lr_scheduler: + lr_scheduler_type: reduce_lr_on_plateau + mode: min + factor: 0.1 + patience: 10 + min_lr: 0.00001 + +trainer: # HuggingFace Trainer-like arguments + name: distill_attention_xent_mse + reverse_kl: false + mse_factor: 1000 + xent_factor: 1 + + bf16: true + train_split: train + val_split: validation + num_train_epochs: 2 + gradient_accumulation_steps: 8 + seed: 42 + batch_size: 1 + load_best_model_at_end: true + greater_is_better: false + metric_for_best_model: distill/eval/loss + logging_steps: 100 + evaluation_strategy: steps + max_steps: -1 + eval_steps: 100 + max_eval_batches: null diff --git a/configs/experiment/eval_alpaca_clean.yaml b/configs/experiment/eval_alpaca_clean.yaml new file mode 100644 index 0000000000000000000000000000000000000000..af1005ae42f46dd02c790ba2df988ec1cf65d7b6 --- /dev/null +++ b/configs/experiment/eval_alpaca_clean.yaml @@ -0,0 +1,56 @@ +dataset: + name: alpaca_clean + dataset_config: + name: alpaca + path: yahma/alpaca-cleaned + chunk_size: 1024 # sequence length for distilling + concat_data: true + cache_dir: 'data/alpaca' # Change this to where you want to save + pretrained_model_config: + pretrained_model_name_or_path: 'mistralai/Mistral-7B-v0.1' # will be updated based on model_config + cache_dir: '/scratch/' + preprocess_config: null + +dataloader: + batch_size: 1 + num_workers: 2 + drop_last: false + pin_memory: true + +optimizer: + optim: adamw_torch_fused + lr: 1e-4 + weight_decay: 0.0 + +lr_scheduler: + lr_scheduler_type: reduce_lr_on_plateau + mode: min + factor: 0.1 + patience: 10 + min_lr: 0.00001 + +trainer: # HuggingFace Trainer-like arguments + name: finetune_seq2seq + bf16: true + train_split: train + val_split: test + num_train_epochs: 2 + gradient_accumulation_steps: 8 + seed: 42 + batch_size: 1 + load_best_model_at_end: true + greater_is_better: true + metric_for_best_model: eval/rouge/geometric_mean + logging_steps: 100 + evaluation_strategy: steps + max_steps: -1 + eval_steps: 100 + max_eval_batches: null + +finetune: + method: lora + kwargs: + r: 8 + lora_alpha: 16 + lora_dropout: 0 # 0.05 + target_modules: ['q_proj', 'k_proj', 'v_proj', 'o_proj'] \ No newline at end of file diff --git a/configs/experiment/finetune_lora_fqkvo_alpaca_clean.yaml b/configs/experiment/finetune_lora_fqkvo_alpaca_clean.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6d94d5f8acfd89e98054ccdea483d069df7562da --- /dev/null +++ b/configs/experiment/finetune_lora_fqkvo_alpaca_clean.yaml @@ -0,0 +1,58 @@ +dataset: + name: alpaca_clean + dataset_config: + name: default + path: yahma/alpaca-cleaned + chunk_size: 1024 + concat_data: true + cache_dir: "data/alpaca" + pretrained_model_config: + pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1" # will be updated based on model_config + cache_dir: "/data_persistent2/sim_data/" + preprocess_config: null + +dataloader: + batch_size: 1 + num_workers: 2 + drop_last: false + pin_memory: true + +optimizer: + optim: adamw_torch_fused + lr: 1e-4 + weight_decay: 0.0 + +lr_scheduler: + lr_scheduler_type: reduce_lr_on_plateau + mode: min + factor: 0.1 + patience: 10 + min_lr: 0.00001 + +trainer: # HuggingFace Trainer-like arguments + name: default_lm + bf16: true + train_split: train + val_split: validation + num_train_epochs: 2 + gradient_accumulation_steps: 8 + seed: 42 + batch_size: 1 + load_best_model_at_end: true + greater_is_better: false + metric_for_best_model: eval/loss # eval/rouge/geometric_mean + logging_steps: 100 + evaluation_strategy: steps + max_steps: -1 + eval_steps: 100 + max_eval_batches: null + num_save_ckpt_steps: 200 + +finetune: + method: lora + kwargs: + r: 8 + lora_alpha: 16 + lora_dropout: 0 # 0.05 + target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] + trainable_weights: ['feature_map_q.mlp.layer', 'feature_map_k.mlp.layer', 'window_factors'] diff --git a/configs/experiment/finetune_lora_qkvo_alpaca_clean.yaml b/configs/experiment/finetune_lora_qkvo_alpaca_clean.yaml new file mode 100644 index 0000000000000000000000000000000000000000..10323a0afd698831fdbaaa9783b696fa6e08ab1b --- /dev/null +++ b/configs/experiment/finetune_lora_qkvo_alpaca_clean.yaml @@ -0,0 +1,56 @@ +dataset: + name: alpaca_clean + dataset_config: + name: default + path: yahma/alpaca-cleaned + chunk_size: 1024 + concat_data: true + cache_dir: "data/alpaca" + pretrained_model_config: + pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1" # will be updated based on model_config + cache_dir: "/scratch/" + preprocess_config: null + +dataloader: + batch_size: 1 + num_workers: 2 + drop_last: false + pin_memory: true + +optimizer: + optim: adamw_torch_fused + lr: 1e-4 + weight_decay: 0.0 + +lr_scheduler: + lr_scheduler_type: reduce_lr_on_plateau + mode: min + factor: 0.1 + patience: 10 + min_lr: 0.00001 + +trainer: # HuggingFace Trainer-like arguments + name: default_lm + bf16: true + train_split: train + val_split: validation + num_train_epochs: 2 + gradient_accumulation_steps: 8 + seed: 42 + batch_size: 1 + load_best_model_at_end: true + greater_is_better: false + metric_for_best_model: eval/loss # eval/rouge/geometric_mean + logging_steps: 100 + evaluation_strategy: steps + max_steps: -1 + eval_steps: 100 + max_eval_batches: null + +finetune: + method: lora + kwargs: + r: 8 + lora_alpha: 16 + lora_dropout: 0 # 0.05 + target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] diff --git a/configs/experiment/no_distill_alpaca_clean.yaml b/configs/experiment/no_distill_alpaca_clean.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4b57c3411b47219c97d28992bea043eea3d67661 --- /dev/null +++ b/configs/experiment/no_distill_alpaca_clean.yaml @@ -0,0 +1,29 @@ +dataset: + name: alpaca_clean + dataset_config: + name: alpaca + path: yahma/alpaca-cleaned + chunk_size: 1024 # sequence length for distilling + concat_data: true + cache_dir: 'data/alpaca' # Change this to where you want to save + pretrained_model_config: + pretrained_model_name_or_path: 'mistralai/Mistral-7B-v0.1' # will be updated based on model_config + cache_dir: '/scr-ssd/mzhang/models/mistral-v0.1' + preprocess_config: null + +dataloader: + batch_size: 1 + num_workers: 2 + drop_last: false + pin_memory: true + +optimizer: + optim: adamw_torch_fused + lr: 0.01 + weight_decay: 0.0 + +lr_scheduler: + lr_scheduler_type: none + +trainer: # HuggingFace Trainer-like arguments + name: null diff --git a/configs/model/base_llama3_1_8b.yaml b/configs/model/base_llama3_1_8b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8c6292dc98c337697d4762590f8aa5c5ee2e3f2c --- /dev/null +++ b/configs/model/base_llama3_1_8b.yaml @@ -0,0 +1,15 @@ +name: llama +model: + pretrained_model_name_or_path: 'meta-llama/Meta-Llama-3.1-8B' + cache_dir: '/scr-ssd/mzhang/models/llama3' # Set this to where you want to save checkpoint weights + return_dict: true + load_in_8bit: false + load_in_4bit: false + device_map: auto + low_cpu_mem_usage: true + torch_dtype: bfloat16 + attn_implementation: flash_attention_2 + rope_theta: 500000.0 + +attention: + attention_type: softmax diff --git a/configs/model/base_llama3_8b.yaml b/configs/model/base_llama3_8b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ae172d05d9fe099871dfb8b3515b4169bec3e760 --- /dev/null +++ b/configs/model/base_llama3_8b.yaml @@ -0,0 +1,15 @@ +name: llama +model: + pretrained_model_name_or_path: 'meta-llama/Meta-Llama-3-8B' + cache_dir: '/scr-ssd/mzhang/models/llama3' # Set this to where you want to save checkpoint weights + return_dict: true + load_in_8bit: false + load_in_4bit: false + device_map: auto + low_cpu_mem_usage: true + torch_dtype: bfloat16 + attn_implementation: flash_attention_2 + rope_theta: 500000.0 + +attention: + attention_type: softmax diff --git a/configs/model/base_mistral_7b.yaml b/configs/model/base_mistral_7b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cf58f66880f690e0ee82c088202eb985b29d5088 --- /dev/null +++ b/configs/model/base_mistral_7b.yaml @@ -0,0 +1,15 @@ +name: llama +model: + pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1" + cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights + return_dict: true + load_in_8bit: false + load_in_4bit: false + device_map: auto + low_cpu_mem_usage: true + torch_dtype: bfloat16 + attn_implementation: flash_attention_2 + rope_theta: 10000.0 + +attention: + attention_type: softmax diff --git a/configs/model/chunked_experimental/distill_long_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml b/configs/model/chunked_experimental/distill_long_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5c21f66525d265dffc5146d3616d39f9371110be --- /dev/null +++ b/configs/model/chunked_experimental/distill_long_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml @@ -0,0 +1,40 @@ +# Experimental config for chunked linear attention +name: llama +model: + pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B" + cache_dir: "/scr-ssd/mzhang/models/llama-3_1-8b" # Set this to where you want to save checkpoint weights + return_dict: true + load_in_8bit: false + load_in_4bit: false + device_map: auto + low_cpu_mem_usage: true + torch_dtype: bfloat16 + attn_implementation: flash_attention_2 + rope_theta: 500000.0 + rope_scaling: + factor: 8.0 + low_freq_factor: 1.0 + high_freq_factor: 4.0 + original_max_position_embeddings: 8192 + rope_type: llama3 + +attention: + attention_type: lolcats_long_llama_window_sw + state_chunk_len: 1024 + window_size: 64 + affine_attention_factors: false + init_window_factor: -2.1972245773362196 + feature_map: softmax_dim + feature_map_kwargs: + eps: 1e-12 + # mlp: null # to set + fullspace: true + layer_idx: null # to set + learned_kernel: untied_head_einsum + learned_kernel_kwargs: + feature_dim: 64 + skip_connection: false + bias: false + zero_init: false + tie_qk_kernels: false + train_qk: false diff --git a/configs/model/chunked_experimental/distill_long_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml b/configs/model/chunked_experimental/distill_long_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7ea8039707a539835c723182000df3c58a68589f --- /dev/null +++ b/configs/model/chunked_experimental/distill_long_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml @@ -0,0 +1,40 @@ +# Experimental config for chunked linear attention +name: llama +model: + pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B" + cache_dir: "/scr-ssd/mzhang/models/llama-3_1-8b" # Set this to where you want to save checkpoint weights + return_dict: true + load_in_8bit: false + load_in_4bit: false + device_map: auto + low_cpu_mem_usage: true + torch_dtype: bfloat16 + attn_implementation: flash_attention_2 + rope_theta: 500000.0 + rope_scaling: + factor: 8.0 + low_freq_factor: 1.0 + high_freq_factor: 4.0 + original_max_position_embeddings: 8192 + rope_type: llama3 + +attention: + attention_type: lolcats_long_llama_window_tk + state_chunk_len: 1024 + window_size: 64 + affine_attention_factors: false + init_window_factor: -2.1972245773362196 + feature_map: softmax_dim + feature_map_kwargs: + eps: 1e-12 + # mlp: null # to set + fullspace: true + layer_idx: null # to set + learned_kernel: untied_head_einsum + learned_kernel_kwargs: + feature_dim: 64 + skip_connection: false + bias: false + zero_init: false + tie_qk_kernels: false + train_qk: false diff --git a/configs/model/chunked_experimental/distill_long_llama3_8b_lk_smd_wsw64_fd64_w01.yaml b/configs/model/chunked_experimental/distill_long_llama3_8b_lk_smd_wsw64_fd64_w01.yaml new file mode 100644 index 0000000000000000000000000000000000000000..75b0283ca45c41ae2ad46caf5eba7c0dbe44b9e7 --- /dev/null +++ b/configs/model/chunked_experimental/distill_long_llama3_8b_lk_smd_wsw64_fd64_w01.yaml @@ -0,0 +1,34 @@ +# Experimental config for chunked linear attention +name: llama +model: + pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B" + cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights + return_dict: true + load_in_8bit: false + load_in_4bit: false + device_map: auto + low_cpu_mem_usage: true + torch_dtype: bfloat16 + attn_implementation: flash_attention_2 + rope_theta: 500000.0 + +attention: + attention_type: lolcats_long_llama_window_sw + state_chunk_len: 1024 + window_size: 64 + affine_attention_factors: false + init_window_factor: -2.1972245773362196 + feature_map: softmax_dim + feature_map_kwargs: + eps: 1e-12 + # mlp: null # to set + fullspace: true + layer_idx: null # to set + learned_kernel: untied_head_einsum + learned_kernel_kwargs: + feature_dim: 64 + skip_connection: false + bias: false + zero_init: false + tie_qk_kernels: false + train_qk: false diff --git a/configs/model/chunked_experimental/distill_long_llama3_8b_lk_smd_wtk64_fd64_w01.yaml b/configs/model/chunked_experimental/distill_long_llama3_8b_lk_smd_wtk64_fd64_w01.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cdfbc73e0bd23b040e415dcdc804046aad335d99 --- /dev/null +++ b/configs/model/chunked_experimental/distill_long_llama3_8b_lk_smd_wtk64_fd64_w01.yaml @@ -0,0 +1,34 @@ +# Experimental config for chunked linear attention +name: llama +model: + pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B" + cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights + return_dict: true + load_in_8bit: false + load_in_4bit: false + device_map: auto + low_cpu_mem_usage: true + torch_dtype: bfloat16 + attn_implementation: flash_attention_2 + rope_theta: 500000.0 + +attention: + attention_type: lolcats_long_llama_window_tk + state_chunk_len: 1024 + window_size: 64 + affine_attention_factors: false + init_window_factor: -2.1972245773362196 + feature_map: softmax_dim + feature_map_kwargs: + eps: 1e-12 + # mlp: null # to set + fullspace: true + layer_idx: null # to set + learned_kernel: untied_head_einsum + learned_kernel_kwargs: + feature_dim: 64 + skip_connection: false + bias: false + zero_init: false + tie_qk_kernels: false + train_qk: false diff --git a/configs/model/chunked_experimental/distill_long_mistral_7b_lk_smd_wsw64_fd64_w01.yaml b/configs/model/chunked_experimental/distill_long_mistral_7b_lk_smd_wsw64_fd64_w01.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5518f4ecf40c5099e208d0ae8722f6b1c1557dc4 --- /dev/null +++ b/configs/model/chunked_experimental/distill_long_mistral_7b_lk_smd_wsw64_fd64_w01.yaml @@ -0,0 +1,36 @@ +# Experimental config for chunked linear attention +name: llama +model: + pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1" + cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights + return_dict: true + load_in_8bit: false + load_in_4bit: false + device_map: auto + low_cpu_mem_usage: true + torch_dtype: bfloat16 + attn_implementation: flash_attention_2 # eager # so we can load attention weights + rope_theta: 10000.0 + +attention: + attention_type: lolcats_long_llama_window_sw + state_chunk_len: 512 # 1024 + window_size: 64 + affine_attention_factors: false + init_window_factor: -2.1972245773362196 + train_window_factor: true + train_attention_weights: false + feature_map: softmax_dim + feature_map_kwargs: + eps: 1e-12 + # mlp: null # to set + fullspace: true + layer_idx: null # to set + learned_kernel: untied_head_einsum + learned_kernel_kwargs: + feature_dim: 64 + skip_connection: false + bias: false + zero_init: false + tie_qk_kernels: false + train_qk: false diff --git a/configs/model/chunked_experimental/distill_long_mistral_7b_lk_smd_wtk64_fd64_w01.yaml b/configs/model/chunked_experimental/distill_long_mistral_7b_lk_smd_wtk64_fd64_w01.yaml new file mode 100644 index 0000000000000000000000000000000000000000..15260921af4f492131ae6d9fd2068c85f190235a --- /dev/null +++ b/configs/model/chunked_experimental/distill_long_mistral_7b_lk_smd_wtk64_fd64_w01.yaml @@ -0,0 +1,35 @@ +name: llama +model: + pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1" + cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights + return_dict: true + load_in_8bit: false + load_in_4bit: false + device_map: auto + low_cpu_mem_usage: true + torch_dtype: bfloat16 + attn_implementation: flash_attention_2 # eager # so we can load attention weights + rope_theta: 10000.0 + +attention: + attention_type: lolcats_long_llama_window_tk + state_chunk_len: 512 # 1024 + window_size: 64 + affine_attention_factors: false + init_window_factor: -2.1972245773362196 + train_window_factor: true + train_attention_weights: false + feature_map: softmax_dim + feature_map_kwargs: + eps: 1e-12 + # mlp: null # to set + fullspace: true + layer_idx: null # to set + learned_kernel: untied_head_einsum + learned_kernel_kwargs: + feature_dim: 64 + skip_connection: false + bias: false + zero_init: false + tie_qk_kernels: false + train_qk: false diff --git a/configs/model/distill_llama3_1_8b_lk_smd_fd64.yaml b/configs/model/distill_llama3_1_8b_lk_smd_fd64.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fa4547fb5db9a1c6082eabb85271efc23af2492f --- /dev/null +++ b/configs/model/distill_llama3_1_8b_lk_smd_fd64.yaml @@ -0,0 +1,35 @@ +name: llama +model: + pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B" + cache_dir: "/scr-ssd/mzhang/models/llama-3_1-8b" # Set this to where you want to save checkpoint weights + return_dict: true + load_in_8bit: false + load_in_4bit: false + device_map: auto + low_cpu_mem_usage: true + torch_dtype: bfloat16 + attn_implementation: eager + rope_theta: 500000.0 + rope_scaling: + factor: 8.0 + low_freq_factor: 1.0 + high_freq_factor: 4.0 + original_max_position_embeddings: 8192 + rope_type: llama3 + +attention: + attention_type: lolcats_llama + feature_map: softmax_dim + feature_map_kwargs: + eps: 1e-12 + # mlp: null # to set + fullspace: true + layer_idx: null # to set + learned_kernel: untied_head_einsum + learned_kernel_kwargs: + feature_dim: 64 + skip_connection: false + bias: false + zero_init: false + tie_qk_kernels: false + train_qk: false diff --git a/configs/model/distill_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml b/configs/model/distill_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0313f3ab45936a308bc0c775e3ff5a65bf5627ca --- /dev/null +++ b/configs/model/distill_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml @@ -0,0 +1,39 @@ +name: llama +model: + pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B" + cache_dir: "/scratch/" # Set this to where you want to save checkpoint weights + return_dict: true + load_in_8bit: false + load_in_4bit: false + device_map: auto + low_cpu_mem_usage: true + torch_dtype: bfloat16 + attn_implementation: eager + rope_theta: 500000.0 + rope_scaling: + factor: 8.0 + low_freq_factor: 1.0 + high_freq_factor: 4.0 + original_max_position_embeddings: 8192 + rope_type: llama3 + +attention: + attention_type: lolcats_llama_window_sw + state_chunk_len: 1024 + window_size: 64 + affine_attention_factors: false + init_window_factor: -2.1972245773362196 + feature_map: softmax_dim + feature_map_kwargs: + eps: 1e-12 + # mlp: null # to set + fullspace: true + layer_idx: null # to set + learned_kernel: untied_head_einsum + learned_kernel_kwargs: + feature_dim: 64 + skip_connection: false + bias: false + zero_init: false + tie_qk_kernels: false + train_qk: false diff --git a/configs/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml b/configs/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7bc9a4dc56c3705bdc445c2a42aa99deeac0dd9f --- /dev/null +++ b/configs/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml @@ -0,0 +1,39 @@ +name: llama +model: + pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B" + cache_dir: "/scratch/" # Set this to where you want to save checkpoint weights + return_dict: true + load_in_8bit: false + load_in_4bit: false + device_map: auto + low_cpu_mem_usage: true + torch_dtype: bfloat16 + attn_implementation: eager + rope_theta: 500000.0 + rope_scaling: + factor: 8.0 + low_freq_factor: 1.0 + high_freq_factor: 4.0 + original_max_position_embeddings: 8192 + rope_type: llama3 + +attention: + attention_type: lolcats_llama_window_tk + state_chunk_len: 1024 + window_size: 64 + affine_attention_factors: false + init_window_factor: -2.1972245773362196 + feature_map: softmax_dim + feature_map_kwargs: + eps: 1e-12 + # mlp: null # to set + fullspace: true + layer_idx: null # to set + learned_kernel: untied_head_einsum + learned_kernel_kwargs: + feature_dim: 64 + skip_connection: false + bias: false + zero_init: false + tie_qk_kernels: false + train_qk: false diff --git a/configs/model/distill_llama3_1_8b_lk_t2r.yaml b/configs/model/distill_llama3_1_8b_lk_t2r.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a678d65cd54b1f55f6f86503298378257ff6b4ba --- /dev/null +++ b/configs/model/distill_llama3_1_8b_lk_t2r.yaml @@ -0,0 +1,35 @@ +name: llama +model: + pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B" + cache_dir: "/scr-ssd/mzhang/models/llama-3_1-8b" # Set this to where you want to save checkpoint weights + return_dict: true + load_in_8bit: false + load_in_4bit: false + device_map: auto + low_cpu_mem_usage: true + torch_dtype: bfloat16 + attn_implementation: eager + rope_theta: 500000.0 + rope_scaling: + factor: 8.0 + low_freq_factor: 1.0 + high_freq_factor: 4.0 + original_max_position_embeddings: 8192 + rope_type: llama3 + +attention: + attention_type: lolcats_llama + feature_map: relu + feature_map_kwargs: + eps: 1e-12 + # mlp: null # to set + fullspace: true + layer_idx: null # to set + learned_kernel: untied_head_einsum + learned_kernel_kwargs: + feature_dim: 128 + skip_connection: false + bias: true + zero_init: false + tie_qk_kernels: false + train_qk: false diff --git a/configs/model/distill_llama3_8b_lk_smd_fd64.yaml b/configs/model/distill_llama3_8b_lk_smd_fd64.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4684a8eb3235ad6e77dc63c25240a5a9b31bfe7b --- /dev/null +++ b/configs/model/distill_llama3_8b_lk_smd_fd64.yaml @@ -0,0 +1,29 @@ +name: llama +model: + pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B" + cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights + return_dict: true + load_in_8bit: false + load_in_4bit: false + device_map: auto + low_cpu_mem_usage: true + torch_dtype: bfloat16 + attn_implementation: flash_attention_2 + rope_theta: 500000.0 + +attention: + attention_type: lolcats_llama + feature_map: softmax_dim + feature_map_kwargs: + eps: 1e-12 + # mlp: null # to set + fullspace: true + layer_idx: null # to set + learned_kernel: untied_head_einsum + learned_kernel_kwargs: + feature_dim: 64 + skip_connection: false + bias: false + zero_init: false + tie_qk_kernels: false + train_qk: false diff --git a/configs/model/distill_llama3_8b_lk_smd_wsw64_fd64_w01.yaml b/configs/model/distill_llama3_8b_lk_smd_wsw64_fd64_w01.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7fb2e70160b380494aad531a34060bb27a6d36a1 --- /dev/null +++ b/configs/model/distill_llama3_8b_lk_smd_wsw64_fd64_w01.yaml @@ -0,0 +1,33 @@ +name: llama +model: + pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B" + cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights + return_dict: true + load_in_8bit: false + load_in_4bit: false + device_map: auto + low_cpu_mem_usage: true + torch_dtype: bfloat16 + attn_implementation: flash_attention_2 + rope_theta: 500000.0 + +attention: + attention_type: lolcats_llama_window_sw + state_chunk_len: 1024 + window_size: 64 + affine_attention_factors: false + init_window_factor: -2.1972245773362196 + feature_map: softmax_dim + feature_map_kwargs: + eps: 1e-12 + # mlp: null # to set + fullspace: true + layer_idx: null # to set + learned_kernel: untied_head_einsum + learned_kernel_kwargs: + feature_dim: 64 + skip_connection: false + bias: false + zero_init: false + tie_qk_kernels: false + train_qk: false diff --git a/configs/model/distill_llama3_8b_lk_smd_wtk64_fd64_w01.yaml b/configs/model/distill_llama3_8b_lk_smd_wtk64_fd64_w01.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b86ae07b6040a4ce1011d32c0cce32bd17c348d5 --- /dev/null +++ b/configs/model/distill_llama3_8b_lk_smd_wtk64_fd64_w01.yaml @@ -0,0 +1,33 @@ +name: llama +model: + pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B" + cache_dir: '/scr-ssd/mzhang/models/llama3' # Set this to where you want to save checkpoint weights + return_dict: true + load_in_8bit: false + load_in_4bit: false + device_map: auto + low_cpu_mem_usage: true + torch_dtype: bfloat16 + attn_implementation: flash_attention_2 + rope_theta: 500000.0 + +attention: + attention_type: lolcats_llama_window_tk + state_chunk_len: 1024 + window_size: 64 + affine_attention_factors: false + init_window_factor: -2.1972245773362196 + feature_map: softmax_dim + feature_map_kwargs: + eps: 1e-12 + # mlp: null # to set + fullspace: true + layer_idx: null # to set + learned_kernel: untied_head_einsum + learned_kernel_kwargs: + feature_dim: 64 + skip_connection: false + bias: false + zero_init: false + tie_qk_kernels: false + train_qk: false diff --git a/configs/model/distill_llama3_8b_lk_t2r.yaml b/configs/model/distill_llama3_8b_lk_t2r.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5d79ec8799dcbfbf29f060539266f0f0c3f2b8a7 --- /dev/null +++ b/configs/model/distill_llama3_8b_lk_t2r.yaml @@ -0,0 +1,29 @@ +name: llama +model: + pretrained_model_name_or_path: "meta-llama/Meta-Llama-3-8B" + cache_dir: "/scr-ssd/mzhang/models/llama3" # Set this to where you want to save checkpoint weights + return_dict: true + load_in_8bit: false + load_in_4bit: false + device_map: auto + low_cpu_mem_usage: true + torch_dtype: bfloat16 + attn_implementation: flash_attention_2 + rope_theta: 500000.0 + +attention: + attention_type: lolcats_llama + feature_map: relu + feature_map_kwargs: + eps: 1e-12 + # mlp: null # to set + fullspace: true + layer_idx: null # to set + learned_kernel: untied_head_einsum + learned_kernel_kwargs: + feature_dim: 128 + skip_connection: false + bias: true + zero_init: false + tie_qk_kernels: false + train_qk: false diff --git a/configs/model/distill_mistral_7b_lk_smd_fd64.yaml b/configs/model/distill_mistral_7b_lk_smd_fd64.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b9b6910758ea2acc1f105303b18519dec1b83232 --- /dev/null +++ b/configs/model/distill_mistral_7b_lk_smd_fd64.yaml @@ -0,0 +1,29 @@ +name: llama +model: + pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1" + cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights + return_dict: true + load_in_8bit: false + load_in_4bit: false + device_map: auto + low_cpu_mem_usage: true + torch_dtype: bfloat16 + attn_implementation: flash_attention_2 # eager # so we can load attention weights + rope_theta: 10000.0 + +attention: + attention_type: lolcats_llama + feature_map: softmax_dim + feature_map_kwargs: + eps: 1e-12 + # mlp: null # to set + fullspace: true + layer_idx: null # to set + learned_kernel: untied_head_einsum + learned_kernel_kwargs: + feature_dim: 64 + skip_connection: false + bias: false + zero_init: false + tie_qk_kernels: false + train_qk: false diff --git a/configs/model/distill_mistral_7b_lk_smd_wsw64_fd64_w01.yaml b/configs/model/distill_mistral_7b_lk_smd_wsw64_fd64_w01.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3b063bb9e5ac3105e73b921d6ce59954ec5c995c --- /dev/null +++ b/configs/model/distill_mistral_7b_lk_smd_wsw64_fd64_w01.yaml @@ -0,0 +1,35 @@ +name: llama +model: + pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1" + cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights + return_dict: true + load_in_8bit: false + load_in_4bit: false + device_map: auto + low_cpu_mem_usage: true + torch_dtype: bfloat16 + attn_implementation: flash_attention_2 # eager # so we can load attention weights + rope_theta: 10000.0 + +attention: + attention_type: lolcats_llama_window_sw + state_chunk_len: 512 # 1024 + window_size: 64 + affine_attention_factors: false + init_window_factor: -2.1972245773362196 + train_window_factor: true + train_attention_weights: false + feature_map: softmax_dim + feature_map_kwargs: + eps: 1e-12 + # mlp: null # to set + fullspace: true + layer_idx: null # to set + learned_kernel: untied_head_einsum + learned_kernel_kwargs: + feature_dim: 64 + skip_connection: false + bias: false + zero_init: false + tie_qk_kernels: false + train_qk: false diff --git a/configs/model/distill_mistral_7b_lk_smd_wtk64_fd64_w01.yaml b/configs/model/distill_mistral_7b_lk_smd_wtk64_fd64_w01.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ab4c534967dfd76c760b1c039d96d7e9cd74295d --- /dev/null +++ b/configs/model/distill_mistral_7b_lk_smd_wtk64_fd64_w01.yaml @@ -0,0 +1,35 @@ +name: llama +model: + pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1" + cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights + return_dict: true + load_in_8bit: false + load_in_4bit: false + device_map: auto + low_cpu_mem_usage: true + torch_dtype: bfloat16 + attn_implementation: flash_attention_2 # eager # so we can load attention weights + rope_theta: 10000.0 + +attention: + attention_type: lolcats_llama_window_tk + state_chunk_len: 512 # 1024 + window_size: 64 + affine_attention_factors: false + init_window_factor: -2.1972245773362196 + train_window_factor: true + train_attention_weights: false + feature_map: softmax_dim + feature_map_kwargs: + eps: 1e-12 + # mlp: null # to set + fullspace: true + layer_idx: null # to set + learned_kernel: untied_head_einsum + learned_kernel_kwargs: + feature_dim: 64 + skip_connection: false + bias: false + zero_init: false + tie_qk_kernels: false + train_qk: false diff --git a/configs/model/distill_mistral_7b_lk_t2r.yaml b/configs/model/distill_mistral_7b_lk_t2r.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bbcb048eb992c474546072acced94758b9a9b583 --- /dev/null +++ b/configs/model/distill_mistral_7b_lk_t2r.yaml @@ -0,0 +1,29 @@ +name: llama +model: + pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1" + cache_dir: "/scr-ssd/mzhang/models/mistral-7b-v0.1" # Set this to where you want to save checkpoint weights + return_dict: true + load_in_8bit: false + load_in_4bit: false + device_map: auto + low_cpu_mem_usage: true + torch_dtype: bfloat16 + attn_implementation: flash_attention_2 # eager # so we can load attention weights + rope_theta: 10000.0 + +attention: + attention_type: lolcats_llama + feature_map: relu + feature_map_kwargs: + eps: 1e-12 + # mlp: null # to set + fullspace: true + layer_idx: null # to set + learned_kernel: untied_head_einsum + learned_kernel_kwargs: + feature_dim: 128 + skip_connection: false + bias: true + zero_init: false + tie_qk_kernels: false + train_qk: false diff --git a/csrc/__init__.py b/csrc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e5caecd2401c2b061a5ed786aabf828e986af1c9 --- /dev/null +++ b/csrc/__init__.py @@ -0,0 +1,6 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# +from .causal_attention import causal_dot_product diff --git a/csrc/causal_attention.cpp b/csrc/causal_attention.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4fa77571c8fe402d674214458edf52da756174f3 --- /dev/null +++ b/csrc/causal_attention.cpp @@ -0,0 +1,225 @@ +// +// Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +// Written by Angelos Katharopoulos , +// Apoorv Vyas +// + +#include + + +/** + * Compute a*b^T and save it into out. + * + * a \in R^A + * b \in R^B + */ +inline void vvt_dot(float *a, float *b, float *out, int A, int B) { + for (int i=0; i(); + auto ka = keys.accessor(); + auto va = values.accessor(); + auto pa = product.accessor(); + + #pragma omp parallel for collapse(2) + for (int n=0; n(); + for (int l=0; l(); + auto ka = keys.accessor(); + auto va = values.accessor(); + auto ga = grad_out.accessor(); + auto gqa = grad_queries.accessor(); + auto gka = grad_keys.accessor(); + auto gva = grad_values.accessor(); + + #pragma omp parallel for collapse(2) + for (int n=0; n(); + + // Compute the gradient wrt the queries + for (int l=0; l=0; l--) { + vvt_dot( + &qa[n][h][l][0], + &ga[n][h][l][0], + kvp, + E, + M + ); + vmt_dot( + &va[n][h][l][0], + kvp, + &gka[n][h][l][0], + E, + M + ); + vm_dot( + &ka[n][h][l][0], + kvp, + &gva[n][h][l][0], + E, + M + ); + } + } + } +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "causal_dot_product", + &causal_dot_product, + "Compute the weighted sum of values but attending only to previous " + "values." + ); + m.def( + "causal_dot_backward", + &causal_dot_backward, + "Compute the gradient of queries, keys and values given the gradient " + "of causal_dot_product." + ); +} \ No newline at end of file diff --git a/csrc/causal_attention.py b/csrc/causal_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..6865d38d28b277867befa337ed143e72904041b0 --- /dev/null +++ b/csrc/causal_attention.py @@ -0,0 +1,77 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +import torch + +try: + from causal_attention_cuda import causal_dot_product as causal_dot_product_cuda + from causal_attention_cuda import causal_dot_backward as causal_dot_backward_cuda +except ImportError as e: + print(e) + causal_dot_product_cuda = causal_dot_backward_cuda = None + + +class CausalDotProduct(torch.autograd.Function): + """Compute the weighted sum of values but attending only to previous + values.""" + dot = { + # "cpu": causal_dot_product_cpu, + "cuda": causal_dot_product_cuda + } + dot_backward = { + # "cpu": causal_dot_backward_cpu, + "cuda": causal_dot_backward_cuda + } + + @staticmethod + def forward(ctx, Q, K, V): + # Save the inputs for the gradient computation + ctx.save_for_backward(Q, K, V) + + # Create the output tensor + device = Q.device + N, H, L, _ = Q.shape + _, _, _, M = V.shape + product = torch.zeros((N, H, L, M), dtype=Q.dtype, device=device) + + # Actually perform the dot product + CausalDotProduct.dot[device.type]( + Q.data, + K.data, + V.data, + product + ) + # breakpoint() + # CausalDotProduct.dot[device.type](Q.data, K.data, V.data, product) + + return product + + @staticmethod + def backward(ctx, grad_out): + # Extract the saved tensors + Q, K, V = ctx.saved_tensors + + # Allocate memory for the gradients + grad_Q = torch.zeros_like(Q) + grad_K = torch.zeros_like(K) + grad_V = torch.zeros_like(V) + + # Actually compute the gradients + CausalDotProduct.dot_backward[Q.device.type]( + Q.data, + K.data, + V.data, + grad_out, + grad_Q, + grad_K, + grad_V + ) + + return grad_Q, grad_K, grad_V + + +# Alias the autograd functions to python style snake case naming +causal_dot_product = CausalDotProduct.apply \ No newline at end of file diff --git a/csrc/causal_attention_cuda.cu b/csrc/causal_attention_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..e2970e1a1fd92dc2c4844bd3dd69fb22fedc35cc --- /dev/null +++ b/csrc/causal_attention_cuda.cu @@ -0,0 +1,1483 @@ +// +// Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +// Written by Angelos Katharopoulos , +// Apoorv Vyas +// + +// +// For modifications made inside namespace nvidia (authored by jdemouth): +// +// Copyright (c) 2021 NVIDIA CORPORATION. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +// + +#include +#include +#include + +#define ENABLE_NVIDIA_OPTIMIZATIONS + +#ifdef ENABLE_NVIDIA_OPTIMIZATIONS +namespace nvidia { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +constexpr int THREADS_PER_WARP = 32; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +constexpr int LOW_OCCUPANCY_THRESHOLD = 40; // TODO: Make it HW specific (like 1/2 SMs). + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ __host__ int div_up(int m, int n) { + return (m + n-1) / n; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ __host__ int round_up(int m, int n) { + return div_up(m, n) * n; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename T > +struct Lmha_params { + + // The output buffer. Dimensions [B, H, L, M]. + T *out; + + // The input Qs. Dimensions [B, H, L, E]. + const T *q; + // The input Ks. Dimensions [B, H, L, E]. + const T *k; + // The input Vs. Dimensions [B, H, L, M]. + const T *v; + + // The different dimensions. + int B, L, H, E, M; + + // The strides for the different tensors. + int q_stride_B, q_stride_H, q_stride_L; + int k_stride_B, k_stride_H, k_stride_L; + int v_stride_B, v_stride_H, v_stride_L; + int o_stride_B, o_stride_H, o_stride_L; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int E, bool GO_BACKWARD, int WARPS, int COLS_PER_THREAD = 4 > +__global__ __launch_bounds__(WARPS * THREADS_PER_WARP) +void lmha_low_occupancy_kernel(Lmha_params params) { + + // The number of threads per block. + constexpr int THREADS_PER_BLOCK = WARPS * THREADS_PER_WARP; + // The number of rows per thread. + constexpr int ROWS_PER_THREAD = E / THREADS_PER_WARP; + // The number of steps per iteration. + constexpr int COLS_PER_ITER = WARPS * COLS_PER_THREAD; + + // Make sure E is a multiple of the warp size. + static_assert(E % THREADS_PER_WARP == 0, ""); + + // Shared memory to store V/O. + __shared__ float smem_v[COLS_PER_ITER], smem_o[COLS_PER_ITER]; + // Shared memory buffer to performance the reductions. + __shared__ float smem_reds[E * WARPS]; + + // The sequence processed by that block. + const int bi = blockIdx.z; + // The head processed by that block. + const int hi = blockIdx.y; + // The hidden cell in the V/output buffers. + const int vi = blockIdx.x; + + // The linear index of the thread. + const int tidx = threadIdx.x; + + // Decompose the block in warp/lane. + const int warp = tidx / THREADS_PER_WARP; + const int lane = tidx % THREADS_PER_WARP; + + // The base offset loaded by the thread in Q and K. + int offset_q = bi*params.q_stride_B + hi*params.q_stride_H + lane; + int offset_k = bi*params.k_stride_B + hi*params.k_stride_H + lane; + + // If we walk backward, account for the extra offset. + if( GO_BACKWARD ) { + offset_q += (params.L-1)*params.q_stride_L; + offset_k += (params.L-1)*params.k_stride_L; + } + + // Position the warp at the beginning of the proper timestep. + if( GO_BACKWARD ) { + offset_q -= warp*COLS_PER_THREAD*params.q_stride_L; + offset_k -= warp*COLS_PER_THREAD*params.k_stride_L; + } else { + offset_q += warp*COLS_PER_THREAD*params.q_stride_L; + offset_k += warp*COLS_PER_THREAD*params.k_stride_L; + } + + // Determine the base pointers for Q and K. + const float *ptr_q = ¶ms.q[offset_q]; + const float *ptr_k = ¶ms.k[offset_k]; + + // Is a given row valid? + int valid_qk[ROWS_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < ROWS_PER_THREAD; ++ii ) { + valid_qk[ii] = lane + ii*THREADS_PER_WARP < params.E; + } + + // The offset to the position loaded by the thread in V. + int offset_v = bi*params.v_stride_B + hi*params.v_stride_H + vi; + int offset_o = bi*params.o_stride_B + hi*params.o_stride_H + vi; + + // If we walk backward, account for the extra offset. + if( GO_BACKWARD ) { + offset_v += (params.L-1)*params.v_stride_L; + offset_o += (params.L-1)*params.o_stride_L; + } + + // We load/store a strided matrix of COLS_PER_ITER x OUTPUTS_PER_BLOCK. + if( GO_BACKWARD ) { + offset_v -= tidx*params.v_stride_L; + offset_o -= tidx*params.o_stride_L; + } else { + offset_v += tidx*params.v_stride_L; + offset_o += tidx*params.o_stride_L; + } + + // Determine the base pointer for V. + const float *ptr_v = ¶ms.v[offset_v]; + // The output pointer. + float *ptr_o = ¶ms.out[offset_o]; + + // The running KVs. + float running_kv[ROWS_PER_THREAD]; + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + running_kv[ri] = 0.f; + } + + // Iterate over the timesteps. TODO: Use params.loop_count!!! + for( int iter = 0; iter < params.L; iter += COLS_PER_ITER ) { + + // Each thread loads a matrix of elements. + float q[ROWS_PER_THREAD][COLS_PER_THREAD], k[ROWS_PER_THREAD][COLS_PER_THREAD]; + + // Trigger the memory loads for Q and K. + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + + // For Q/K, each warp loads from various timesteps. + int ti = iter + warp*COLS_PER_THREAD; + if( GO_BACKWARD ) { + ti = params.L - 1 - ti; + } + + // Is it a valid access? + int valid; + if( GO_BACKWARD ) { + valid = valid_qk[ri] && ti - ci >= 0; + } else { + valid = valid_qk[ri] && ti + ci < params.L; + } + + // The extra offset to add. + if( GO_BACKWARD ) { + offset_q = ri*THREADS_PER_WARP - ci*params.q_stride_L; + offset_k = ri*THREADS_PER_WARP - ci*params.k_stride_L; + } else { + offset_q = ri*THREADS_PER_WARP + ci*params.q_stride_L; + offset_k = ri*THREADS_PER_WARP + ci*params.k_stride_L; + } + + // Load Q/K if they are valid. + q[ri][ci] = valid ? ptr_q[offset_q] : 0.f; + k[ri][ci] = valid ? ptr_k[offset_k] : 0.f; + } + } + + // For the V tensor, we assign contiguous thread to different loads. So, ti is different. + int ti = iter + tidx; + if( GO_BACKWARD ) { + ti = params.L - 1 - ti; + } + + // Is it a valid access? + int valid_vo = tidx < COLS_PER_ITER; + if( GO_BACKWARD ) { + valid_vo &= ti >= 0; + } else { + valid_vo &= ti < params.L; + } + + // Trigger the loads for V. + float ldg_v = valid_vo ? *ptr_v : 0.f; + + // Move the load pointers. + if( GO_BACKWARD ) { + ptr_q -= COLS_PER_ITER*params.q_stride_L; + ptr_k -= COLS_PER_ITER*params.k_stride_L; + ptr_v -= COLS_PER_ITER*params.v_stride_L; + } else { + ptr_q += COLS_PER_ITER*params.q_stride_L; + ptr_k += COLS_PER_ITER*params.k_stride_L; + ptr_v += COLS_PER_ITER*params.v_stride_L; + } + + // Store to shared memory. + if( tidx < COLS_PER_ITER ) { + smem_v[tidx] = ldg_v; + } + + // Make sure V is in shared memory. + __syncthreads(); + + // Read V from shared memory. + float v[COLS_PER_THREAD]; + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + v[ci] = smem_v[warp*COLS_PER_THREAD + ci]; + } + + // Each thread computes local K*V products. + float kv[ROWS_PER_THREAD][COLS_PER_THREAD]; + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + kv[ri][ci] = 0.f; + } + } + + // Update the K*V^T product. + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + kv[ri][ci] += k[ri][ci] * v[ci]; + } + } + + // We must perform the prefix sums within the thread-block. Start with the thread. + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + #pragma unroll + for( int ci = 1; ci < COLS_PER_THREAD; ++ci ) { + kv[ri][ci] += kv[ri][ci-1]; + } + } + + // Store the partial sums to shared memory. Unless we have no inter-warp reduction to perform. + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + smem_reds[warp*E + ri*THREADS_PER_WARP + lane] = kv[ri][COLS_PER_THREAD-1]; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Each thread deals with one or more column(s) of the matrix. + constexpr int SUMS_PER_THREAD = (E + THREADS_PER_BLOCK-1) / THREADS_PER_BLOCK; + #pragma unroll + for( int ii = 0, idx = tidx; ii < SUMS_PER_THREAD; ++ii, idx += THREADS_PER_BLOCK ) { + if( idx < E ) { + float sum = smem_reds[idx]; + #pragma unroll + for( int jj = 1; jj < WARPS; ++jj ) { + smem_reds[idx + jj*E] = sum += smem_reds[idx + jj*E]; + } + } + } + + // Make sure the reductions are stored in shared memory. + __syncthreads(); + + // Each thread updates his partial products. + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + float sum = running_kv[ri]; + if( warp > 0 ) { + sum += smem_reds[(warp-1)*E + lane + ri*THREADS_PER_WARP]; + } + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + kv[ri][ci] += sum; + } + } + + // Compute the partial output values for that thread. + float sum[COLS_PER_THREAD]; + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + sum[ci] = q[0][ci] * kv[0][ci]; + #pragma unroll + for( int ri = 1; ri < ROWS_PER_THREAD; ++ri ) { + sum[ci] += q[ri][ci] * kv[ri][ci]; + } + } + + // Run the parallel reductions inside the warp. + #pragma unroll + for( int mask = THREADS_PER_WARP / 2; mask >= 1; mask /= 2 ) { + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + sum[ci] += __shfl_xor_sync(uint32_t(-1), sum[ci], mask); + } + } + + // Store the final output to shared memory. + if( lane == 0 ) { + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + smem_o[warp*COLS_PER_THREAD + ci] = sum[ci]; + } + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Store the output. + if( valid_vo ) { + *ptr_o = smem_o[tidx]; + } + + // Each thread updates his running kv. + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + running_kv[ri] += smem_reds[(WARPS-1)*E + lane + ri*THREADS_PER_WARP]; + } + + // Move to next location. + if( GO_BACKWARD ) { + ptr_o -= COLS_PER_ITER*params.o_stride_L; + } else { + ptr_o += COLS_PER_ITER*params.o_stride_L; + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int E, bool GO_BACKWARD, int WARPS > +int lmha_low_occupancy_(const Lmha_params ¶ms) { + + // Make sure we are not going to launch an invalid grid. + if( params.H > 65535 || params.B > 65535 ) { + return 1; + } + + // Prepare the grid and trigger the CUDA kernel. + dim3 grid; + grid.x = params.M; + grid.y = params.H; + grid.z = params.B; + lmha_low_occupancy_kernel<<>>(params); + return 0; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int E, bool GO_BACKWARD > +int lmha_low_occupancy_(const Lmha_params ¶ms, int blocks) { + if( params.M * blocks >= 8*LOW_OCCUPANCY_THRESHOLD ) { + return lmha_low_occupancy_(params); + } else if( params.M * blocks >= 4*LOW_OCCUPANCY_THRESHOLD ) { + return lmha_low_occupancy_(params); + } else { + return lmha_low_occupancy_(params); + } + return 1; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int E, typename Params > +static inline __device__ __host__ int smem_buffer_elts_(const Params ¶ms) { + int M = round_up(params.M, 4); + return 2*E + 2*M; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int E, int THREADS_PER_HEAD, bool GO_BACKWARD > +__global__ +void lmha_kernel(Lmha_params params) { + + // Make sure E is a multiple of 4. + static_assert(E % 4 == 0, ""); + + // The amount of shared memory per buffer (2 buffers for double-buffering). + const int smem_buffer_elts = smem_buffer_elts_(params); + // The M dimension for shared memory. + const int M = round_up(params.M, 4); + + // Shared memory to store Q, K and V. Size is 2*smem_buffer_elts. + extern __shared__ float smem_[]; + + // The various shared memory buffers. + float *smem_q = &smem_[0*E]; + float *smem_k = &smem_[1*E]; + float *smem_v = &smem_[2*E]; + float *smem_o = &smem_[2*E + M]; + + // The index of the shared memory buffer (for double-buffering). + int smem_curr = 0; + + // The sequence processed by that block. + const int bi = blockIdx.y; + // The head processed by that block. + const int hi = blockIdx.x; + + // The linear index of the thread. + const int tidx = threadIdx.x; + + // The offset to the position loaded by the thread in Q. + int offset_q = bi*params.q_stride_B + hi*params.q_stride_H + tidx; + // The offset to the position loaded by the thread in K. + int offset_k = bi*params.k_stride_B + hi*params.k_stride_H + tidx; + + // If we walk backward, account for the extra offset. + if( GO_BACKWARD ) { + offset_q += (params.L-1)*params.q_stride_L; + offset_k += (params.L-1)*params.k_stride_L; + } + + // Determine the base pointers for Q and K. + const float *ptr_q = ¶ms.q[offset_q]; + const float *ptr_k = ¶ms.k[offset_k]; + + // The offset to the position loaded by the thread in V and O. + int offset_v = bi*params.v_stride_B + hi*params.v_stride_H + tidx; + int offset_o = bi*params.o_stride_B + hi*params.o_stride_H + tidx; + + // If we walk backward, account for the extra offset. + if( GO_BACKWARD ) { + offset_v += (params.L-1)*params.v_stride_L; + offset_o += (params.L-1)*params.o_stride_L; + } + + // Determine the base pointers for V. + const float *ptr_v = ¶ms.v[offset_v]; + + // Is it an active Q/K thread? + const int active_qk = tidx < params.E; + + // Trigger the memory loads for Q and K. + float ldg_q = 0.f, ldg_k = 0.f; + if( active_qk ) { + ldg_q = *ptr_q; + ldg_k = *ptr_k; + } + + // Is it an active V thread? + const int active_v = tidx < params.M; + + // Trigger the memory loads for V. + float ldg_v = 0.f; + if( active_v ) { + ldg_v = *ptr_v; + } + + // Move the load pointers. + if( GO_BACKWARD ) { + ptr_q -= params.q_stride_L; + ptr_k -= params.k_stride_L; + ptr_v -= params.v_stride_L; + } else { + ptr_q += params.q_stride_L; + ptr_k += params.k_stride_L; + ptr_v += params.v_stride_L; + } + + // The number of FLOAT4s per head. + constexpr int FLOAT4s_PER_HEAD = E / 4; + // The number of FLOAT4s per thread. + constexpr int FLOAT4s_PER_THREAD = FLOAT4s_PER_HEAD / THREADS_PER_HEAD; + + // The storage for the K*V^T values. + float4 kv[FLOAT4s_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + kv[ii] = make_float4(0.f, 0.f, 0.f, 0.f); + } + + // The output pointer. + float *out_ptr = ¶ms.out[offset_o]; + + // Store to shared memory Q and K. + if( tidx < E ) { + smem_q[smem_curr*smem_buffer_elts + tidx] = ldg_q; + smem_k[smem_curr*smem_buffer_elts + tidx] = ldg_k; + } + + // Store to shared memory V. All threads store valid values. + if( tidx < M ) { + smem_v[smem_curr*smem_buffer_elts + tidx] = ldg_v; + } + + // The position of the thread in the V dimension. + int vo = tidx / THREADS_PER_HEAD; + int vi = tidx % THREADS_PER_HEAD; + + // Iterate over the timesteps. + for( int ti = 0; ti < params.L; ++ti ) { + + // Is it the last iteration? + int is_last = ti == params.L - 1; + + // Trigger the next loads for Q and K. + if( !is_last && active_qk ) { + ldg_q = *ptr_q; + ldg_k = *ptr_k; + } + + // Trigger the next loads for V. + if( !is_last && active_v ) { + ldg_v = *ptr_v; + } + + // Move the load pointers. + if( GO_BACKWARD ) { + ptr_q -= params.q_stride_L; + ptr_k -= params.k_stride_L; + ptr_v -= params.v_stride_L; + } else { + ptr_q += params.q_stride_L; + ptr_k += params.k_stride_L; + ptr_v += params.v_stride_L; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Each thread loads 4 values from K. + float4 k[FLOAT4s_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + int ki = tidx % THREADS_PER_HEAD * 4 + ii * THREADS_PER_HEAD * 4; + k[ii] = *reinterpret_cast(&smem_k[smem_curr*smem_buffer_elts + ki]); + } + + // Each thread loads a single V value. + float v = 0.f; + if( vo < params.M ) { + v = *reinterpret_cast(&smem_v[smem_curr*smem_buffer_elts + vo]); + } + + // Update the K*V^T product. + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + kv[ii].x += k[ii].x * v; + kv[ii].y += k[ii].y * v; + kv[ii].z += k[ii].z * v; + kv[ii].w += k[ii].w * v; + } + + // Load the Q values from shared memory. + float4 q[FLOAT4s_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + int qi = tidx % THREADS_PER_HEAD * 4 + ii * THREADS_PER_HEAD * 4; + q[ii] = *reinterpret_cast(&smem_q[smem_curr*smem_buffer_elts + qi]); + } + + // Compute the partial output value for that thread. + float sum = 0.f; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + sum += q[ii].x * kv[ii].x; + sum += q[ii].y * kv[ii].y; + sum += q[ii].z * kv[ii].z; + sum += q[ii].w * kv[ii].w; + } + + // Finalize the computation of the sum (if we have more than 1 thread per head). + if( THREADS_PER_HEAD > 1 ) { + + // Finalize the sum for each head. + #pragma unroll + for( int mask = THREADS_PER_HEAD / 2; mask >= 1; mask /= 2 ) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Store to shared memory. + if( vo < M && vi == 0 ) { + smem_o[smem_curr*smem_buffer_elts + vo] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Active threads read the data to store. + if( active_v ) { + sum = smem_o[smem_curr*smem_buffer_elts + tidx]; + } + + } // THREADS_PER_HEAD > 1. + + // Store the output. All the threads are active. + if( active_v ) { + *out_ptr = sum; + } + + // Move to next location. + if( GO_BACKWARD ) { + out_ptr -= params.o_stride_L; + } else { + out_ptr += params.o_stride_L; + } + + // Move the shared memory buffer. + smem_curr = (smem_curr + 1) % 2; + + // Store to shared memory for Q and K. + if( !is_last && tidx < E ) { + smem_q[smem_curr*smem_buffer_elts + tidx] = ldg_q; + smem_k[smem_curr*smem_buffer_elts + tidx] = ldg_k; + } + + // Store to shared memory for V. + if( !is_last && tidx < M ) { + smem_v[smem_curr*smem_buffer_elts + tidx] = ldg_v; + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int E, int THREADS_PER_HEAD, bool GO_BACKWARD > +int lmha_(const Lmha_params ¶ms) { + // The M dimension rounded up to 4. + int M = round_up(params.M, 4); + + // The number of threads in the block. + int block = round_up(max(E, M*THREADS_PER_HEAD), 32); + if( block > 512 || params.B > 65535 ) { + return 1; + } + + // Prepare the kernel. + dim3 grid(params.H, params.B); + size_t smem = smem_buffer_elts_(params)*2*sizeof(float); + lmha_kernel<<>>(params); + return 0; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< bool GO_BACKWARD > +int lmha(const Lmha_params ¶ms) { + int blocks = params.B * params.H; + int res = 1; + if( blocks < LOW_OCCUPANCY_THRESHOLD ) { + if( params.E <= 32 ) { + res = lmha_low_occupancy_< 32, GO_BACKWARD>(params, blocks); + } else if( params.E <= 64 ) { + res = lmha_low_occupancy_< 64, GO_BACKWARD>(params, blocks); + } else if( params.E <= 128 ) { + res = lmha_low_occupancy_<128, GO_BACKWARD>(params, blocks); + } else if( params.E <= 256 ) { + res = lmha_low_occupancy_<256, GO_BACKWARD>(params, blocks); + } + } else { + if( params.E <= 32 ) { + res = lmha_< 32, 1, GO_BACKWARD>(params); + } else if( params.E <= 48 ) { + res = lmha_< 48, 1, GO_BACKWARD>(params); + } else if( params.E <= 64 ) { + res = lmha_< 64, 1, GO_BACKWARD>(params); + } else if( params.E <= 128 ) { + res = lmha_<128, 2, GO_BACKWARD>(params); + } else if( params.E <= 256 ) { + res = lmha_<256, 4, GO_BACKWARD>(params); + } + } + return res; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename T > +inline void set_params(Lmha_params ¶ms, + const torch::Tensor q, + const torch::Tensor k, + const torch::Tensor v, + torch::Tensor o) { + + // Define the pointers. + params.out = o.data_ptr(); + params.q = q.data_ptr(); + params.k = k.data_ptr(); + params.v = v.data_ptr(); + + // Define the strides. + params.q_stride_B = (int) q.stride(0); + params.q_stride_H = (int) q.stride(1); + params.q_stride_L = (int) q.stride(2); + params.k_stride_B = (int) k.stride(0); + params.k_stride_H = (int) k.stride(1); + params.k_stride_L = (int) k.stride(2); + params.v_stride_B = (int) v.stride(0); + params.v_stride_H = (int) v.stride(1); + params.v_stride_L = (int) v.stride(2); + params.o_stride_B = (int) o.stride(0); + params.o_stride_H = (int) o.stride(1); + params.o_stride_L = (int) o.stride(2); + + // Extract the dimensions. + int N = q.size(0); + int H = q.size(1); + int L = q.size(2); + int E = q.size(3); + int M = v.size(3); + + params.B = N; + params.L = L; + params.H = H; + params.E = E; + params.M = M; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +int lmha_fwd(const torch::Tensor queries, + const torch::Tensor keys, + const torch::Tensor values, + torch::Tensor product) { + + // Make sure that we are using the correct GPU device + torch::DeviceGuard _guard(queries.device()); + + // Make sure the inner-most dimension of the tensors is packed. + assert(queries.stride(3) == 1); + assert(keys .stride(3) == 1); + assert(values .stride(3) == 1); + assert(product.stride(3) == 1); + + // Extract the dimensions. + int N = queries.size(0); + int H = queries.size(1); + int L = queries.size(2); + int E = queries.size(3); + int M = values.size (3); + + // The structure of params. + Lmha_params params; + set_params(params, queries, keys, values, product); + + // Launch the kernel. + return lmha(params); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename T > +struct Lmha_bwd_params { + + // The output buffer for K. Dimensions [B, H, L, D]. + T *out_k; + // The output buffer for V. Dimensions [B, H, L, D]. + T *out_v; + + // The input Qs. Dimensions [B, H, L, D]. + const T *q; + // The input Ks. Dimensions [B, H, L, D]. + const T *k; + // The input Vs. Dimensions [B, H, L, D]. + const T *v; + // The input Gs. Dimensions [B, H, L, D]. + const T *g; + + // The dimensions. + int B, L, H, M, E; + + // The strides for the input tensors. + int q_stride_B, q_stride_L, q_stride_H; + int k_stride_B, k_stride_L, k_stride_H; + int v_stride_B, v_stride_L, v_stride_H; + int g_stride_B, g_stride_L, g_stride_H; + + // The strides for the outputs. + int out_k_stride_B, out_k_stride_L, out_k_stride_H; + int out_v_stride_B, out_v_stride_L, out_v_stride_H; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int D, int THREADS_PER_HEAD > +__global__ __launch_bounds__(D*THREADS_PER_HEAD*2) +void lmha_bwd_kernel(Lmha_bwd_params params) { + + // Make sure D is a multiple of 4. + static_assert(D % 4 == 0, ""); + + // The shared memory buffers. + __shared__ struct Smem { float qg[2*D], kv[2*D], out_kv[2*D]; } smem_[2]; + + // The index of the shared memory buffer (for double-buffering). + int smem_curr = 0; + + // The sequence processed by that block. + const int bi = blockIdx.y; + // The head processed by that block. + const int hi = blockIdx.x; + + // The linear index of the thread. + const int tidx = threadIdx.x; + + // Split the threads into two slices. + int so = tidx / (D*THREADS_PER_HEAD); + int si = tidx % (D*THREADS_PER_HEAD); + + // The strides for B/L/H for the Q/G tensors. + int qg_stride_B, qg_stride_L, qg_stride_H; + if( so == 0 ) { + qg_stride_B = params.q_stride_B; + qg_stride_L = params.q_stride_L; + qg_stride_H = params.q_stride_H; + } else { + qg_stride_B = params.g_stride_B; + qg_stride_L = params.g_stride_L; + qg_stride_H = params.g_stride_H; + } + + // The strides for B/L/H for the K/V tensors. + int kv_stride_B, kv_stride_L, kv_stride_H; + if( so == 0 ) { + kv_stride_B = params.k_stride_B; + kv_stride_L = params.k_stride_L; + kv_stride_H = params.k_stride_H; + } else { + kv_stride_B = params.v_stride_B; + kv_stride_L = params.v_stride_L; + kv_stride_H = params.v_stride_H; + } + + // The hidden size. + int hidden_size_per_head = 0; + if( so == 0 ) { + hidden_size_per_head = params.E; + } else { + hidden_size_per_head = params.M; + } + + // Where to start reading from. + int offset_qg = bi*qg_stride_B + hi*qg_stride_H + si; + int offset_kv = bi*kv_stride_B + hi*kv_stride_H + si; + + // We walk backward, account for the extra offset. + offset_qg += (params.L-1)*qg_stride_L; + offset_kv += (params.L-1)*kv_stride_L; + + // Determine the base pointers for Q, K, V and G. + const float *ptr_qg = &(so == 0 ? params.q : params.g)[offset_qg]; + const float *ptr_kv = &(so == 0 ? params.k : params.v)[offset_kv]; + + // Is it an active thread? + const int active = si < hidden_size_per_head; + + // Trigger the memory loads for Q, K, V and G. + float ldg_qg = 0.f, ldg_kv = 0.f; + if( active ) { + ldg_qg = *ptr_qg; + ldg_kv = *ptr_kv; + } + + // Move the load pointers (backward). + ptr_qg -= qg_stride_L; + ptr_kv -= kv_stride_L; + + // The number of FLOAT4s per head. + constexpr int FLOAT4s_PER_HEAD = D / 4; + // The number of FLOAT4s per thread. + constexpr int FLOAT4s_PER_THREAD = FLOAT4s_PER_HEAD / THREADS_PER_HEAD; + + // The storage for the G*Q^T or Q^T*G values. + float4 gq[FLOAT4s_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + gq[ii] = make_float4(0.f, 0.f, 0.f, 0.f); + } + + // The strides for B/L/H for the K/V tensors. + int out_kv_stride_B, out_kv_stride_L, out_kv_stride_H; + if( so == 0 ) { + out_kv_stride_B = params.out_k_stride_B; + out_kv_stride_L = params.out_k_stride_L; + out_kv_stride_H = params.out_k_stride_H; + } else { + out_kv_stride_B = params.out_v_stride_B; + out_kv_stride_L = params.out_v_stride_L; + out_kv_stride_H = params.out_v_stride_H; + } + + // Where to start reading from. + int offset_out_kv = bi*out_kv_stride_B + hi*out_kv_stride_H + si; + + // We walk backward, account for the extra offset. + offset_out_kv += (params.L-1)*out_kv_stride_L; + + // The output pointer. + float *ptr_out_kv = &(so == 0 ? params.out_k : params.out_v)[offset_out_kv]; + + // Store to shared memory. + if( si < D ) { + smem_[smem_curr].qg[so*D + si] = ldg_qg; + smem_[smem_curr].kv[so*D + si] = ldg_kv; + } + + // The position of the thread in the output dimension. + int oo = si / THREADS_PER_HEAD % D; + int oi = si % THREADS_PER_HEAD * 4; + + // Iterate over the timesteps. + for( int ti = 0; ti < params.L; ++ti ) { + + // Is it the last iteration? + int is_last = ti == params.L - 1; + + // Trigger the next loads. + if( !is_last && active ) { + ldg_qg = *ptr_qg; + ldg_kv = *ptr_kv; + } + + // Move the load pointers. + ptr_qg -= qg_stride_L; + ptr_kv -= kv_stride_L; + + // Make sure the data is in shared memory. + __syncthreads(); + + // Each thread loads 4 values from G or Q. + float4 g[FLOAT4s_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + float *smem_ptr = &smem_[smem_curr].qg[(so^1)*D + oi]; + g[ii] = *reinterpret_cast(&smem_ptr[ii*THREADS_PER_HEAD*4]); + } + + // Each thread loads a single from Q or G value. + float q = smem_[smem_curr].qg[so*D + oo]; + + // Update the G*Q^T or Q*G^T product. + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + gq[ii].x += g[ii].x * q; + gq[ii].y += g[ii].y * q; + gq[ii].z += g[ii].z * q; + gq[ii].w += g[ii].w * q; + } + + // Load the V or K values from shared memory. + float4 v[FLOAT4s_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + float *smem_ptr = &smem_[smem_curr].kv[(so^1)*D + oi]; + v[ii] = *reinterpret_cast(&smem_ptr[ii*THREADS_PER_HEAD*4]); + } + + // Compute the partial output value for that thread. + float sum = 0.f; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + sum += v[ii].x * gq[ii].x; + sum += v[ii].y * gq[ii].y; + sum += v[ii].z * gq[ii].z; + sum += v[ii].w * gq[ii].w; + } + + // Finalize the computation of the sum (if we have more than 1 thread per head). + if( THREADS_PER_HEAD > 1 ) { + + // Finalize the sum for each head. + #pragma unroll + for( int mask = THREADS_PER_HEAD / 2; mask >= 1; mask /= 2 ) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Store to shared memory. + if( oi == 0 ) { + smem_[smem_curr].out_kv[so*D + oo] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Active threads read the data to store. + if( si < hidden_size_per_head ) { + sum = smem_[smem_curr].out_kv[so*D + si]; + } + + } // THREADS_PER_HEAD > 1. + + // Store the output. All the threads are active. + if( si < hidden_size_per_head ) { + *ptr_out_kv = sum; + } + + // Move to next location. + ptr_out_kv -= out_kv_stride_L; + + // Move the shared memory buffer. + smem_curr = (smem_curr + 1) % 2; + + // Store to shared memory for Q and K. + if( !is_last && si < D ) { + smem_[smem_curr].qg[so*D + si] = ldg_qg; + smem_[smem_curr].kv[so*D + si] = ldg_kv; + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int D, int THREADS_PER_HEAD > +int lmha_bwd_(const Lmha_bwd_params ¶ms) { + int block = D*THREADS_PER_HEAD*2; + if( block >= 1024 || params.B > 65535 ) { + return 1; + } + dim3 grid(params.H, params.B); + lmha_bwd_kernel<<>>(params); + return 0; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +int lmha_bwd(const Lmha_bwd_params ¶ms) { + int blocks = params.B * params.H; + if( blocks < LOW_OCCUPANCY_THRESHOLD ) { + return 1; + } + + int hidden_size_per_head = max(params.E, params.M); + int res = 1; + if( hidden_size_per_head <= 32 ) { + res = lmha_bwd_< 32, 1>(params); + } else if( hidden_size_per_head <= 64 ) { + res = lmha_bwd_< 64, 1>(params); + } else if( hidden_size_per_head <= 128 ) { + res = lmha_bwd_<128, 2>(params); + } else if( hidden_size_per_head <= 256 ) { + res = lmha_bwd_<256, 4>(params); + } + return res; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +int lmha_bwd(const torch::Tensor queries, + const torch::Tensor keys, + const torch::Tensor values, + const torch::Tensor grad_out, + torch::Tensor grad_queries, + torch::Tensor grad_keys, + torch::Tensor grad_values) { + + // Make sure that we are using the correct GPU device + torch::DeviceGuard _guard(queries.device()); + + // Make sure the inner-most dimension of the tensors is packed. + assert(queries .stride(3) == 1); + assert(keys .stride(3) == 1); + assert(values .stride(3) == 1); + assert(grad_out .stride(3) == 1); + assert(grad_queries.stride(3) == 1); + assert(grad_keys .stride(3) == 1); + assert(grad_values .stride(3) == 1); + + // Extract the dimensions. + int N = queries.size(0); + int H = queries.size(1); + int L = queries.size(2); + int E = queries.size(3); + int M = values.size (3); + + // Gradient on Q. + + // The structure of params. + Lmha_params params; + set_params(params, grad_out, values, keys, grad_queries); + + // Launch the kernel. + int res = lmha(params); + if( res ) { + return res; + } + + // Gradient on K and V together. + + Lmha_bwd_params bwd_params; + bwd_params.out_k = grad_keys.data_ptr(); + bwd_params.out_v = grad_values.data_ptr(); + bwd_params.q = queries.data_ptr(); + bwd_params.k = keys.data_ptr(); + bwd_params.v = values.data_ptr(); + bwd_params.g = grad_out.data_ptr(); + + bwd_params.B = N; + bwd_params.L = L; + bwd_params.H = H; + bwd_params.E = E; + bwd_params.M = M; + + bwd_params.q_stride_B = queries.stride(0); + bwd_params.q_stride_H = queries.stride(1); + bwd_params.q_stride_L = queries.stride(2); + bwd_params.k_stride_B = keys.stride(0); + bwd_params.k_stride_H = keys.stride(1); + bwd_params.k_stride_L = keys.stride(2); + bwd_params.v_stride_B = values.stride(0); + bwd_params.v_stride_H = values.stride(1); + bwd_params.v_stride_L = values.stride(2); + bwd_params.g_stride_B = grad_out.stride(0); + bwd_params.g_stride_H = grad_out.stride(1); + bwd_params.g_stride_L = grad_out.stride(2); + + bwd_params.out_k_stride_B = grad_keys.stride(0); + bwd_params.out_k_stride_H = grad_keys.stride(1); + bwd_params.out_k_stride_L = grad_keys.stride(2); + bwd_params.out_v_stride_B = grad_values.stride(0); + bwd_params.out_v_stride_H = grad_values.stride(1); + bwd_params.out_v_stride_L = grad_values.stride(2); + + // Try to run the fused kernel. + int fallback = lmha_bwd(bwd_params); + + // If it failed, fallback on separate kernels for K and V. + if( fallback ) { + + // Gradient on K. + + // Launch the kernel. + set_params(params, values, grad_out, queries, grad_keys); + res = lmha(params); + if( res ) { + return res; + } + + // Gradient on V. + + // Launch the kernel. + set_params(params, keys, queries, grad_out, grad_values); + return lmha(params); + } + + // It worked... + return 0; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace nvidia +#endif // #ifdef ENABLE_NVIDIA_OPTIMIZATIONS + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +typedef torch::PackedTensorAccessor32 float_accessor; + +#define E_BLOCK_SIZE 8 + +__global__ void causal_dot_product_kernel( + const float_accessor queries, + const float_accessor keys, + const float_accessor values, + float_accessor result, + const int N, + const int H, + const int L, + const int E, + const int M +) { + int n = blockIdx.y; + int h = blockIdx.z; + + int e_start = blockIdx.x * E_BLOCK_SIZE; + int m = threadIdx.x % M; + + extern __shared__ float shared_mem[]; + float* shared_kv = shared_mem; + + for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) { + shared_kv[m + e_local * M] = 0; + } + + for (int t=0; t>>( + queries.packed_accessor32(), + keys.packed_accessor32(), + values.packed_accessor32(), + product.packed_accessor32(), + N, H, L, E, M + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void causal_dot_product(const torch::Tensor queries, + const torch::Tensor keys, + const torch::Tensor values, + torch::Tensor product) { +#ifdef ENABLE_NVIDIA_OPTIMIZATIONS + int fallback = nvidia::lmha_fwd(queries, keys, values, product); +#else + int fallback = 1; +#endif + if( fallback ) { + causal_dot_product_(queries, keys, values, product); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define M_BLOCK_SIZE 4 + +// we need shared memory to store +// kv +// Backward direction +// kv_backwards +// Shared memory usage +__global__ void causal_dot_backward_query_key_kernel( + const float_accessor queries, + const float_accessor keys, + const float_accessor values, + const float_accessor grad_out, + float_accessor grad_queries, + float_accessor grad_keys, + int N, + int H, + int L, + int E, + int M +) { + int n = blockIdx.y; + int h = blockIdx.z; + + int m_start = blockIdx.x * M_BLOCK_SIZE; + int e = threadIdx.x % E; + + extern __shared__ float shared_mem[]; + const int shared_kv_size = M_BLOCK_SIZE * E; + float* shared_kv = shared_mem; + float* shared_kv_bw = shared_mem + shared_kv_size; + + for (int m_local = 0; m_local < M_BLOCK_SIZE && m_local + m_start < M; m_local++) { + shared_kv[m_local * E + e] = 0; + shared_kv_bw[m_local * E + e] = 0; + } + + for (int l=0; l>>( + queries.packed_accessor32(), + keys.packed_accessor32(), + values.packed_accessor32(), + grad_out.packed_accessor32(), + grad_queries.packed_accessor32(), + grad_keys.packed_accessor32(), + N, H, L, E, M + ); + + const int blocks_per_sequence_value = (E + E_BLOCK_SIZE - 1) / E_BLOCK_SIZE; + + dim3 blockDimv(M, 1, 1); + dim3 gridDimv(blocks_per_sequence_value, N, H); + const int shared_mem_v_backward = E_BLOCK_SIZE * M * sizeof(float); + causal_dot_backward_value_kernel<<>>( + queries.packed_accessor32(), + keys.packed_accessor32(), + values.packed_accessor32(), + grad_out.packed_accessor32(), + grad_keys.packed_accessor32(), + grad_values.packed_accessor32(), + N, H, L, E, M + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void causal_dot_backward(const torch::Tensor queries, + const torch::Tensor keys, + const torch::Tensor values, + const torch::Tensor grad_out, + torch::Tensor grad_queries, + torch::Tensor grad_keys, + torch::Tensor grad_values) { +#ifdef ENABLE_NVIDIA_OPTIMIZATIONS + int fallback = nvidia::lmha_bwd(queries, + keys, + values, + grad_out, + grad_queries, + grad_keys, + grad_values); +#else + int fallback = 1; +#endif + if( fallback ) { + // Make sure that the gradient tensors are 0. This is needed because the + // bwd pass might have partially executed and filled in some values in + // grad_queries or grad_keys. + // + // This adds a small overhead every time we have to fall back to the old + // kernel for the backward pass. + grad_queries.zero_(); + grad_keys.zero_(); + causal_dot_backward_(queries, keys, values, grad_out, grad_queries, grad_keys, grad_values); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "causal_dot_product", + &causal_dot_product, + "Compute the weighted sum of values but attending only to previous " + "values." + ); + m.def( + "causal_dot_backward", + &causal_dot_backward, + "Compute the gradients for the causal dot product." + ); +} diff --git a/csrc/causal_attention_kv_cuda.cu b/csrc/causal_attention_kv_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..e2970e1a1fd92dc2c4844bd3dd69fb22fedc35cc --- /dev/null +++ b/csrc/causal_attention_kv_cuda.cu @@ -0,0 +1,1483 @@ +// +// Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +// Written by Angelos Katharopoulos , +// Apoorv Vyas +// + +// +// For modifications made inside namespace nvidia (authored by jdemouth): +// +// Copyright (c) 2021 NVIDIA CORPORATION. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +// + +#include +#include +#include + +#define ENABLE_NVIDIA_OPTIMIZATIONS + +#ifdef ENABLE_NVIDIA_OPTIMIZATIONS +namespace nvidia { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +constexpr int THREADS_PER_WARP = 32; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +constexpr int LOW_OCCUPANCY_THRESHOLD = 40; // TODO: Make it HW specific (like 1/2 SMs). + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ __host__ int div_up(int m, int n) { + return (m + n-1) / n; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ __host__ int round_up(int m, int n) { + return div_up(m, n) * n; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename T > +struct Lmha_params { + + // The output buffer. Dimensions [B, H, L, M]. + T *out; + + // The input Qs. Dimensions [B, H, L, E]. + const T *q; + // The input Ks. Dimensions [B, H, L, E]. + const T *k; + // The input Vs. Dimensions [B, H, L, M]. + const T *v; + + // The different dimensions. + int B, L, H, E, M; + + // The strides for the different tensors. + int q_stride_B, q_stride_H, q_stride_L; + int k_stride_B, k_stride_H, k_stride_L; + int v_stride_B, v_stride_H, v_stride_L; + int o_stride_B, o_stride_H, o_stride_L; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int E, bool GO_BACKWARD, int WARPS, int COLS_PER_THREAD = 4 > +__global__ __launch_bounds__(WARPS * THREADS_PER_WARP) +void lmha_low_occupancy_kernel(Lmha_params params) { + + // The number of threads per block. + constexpr int THREADS_PER_BLOCK = WARPS * THREADS_PER_WARP; + // The number of rows per thread. + constexpr int ROWS_PER_THREAD = E / THREADS_PER_WARP; + // The number of steps per iteration. + constexpr int COLS_PER_ITER = WARPS * COLS_PER_THREAD; + + // Make sure E is a multiple of the warp size. + static_assert(E % THREADS_PER_WARP == 0, ""); + + // Shared memory to store V/O. + __shared__ float smem_v[COLS_PER_ITER], smem_o[COLS_PER_ITER]; + // Shared memory buffer to performance the reductions. + __shared__ float smem_reds[E * WARPS]; + + // The sequence processed by that block. + const int bi = blockIdx.z; + // The head processed by that block. + const int hi = blockIdx.y; + // The hidden cell in the V/output buffers. + const int vi = blockIdx.x; + + // The linear index of the thread. + const int tidx = threadIdx.x; + + // Decompose the block in warp/lane. + const int warp = tidx / THREADS_PER_WARP; + const int lane = tidx % THREADS_PER_WARP; + + // The base offset loaded by the thread in Q and K. + int offset_q = bi*params.q_stride_B + hi*params.q_stride_H + lane; + int offset_k = bi*params.k_stride_B + hi*params.k_stride_H + lane; + + // If we walk backward, account for the extra offset. + if( GO_BACKWARD ) { + offset_q += (params.L-1)*params.q_stride_L; + offset_k += (params.L-1)*params.k_stride_L; + } + + // Position the warp at the beginning of the proper timestep. + if( GO_BACKWARD ) { + offset_q -= warp*COLS_PER_THREAD*params.q_stride_L; + offset_k -= warp*COLS_PER_THREAD*params.k_stride_L; + } else { + offset_q += warp*COLS_PER_THREAD*params.q_stride_L; + offset_k += warp*COLS_PER_THREAD*params.k_stride_L; + } + + // Determine the base pointers for Q and K. + const float *ptr_q = ¶ms.q[offset_q]; + const float *ptr_k = ¶ms.k[offset_k]; + + // Is a given row valid? + int valid_qk[ROWS_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < ROWS_PER_THREAD; ++ii ) { + valid_qk[ii] = lane + ii*THREADS_PER_WARP < params.E; + } + + // The offset to the position loaded by the thread in V. + int offset_v = bi*params.v_stride_B + hi*params.v_stride_H + vi; + int offset_o = bi*params.o_stride_B + hi*params.o_stride_H + vi; + + // If we walk backward, account for the extra offset. + if( GO_BACKWARD ) { + offset_v += (params.L-1)*params.v_stride_L; + offset_o += (params.L-1)*params.o_stride_L; + } + + // We load/store a strided matrix of COLS_PER_ITER x OUTPUTS_PER_BLOCK. + if( GO_BACKWARD ) { + offset_v -= tidx*params.v_stride_L; + offset_o -= tidx*params.o_stride_L; + } else { + offset_v += tidx*params.v_stride_L; + offset_o += tidx*params.o_stride_L; + } + + // Determine the base pointer for V. + const float *ptr_v = ¶ms.v[offset_v]; + // The output pointer. + float *ptr_o = ¶ms.out[offset_o]; + + // The running KVs. + float running_kv[ROWS_PER_THREAD]; + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + running_kv[ri] = 0.f; + } + + // Iterate over the timesteps. TODO: Use params.loop_count!!! + for( int iter = 0; iter < params.L; iter += COLS_PER_ITER ) { + + // Each thread loads a matrix of elements. + float q[ROWS_PER_THREAD][COLS_PER_THREAD], k[ROWS_PER_THREAD][COLS_PER_THREAD]; + + // Trigger the memory loads for Q and K. + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + + // For Q/K, each warp loads from various timesteps. + int ti = iter + warp*COLS_PER_THREAD; + if( GO_BACKWARD ) { + ti = params.L - 1 - ti; + } + + // Is it a valid access? + int valid; + if( GO_BACKWARD ) { + valid = valid_qk[ri] && ti - ci >= 0; + } else { + valid = valid_qk[ri] && ti + ci < params.L; + } + + // The extra offset to add. + if( GO_BACKWARD ) { + offset_q = ri*THREADS_PER_WARP - ci*params.q_stride_L; + offset_k = ri*THREADS_PER_WARP - ci*params.k_stride_L; + } else { + offset_q = ri*THREADS_PER_WARP + ci*params.q_stride_L; + offset_k = ri*THREADS_PER_WARP + ci*params.k_stride_L; + } + + // Load Q/K if they are valid. + q[ri][ci] = valid ? ptr_q[offset_q] : 0.f; + k[ri][ci] = valid ? ptr_k[offset_k] : 0.f; + } + } + + // For the V tensor, we assign contiguous thread to different loads. So, ti is different. + int ti = iter + tidx; + if( GO_BACKWARD ) { + ti = params.L - 1 - ti; + } + + // Is it a valid access? + int valid_vo = tidx < COLS_PER_ITER; + if( GO_BACKWARD ) { + valid_vo &= ti >= 0; + } else { + valid_vo &= ti < params.L; + } + + // Trigger the loads for V. + float ldg_v = valid_vo ? *ptr_v : 0.f; + + // Move the load pointers. + if( GO_BACKWARD ) { + ptr_q -= COLS_PER_ITER*params.q_stride_L; + ptr_k -= COLS_PER_ITER*params.k_stride_L; + ptr_v -= COLS_PER_ITER*params.v_stride_L; + } else { + ptr_q += COLS_PER_ITER*params.q_stride_L; + ptr_k += COLS_PER_ITER*params.k_stride_L; + ptr_v += COLS_PER_ITER*params.v_stride_L; + } + + // Store to shared memory. + if( tidx < COLS_PER_ITER ) { + smem_v[tidx] = ldg_v; + } + + // Make sure V is in shared memory. + __syncthreads(); + + // Read V from shared memory. + float v[COLS_PER_THREAD]; + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + v[ci] = smem_v[warp*COLS_PER_THREAD + ci]; + } + + // Each thread computes local K*V products. + float kv[ROWS_PER_THREAD][COLS_PER_THREAD]; + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + kv[ri][ci] = 0.f; + } + } + + // Update the K*V^T product. + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + kv[ri][ci] += k[ri][ci] * v[ci]; + } + } + + // We must perform the prefix sums within the thread-block. Start with the thread. + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + #pragma unroll + for( int ci = 1; ci < COLS_PER_THREAD; ++ci ) { + kv[ri][ci] += kv[ri][ci-1]; + } + } + + // Store the partial sums to shared memory. Unless we have no inter-warp reduction to perform. + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + smem_reds[warp*E + ri*THREADS_PER_WARP + lane] = kv[ri][COLS_PER_THREAD-1]; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Each thread deals with one or more column(s) of the matrix. + constexpr int SUMS_PER_THREAD = (E + THREADS_PER_BLOCK-1) / THREADS_PER_BLOCK; + #pragma unroll + for( int ii = 0, idx = tidx; ii < SUMS_PER_THREAD; ++ii, idx += THREADS_PER_BLOCK ) { + if( idx < E ) { + float sum = smem_reds[idx]; + #pragma unroll + for( int jj = 1; jj < WARPS; ++jj ) { + smem_reds[idx + jj*E] = sum += smem_reds[idx + jj*E]; + } + } + } + + // Make sure the reductions are stored in shared memory. + __syncthreads(); + + // Each thread updates his partial products. + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + float sum = running_kv[ri]; + if( warp > 0 ) { + sum += smem_reds[(warp-1)*E + lane + ri*THREADS_PER_WARP]; + } + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + kv[ri][ci] += sum; + } + } + + // Compute the partial output values for that thread. + float sum[COLS_PER_THREAD]; + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + sum[ci] = q[0][ci] * kv[0][ci]; + #pragma unroll + for( int ri = 1; ri < ROWS_PER_THREAD; ++ri ) { + sum[ci] += q[ri][ci] * kv[ri][ci]; + } + } + + // Run the parallel reductions inside the warp. + #pragma unroll + for( int mask = THREADS_PER_WARP / 2; mask >= 1; mask /= 2 ) { + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + sum[ci] += __shfl_xor_sync(uint32_t(-1), sum[ci], mask); + } + } + + // Store the final output to shared memory. + if( lane == 0 ) { + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + smem_o[warp*COLS_PER_THREAD + ci] = sum[ci]; + } + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Store the output. + if( valid_vo ) { + *ptr_o = smem_o[tidx]; + } + + // Each thread updates his running kv. + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + running_kv[ri] += smem_reds[(WARPS-1)*E + lane + ri*THREADS_PER_WARP]; + } + + // Move to next location. + if( GO_BACKWARD ) { + ptr_o -= COLS_PER_ITER*params.o_stride_L; + } else { + ptr_o += COLS_PER_ITER*params.o_stride_L; + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int E, bool GO_BACKWARD, int WARPS > +int lmha_low_occupancy_(const Lmha_params ¶ms) { + + // Make sure we are not going to launch an invalid grid. + if( params.H > 65535 || params.B > 65535 ) { + return 1; + } + + // Prepare the grid and trigger the CUDA kernel. + dim3 grid; + grid.x = params.M; + grid.y = params.H; + grid.z = params.B; + lmha_low_occupancy_kernel<<>>(params); + return 0; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int E, bool GO_BACKWARD > +int lmha_low_occupancy_(const Lmha_params ¶ms, int blocks) { + if( params.M * blocks >= 8*LOW_OCCUPANCY_THRESHOLD ) { + return lmha_low_occupancy_(params); + } else if( params.M * blocks >= 4*LOW_OCCUPANCY_THRESHOLD ) { + return lmha_low_occupancy_(params); + } else { + return lmha_low_occupancy_(params); + } + return 1; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int E, typename Params > +static inline __device__ __host__ int smem_buffer_elts_(const Params ¶ms) { + int M = round_up(params.M, 4); + return 2*E + 2*M; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int E, int THREADS_PER_HEAD, bool GO_BACKWARD > +__global__ +void lmha_kernel(Lmha_params params) { + + // Make sure E is a multiple of 4. + static_assert(E % 4 == 0, ""); + + // The amount of shared memory per buffer (2 buffers for double-buffering). + const int smem_buffer_elts = smem_buffer_elts_(params); + // The M dimension for shared memory. + const int M = round_up(params.M, 4); + + // Shared memory to store Q, K and V. Size is 2*smem_buffer_elts. + extern __shared__ float smem_[]; + + // The various shared memory buffers. + float *smem_q = &smem_[0*E]; + float *smem_k = &smem_[1*E]; + float *smem_v = &smem_[2*E]; + float *smem_o = &smem_[2*E + M]; + + // The index of the shared memory buffer (for double-buffering). + int smem_curr = 0; + + // The sequence processed by that block. + const int bi = blockIdx.y; + // The head processed by that block. + const int hi = blockIdx.x; + + // The linear index of the thread. + const int tidx = threadIdx.x; + + // The offset to the position loaded by the thread in Q. + int offset_q = bi*params.q_stride_B + hi*params.q_stride_H + tidx; + // The offset to the position loaded by the thread in K. + int offset_k = bi*params.k_stride_B + hi*params.k_stride_H + tidx; + + // If we walk backward, account for the extra offset. + if( GO_BACKWARD ) { + offset_q += (params.L-1)*params.q_stride_L; + offset_k += (params.L-1)*params.k_stride_L; + } + + // Determine the base pointers for Q and K. + const float *ptr_q = ¶ms.q[offset_q]; + const float *ptr_k = ¶ms.k[offset_k]; + + // The offset to the position loaded by the thread in V and O. + int offset_v = bi*params.v_stride_B + hi*params.v_stride_H + tidx; + int offset_o = bi*params.o_stride_B + hi*params.o_stride_H + tidx; + + // If we walk backward, account for the extra offset. + if( GO_BACKWARD ) { + offset_v += (params.L-1)*params.v_stride_L; + offset_o += (params.L-1)*params.o_stride_L; + } + + // Determine the base pointers for V. + const float *ptr_v = ¶ms.v[offset_v]; + + // Is it an active Q/K thread? + const int active_qk = tidx < params.E; + + // Trigger the memory loads for Q and K. + float ldg_q = 0.f, ldg_k = 0.f; + if( active_qk ) { + ldg_q = *ptr_q; + ldg_k = *ptr_k; + } + + // Is it an active V thread? + const int active_v = tidx < params.M; + + // Trigger the memory loads for V. + float ldg_v = 0.f; + if( active_v ) { + ldg_v = *ptr_v; + } + + // Move the load pointers. + if( GO_BACKWARD ) { + ptr_q -= params.q_stride_L; + ptr_k -= params.k_stride_L; + ptr_v -= params.v_stride_L; + } else { + ptr_q += params.q_stride_L; + ptr_k += params.k_stride_L; + ptr_v += params.v_stride_L; + } + + // The number of FLOAT4s per head. + constexpr int FLOAT4s_PER_HEAD = E / 4; + // The number of FLOAT4s per thread. + constexpr int FLOAT4s_PER_THREAD = FLOAT4s_PER_HEAD / THREADS_PER_HEAD; + + // The storage for the K*V^T values. + float4 kv[FLOAT4s_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + kv[ii] = make_float4(0.f, 0.f, 0.f, 0.f); + } + + // The output pointer. + float *out_ptr = ¶ms.out[offset_o]; + + // Store to shared memory Q and K. + if( tidx < E ) { + smem_q[smem_curr*smem_buffer_elts + tidx] = ldg_q; + smem_k[smem_curr*smem_buffer_elts + tidx] = ldg_k; + } + + // Store to shared memory V. All threads store valid values. + if( tidx < M ) { + smem_v[smem_curr*smem_buffer_elts + tidx] = ldg_v; + } + + // The position of the thread in the V dimension. + int vo = tidx / THREADS_PER_HEAD; + int vi = tidx % THREADS_PER_HEAD; + + // Iterate over the timesteps. + for( int ti = 0; ti < params.L; ++ti ) { + + // Is it the last iteration? + int is_last = ti == params.L - 1; + + // Trigger the next loads for Q and K. + if( !is_last && active_qk ) { + ldg_q = *ptr_q; + ldg_k = *ptr_k; + } + + // Trigger the next loads for V. + if( !is_last && active_v ) { + ldg_v = *ptr_v; + } + + // Move the load pointers. + if( GO_BACKWARD ) { + ptr_q -= params.q_stride_L; + ptr_k -= params.k_stride_L; + ptr_v -= params.v_stride_L; + } else { + ptr_q += params.q_stride_L; + ptr_k += params.k_stride_L; + ptr_v += params.v_stride_L; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Each thread loads 4 values from K. + float4 k[FLOAT4s_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + int ki = tidx % THREADS_PER_HEAD * 4 + ii * THREADS_PER_HEAD * 4; + k[ii] = *reinterpret_cast(&smem_k[smem_curr*smem_buffer_elts + ki]); + } + + // Each thread loads a single V value. + float v = 0.f; + if( vo < params.M ) { + v = *reinterpret_cast(&smem_v[smem_curr*smem_buffer_elts + vo]); + } + + // Update the K*V^T product. + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + kv[ii].x += k[ii].x * v; + kv[ii].y += k[ii].y * v; + kv[ii].z += k[ii].z * v; + kv[ii].w += k[ii].w * v; + } + + // Load the Q values from shared memory. + float4 q[FLOAT4s_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + int qi = tidx % THREADS_PER_HEAD * 4 + ii * THREADS_PER_HEAD * 4; + q[ii] = *reinterpret_cast(&smem_q[smem_curr*smem_buffer_elts + qi]); + } + + // Compute the partial output value for that thread. + float sum = 0.f; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + sum += q[ii].x * kv[ii].x; + sum += q[ii].y * kv[ii].y; + sum += q[ii].z * kv[ii].z; + sum += q[ii].w * kv[ii].w; + } + + // Finalize the computation of the sum (if we have more than 1 thread per head). + if( THREADS_PER_HEAD > 1 ) { + + // Finalize the sum for each head. + #pragma unroll + for( int mask = THREADS_PER_HEAD / 2; mask >= 1; mask /= 2 ) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Store to shared memory. + if( vo < M && vi == 0 ) { + smem_o[smem_curr*smem_buffer_elts + vo] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Active threads read the data to store. + if( active_v ) { + sum = smem_o[smem_curr*smem_buffer_elts + tidx]; + } + + } // THREADS_PER_HEAD > 1. + + // Store the output. All the threads are active. + if( active_v ) { + *out_ptr = sum; + } + + // Move to next location. + if( GO_BACKWARD ) { + out_ptr -= params.o_stride_L; + } else { + out_ptr += params.o_stride_L; + } + + // Move the shared memory buffer. + smem_curr = (smem_curr + 1) % 2; + + // Store to shared memory for Q and K. + if( !is_last && tidx < E ) { + smem_q[smem_curr*smem_buffer_elts + tidx] = ldg_q; + smem_k[smem_curr*smem_buffer_elts + tidx] = ldg_k; + } + + // Store to shared memory for V. + if( !is_last && tidx < M ) { + smem_v[smem_curr*smem_buffer_elts + tidx] = ldg_v; + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int E, int THREADS_PER_HEAD, bool GO_BACKWARD > +int lmha_(const Lmha_params ¶ms) { + // The M dimension rounded up to 4. + int M = round_up(params.M, 4); + + // The number of threads in the block. + int block = round_up(max(E, M*THREADS_PER_HEAD), 32); + if( block > 512 || params.B > 65535 ) { + return 1; + } + + // Prepare the kernel. + dim3 grid(params.H, params.B); + size_t smem = smem_buffer_elts_(params)*2*sizeof(float); + lmha_kernel<<>>(params); + return 0; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< bool GO_BACKWARD > +int lmha(const Lmha_params ¶ms) { + int blocks = params.B * params.H; + int res = 1; + if( blocks < LOW_OCCUPANCY_THRESHOLD ) { + if( params.E <= 32 ) { + res = lmha_low_occupancy_< 32, GO_BACKWARD>(params, blocks); + } else if( params.E <= 64 ) { + res = lmha_low_occupancy_< 64, GO_BACKWARD>(params, blocks); + } else if( params.E <= 128 ) { + res = lmha_low_occupancy_<128, GO_BACKWARD>(params, blocks); + } else if( params.E <= 256 ) { + res = lmha_low_occupancy_<256, GO_BACKWARD>(params, blocks); + } + } else { + if( params.E <= 32 ) { + res = lmha_< 32, 1, GO_BACKWARD>(params); + } else if( params.E <= 48 ) { + res = lmha_< 48, 1, GO_BACKWARD>(params); + } else if( params.E <= 64 ) { + res = lmha_< 64, 1, GO_BACKWARD>(params); + } else if( params.E <= 128 ) { + res = lmha_<128, 2, GO_BACKWARD>(params); + } else if( params.E <= 256 ) { + res = lmha_<256, 4, GO_BACKWARD>(params); + } + } + return res; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename T > +inline void set_params(Lmha_params ¶ms, + const torch::Tensor q, + const torch::Tensor k, + const torch::Tensor v, + torch::Tensor o) { + + // Define the pointers. + params.out = o.data_ptr(); + params.q = q.data_ptr(); + params.k = k.data_ptr(); + params.v = v.data_ptr(); + + // Define the strides. + params.q_stride_B = (int) q.stride(0); + params.q_stride_H = (int) q.stride(1); + params.q_stride_L = (int) q.stride(2); + params.k_stride_B = (int) k.stride(0); + params.k_stride_H = (int) k.stride(1); + params.k_stride_L = (int) k.stride(2); + params.v_stride_B = (int) v.stride(0); + params.v_stride_H = (int) v.stride(1); + params.v_stride_L = (int) v.stride(2); + params.o_stride_B = (int) o.stride(0); + params.o_stride_H = (int) o.stride(1); + params.o_stride_L = (int) o.stride(2); + + // Extract the dimensions. + int N = q.size(0); + int H = q.size(1); + int L = q.size(2); + int E = q.size(3); + int M = v.size(3); + + params.B = N; + params.L = L; + params.H = H; + params.E = E; + params.M = M; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +int lmha_fwd(const torch::Tensor queries, + const torch::Tensor keys, + const torch::Tensor values, + torch::Tensor product) { + + // Make sure that we are using the correct GPU device + torch::DeviceGuard _guard(queries.device()); + + // Make sure the inner-most dimension of the tensors is packed. + assert(queries.stride(3) == 1); + assert(keys .stride(3) == 1); + assert(values .stride(3) == 1); + assert(product.stride(3) == 1); + + // Extract the dimensions. + int N = queries.size(0); + int H = queries.size(1); + int L = queries.size(2); + int E = queries.size(3); + int M = values.size (3); + + // The structure of params. + Lmha_params params; + set_params(params, queries, keys, values, product); + + // Launch the kernel. + return lmha(params); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename T > +struct Lmha_bwd_params { + + // The output buffer for K. Dimensions [B, H, L, D]. + T *out_k; + // The output buffer for V. Dimensions [B, H, L, D]. + T *out_v; + + // The input Qs. Dimensions [B, H, L, D]. + const T *q; + // The input Ks. Dimensions [B, H, L, D]. + const T *k; + // The input Vs. Dimensions [B, H, L, D]. + const T *v; + // The input Gs. Dimensions [B, H, L, D]. + const T *g; + + // The dimensions. + int B, L, H, M, E; + + // The strides for the input tensors. + int q_stride_B, q_stride_L, q_stride_H; + int k_stride_B, k_stride_L, k_stride_H; + int v_stride_B, v_stride_L, v_stride_H; + int g_stride_B, g_stride_L, g_stride_H; + + // The strides for the outputs. + int out_k_stride_B, out_k_stride_L, out_k_stride_H; + int out_v_stride_B, out_v_stride_L, out_v_stride_H; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int D, int THREADS_PER_HEAD > +__global__ __launch_bounds__(D*THREADS_PER_HEAD*2) +void lmha_bwd_kernel(Lmha_bwd_params params) { + + // Make sure D is a multiple of 4. + static_assert(D % 4 == 0, ""); + + // The shared memory buffers. + __shared__ struct Smem { float qg[2*D], kv[2*D], out_kv[2*D]; } smem_[2]; + + // The index of the shared memory buffer (for double-buffering). + int smem_curr = 0; + + // The sequence processed by that block. + const int bi = blockIdx.y; + // The head processed by that block. + const int hi = blockIdx.x; + + // The linear index of the thread. + const int tidx = threadIdx.x; + + // Split the threads into two slices. + int so = tidx / (D*THREADS_PER_HEAD); + int si = tidx % (D*THREADS_PER_HEAD); + + // The strides for B/L/H for the Q/G tensors. + int qg_stride_B, qg_stride_L, qg_stride_H; + if( so == 0 ) { + qg_stride_B = params.q_stride_B; + qg_stride_L = params.q_stride_L; + qg_stride_H = params.q_stride_H; + } else { + qg_stride_B = params.g_stride_B; + qg_stride_L = params.g_stride_L; + qg_stride_H = params.g_stride_H; + } + + // The strides for B/L/H for the K/V tensors. + int kv_stride_B, kv_stride_L, kv_stride_H; + if( so == 0 ) { + kv_stride_B = params.k_stride_B; + kv_stride_L = params.k_stride_L; + kv_stride_H = params.k_stride_H; + } else { + kv_stride_B = params.v_stride_B; + kv_stride_L = params.v_stride_L; + kv_stride_H = params.v_stride_H; + } + + // The hidden size. + int hidden_size_per_head = 0; + if( so == 0 ) { + hidden_size_per_head = params.E; + } else { + hidden_size_per_head = params.M; + } + + // Where to start reading from. + int offset_qg = bi*qg_stride_B + hi*qg_stride_H + si; + int offset_kv = bi*kv_stride_B + hi*kv_stride_H + si; + + // We walk backward, account for the extra offset. + offset_qg += (params.L-1)*qg_stride_L; + offset_kv += (params.L-1)*kv_stride_L; + + // Determine the base pointers for Q, K, V and G. + const float *ptr_qg = &(so == 0 ? params.q : params.g)[offset_qg]; + const float *ptr_kv = &(so == 0 ? params.k : params.v)[offset_kv]; + + // Is it an active thread? + const int active = si < hidden_size_per_head; + + // Trigger the memory loads for Q, K, V and G. + float ldg_qg = 0.f, ldg_kv = 0.f; + if( active ) { + ldg_qg = *ptr_qg; + ldg_kv = *ptr_kv; + } + + // Move the load pointers (backward). + ptr_qg -= qg_stride_L; + ptr_kv -= kv_stride_L; + + // The number of FLOAT4s per head. + constexpr int FLOAT4s_PER_HEAD = D / 4; + // The number of FLOAT4s per thread. + constexpr int FLOAT4s_PER_THREAD = FLOAT4s_PER_HEAD / THREADS_PER_HEAD; + + // The storage for the G*Q^T or Q^T*G values. + float4 gq[FLOAT4s_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + gq[ii] = make_float4(0.f, 0.f, 0.f, 0.f); + } + + // The strides for B/L/H for the K/V tensors. + int out_kv_stride_B, out_kv_stride_L, out_kv_stride_H; + if( so == 0 ) { + out_kv_stride_B = params.out_k_stride_B; + out_kv_stride_L = params.out_k_stride_L; + out_kv_stride_H = params.out_k_stride_H; + } else { + out_kv_stride_B = params.out_v_stride_B; + out_kv_stride_L = params.out_v_stride_L; + out_kv_stride_H = params.out_v_stride_H; + } + + // Where to start reading from. + int offset_out_kv = bi*out_kv_stride_B + hi*out_kv_stride_H + si; + + // We walk backward, account for the extra offset. + offset_out_kv += (params.L-1)*out_kv_stride_L; + + // The output pointer. + float *ptr_out_kv = &(so == 0 ? params.out_k : params.out_v)[offset_out_kv]; + + // Store to shared memory. + if( si < D ) { + smem_[smem_curr].qg[so*D + si] = ldg_qg; + smem_[smem_curr].kv[so*D + si] = ldg_kv; + } + + // The position of the thread in the output dimension. + int oo = si / THREADS_PER_HEAD % D; + int oi = si % THREADS_PER_HEAD * 4; + + // Iterate over the timesteps. + for( int ti = 0; ti < params.L; ++ti ) { + + // Is it the last iteration? + int is_last = ti == params.L - 1; + + // Trigger the next loads. + if( !is_last && active ) { + ldg_qg = *ptr_qg; + ldg_kv = *ptr_kv; + } + + // Move the load pointers. + ptr_qg -= qg_stride_L; + ptr_kv -= kv_stride_L; + + // Make sure the data is in shared memory. + __syncthreads(); + + // Each thread loads 4 values from G or Q. + float4 g[FLOAT4s_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + float *smem_ptr = &smem_[smem_curr].qg[(so^1)*D + oi]; + g[ii] = *reinterpret_cast(&smem_ptr[ii*THREADS_PER_HEAD*4]); + } + + // Each thread loads a single from Q or G value. + float q = smem_[smem_curr].qg[so*D + oo]; + + // Update the G*Q^T or Q*G^T product. + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + gq[ii].x += g[ii].x * q; + gq[ii].y += g[ii].y * q; + gq[ii].z += g[ii].z * q; + gq[ii].w += g[ii].w * q; + } + + // Load the V or K values from shared memory. + float4 v[FLOAT4s_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + float *smem_ptr = &smem_[smem_curr].kv[(so^1)*D + oi]; + v[ii] = *reinterpret_cast(&smem_ptr[ii*THREADS_PER_HEAD*4]); + } + + // Compute the partial output value for that thread. + float sum = 0.f; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + sum += v[ii].x * gq[ii].x; + sum += v[ii].y * gq[ii].y; + sum += v[ii].z * gq[ii].z; + sum += v[ii].w * gq[ii].w; + } + + // Finalize the computation of the sum (if we have more than 1 thread per head). + if( THREADS_PER_HEAD > 1 ) { + + // Finalize the sum for each head. + #pragma unroll + for( int mask = THREADS_PER_HEAD / 2; mask >= 1; mask /= 2 ) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Store to shared memory. + if( oi == 0 ) { + smem_[smem_curr].out_kv[so*D + oo] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Active threads read the data to store. + if( si < hidden_size_per_head ) { + sum = smem_[smem_curr].out_kv[so*D + si]; + } + + } // THREADS_PER_HEAD > 1. + + // Store the output. All the threads are active. + if( si < hidden_size_per_head ) { + *ptr_out_kv = sum; + } + + // Move to next location. + ptr_out_kv -= out_kv_stride_L; + + // Move the shared memory buffer. + smem_curr = (smem_curr + 1) % 2; + + // Store to shared memory for Q and K. + if( !is_last && si < D ) { + smem_[smem_curr].qg[so*D + si] = ldg_qg; + smem_[smem_curr].kv[so*D + si] = ldg_kv; + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int D, int THREADS_PER_HEAD > +int lmha_bwd_(const Lmha_bwd_params ¶ms) { + int block = D*THREADS_PER_HEAD*2; + if( block >= 1024 || params.B > 65535 ) { + return 1; + } + dim3 grid(params.H, params.B); + lmha_bwd_kernel<<>>(params); + return 0; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +int lmha_bwd(const Lmha_bwd_params ¶ms) { + int blocks = params.B * params.H; + if( blocks < LOW_OCCUPANCY_THRESHOLD ) { + return 1; + } + + int hidden_size_per_head = max(params.E, params.M); + int res = 1; + if( hidden_size_per_head <= 32 ) { + res = lmha_bwd_< 32, 1>(params); + } else if( hidden_size_per_head <= 64 ) { + res = lmha_bwd_< 64, 1>(params); + } else if( hidden_size_per_head <= 128 ) { + res = lmha_bwd_<128, 2>(params); + } else if( hidden_size_per_head <= 256 ) { + res = lmha_bwd_<256, 4>(params); + } + return res; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +int lmha_bwd(const torch::Tensor queries, + const torch::Tensor keys, + const torch::Tensor values, + const torch::Tensor grad_out, + torch::Tensor grad_queries, + torch::Tensor grad_keys, + torch::Tensor grad_values) { + + // Make sure that we are using the correct GPU device + torch::DeviceGuard _guard(queries.device()); + + // Make sure the inner-most dimension of the tensors is packed. + assert(queries .stride(3) == 1); + assert(keys .stride(3) == 1); + assert(values .stride(3) == 1); + assert(grad_out .stride(3) == 1); + assert(grad_queries.stride(3) == 1); + assert(grad_keys .stride(3) == 1); + assert(grad_values .stride(3) == 1); + + // Extract the dimensions. + int N = queries.size(0); + int H = queries.size(1); + int L = queries.size(2); + int E = queries.size(3); + int M = values.size (3); + + // Gradient on Q. + + // The structure of params. + Lmha_params params; + set_params(params, grad_out, values, keys, grad_queries); + + // Launch the kernel. + int res = lmha(params); + if( res ) { + return res; + } + + // Gradient on K and V together. + + Lmha_bwd_params bwd_params; + bwd_params.out_k = grad_keys.data_ptr(); + bwd_params.out_v = grad_values.data_ptr(); + bwd_params.q = queries.data_ptr(); + bwd_params.k = keys.data_ptr(); + bwd_params.v = values.data_ptr(); + bwd_params.g = grad_out.data_ptr(); + + bwd_params.B = N; + bwd_params.L = L; + bwd_params.H = H; + bwd_params.E = E; + bwd_params.M = M; + + bwd_params.q_stride_B = queries.stride(0); + bwd_params.q_stride_H = queries.stride(1); + bwd_params.q_stride_L = queries.stride(2); + bwd_params.k_stride_B = keys.stride(0); + bwd_params.k_stride_H = keys.stride(1); + bwd_params.k_stride_L = keys.stride(2); + bwd_params.v_stride_B = values.stride(0); + bwd_params.v_stride_H = values.stride(1); + bwd_params.v_stride_L = values.stride(2); + bwd_params.g_stride_B = grad_out.stride(0); + bwd_params.g_stride_H = grad_out.stride(1); + bwd_params.g_stride_L = grad_out.stride(2); + + bwd_params.out_k_stride_B = grad_keys.stride(0); + bwd_params.out_k_stride_H = grad_keys.stride(1); + bwd_params.out_k_stride_L = grad_keys.stride(2); + bwd_params.out_v_stride_B = grad_values.stride(0); + bwd_params.out_v_stride_H = grad_values.stride(1); + bwd_params.out_v_stride_L = grad_values.stride(2); + + // Try to run the fused kernel. + int fallback = lmha_bwd(bwd_params); + + // If it failed, fallback on separate kernels for K and V. + if( fallback ) { + + // Gradient on K. + + // Launch the kernel. + set_params(params, values, grad_out, queries, grad_keys); + res = lmha(params); + if( res ) { + return res; + } + + // Gradient on V. + + // Launch the kernel. + set_params(params, keys, queries, grad_out, grad_values); + return lmha(params); + } + + // It worked... + return 0; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace nvidia +#endif // #ifdef ENABLE_NVIDIA_OPTIMIZATIONS + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +typedef torch::PackedTensorAccessor32 float_accessor; + +#define E_BLOCK_SIZE 8 + +__global__ void causal_dot_product_kernel( + const float_accessor queries, + const float_accessor keys, + const float_accessor values, + float_accessor result, + const int N, + const int H, + const int L, + const int E, + const int M +) { + int n = blockIdx.y; + int h = blockIdx.z; + + int e_start = blockIdx.x * E_BLOCK_SIZE; + int m = threadIdx.x % M; + + extern __shared__ float shared_mem[]; + float* shared_kv = shared_mem; + + for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) { + shared_kv[m + e_local * M] = 0; + } + + for (int t=0; t>>( + queries.packed_accessor32(), + keys.packed_accessor32(), + values.packed_accessor32(), + product.packed_accessor32(), + N, H, L, E, M + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void causal_dot_product(const torch::Tensor queries, + const torch::Tensor keys, + const torch::Tensor values, + torch::Tensor product) { +#ifdef ENABLE_NVIDIA_OPTIMIZATIONS + int fallback = nvidia::lmha_fwd(queries, keys, values, product); +#else + int fallback = 1; +#endif + if( fallback ) { + causal_dot_product_(queries, keys, values, product); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define M_BLOCK_SIZE 4 + +// we need shared memory to store +// kv +// Backward direction +// kv_backwards +// Shared memory usage +__global__ void causal_dot_backward_query_key_kernel( + const float_accessor queries, + const float_accessor keys, + const float_accessor values, + const float_accessor grad_out, + float_accessor grad_queries, + float_accessor grad_keys, + int N, + int H, + int L, + int E, + int M +) { + int n = blockIdx.y; + int h = blockIdx.z; + + int m_start = blockIdx.x * M_BLOCK_SIZE; + int e = threadIdx.x % E; + + extern __shared__ float shared_mem[]; + const int shared_kv_size = M_BLOCK_SIZE * E; + float* shared_kv = shared_mem; + float* shared_kv_bw = shared_mem + shared_kv_size; + + for (int m_local = 0; m_local < M_BLOCK_SIZE && m_local + m_start < M; m_local++) { + shared_kv[m_local * E + e] = 0; + shared_kv_bw[m_local * E + e] = 0; + } + + for (int l=0; l>>( + queries.packed_accessor32(), + keys.packed_accessor32(), + values.packed_accessor32(), + grad_out.packed_accessor32(), + grad_queries.packed_accessor32(), + grad_keys.packed_accessor32(), + N, H, L, E, M + ); + + const int blocks_per_sequence_value = (E + E_BLOCK_SIZE - 1) / E_BLOCK_SIZE; + + dim3 blockDimv(M, 1, 1); + dim3 gridDimv(blocks_per_sequence_value, N, H); + const int shared_mem_v_backward = E_BLOCK_SIZE * M * sizeof(float); + causal_dot_backward_value_kernel<<>>( + queries.packed_accessor32(), + keys.packed_accessor32(), + values.packed_accessor32(), + grad_out.packed_accessor32(), + grad_keys.packed_accessor32(), + grad_values.packed_accessor32(), + N, H, L, E, M + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void causal_dot_backward(const torch::Tensor queries, + const torch::Tensor keys, + const torch::Tensor values, + const torch::Tensor grad_out, + torch::Tensor grad_queries, + torch::Tensor grad_keys, + torch::Tensor grad_values) { +#ifdef ENABLE_NVIDIA_OPTIMIZATIONS + int fallback = nvidia::lmha_bwd(queries, + keys, + values, + grad_out, + grad_queries, + grad_keys, + grad_values); +#else + int fallback = 1; +#endif + if( fallback ) { + // Make sure that the gradient tensors are 0. This is needed because the + // bwd pass might have partially executed and filled in some values in + // grad_queries or grad_keys. + // + // This adds a small overhead every time we have to fall back to the old + // kernel for the backward pass. + grad_queries.zero_(); + grad_keys.zero_(); + causal_dot_backward_(queries, keys, values, grad_out, grad_queries, grad_keys, grad_values); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "causal_dot_product", + &causal_dot_product, + "Compute the weighted sum of values but attending only to previous " + "values." + ); + m.def( + "causal_dot_backward", + &causal_dot_backward, + "Compute the gradients for the causal dot product." + ); +} diff --git a/csrc/setup.py b/csrc/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..fb2cd6be93eff6a9b3344cbf1b284fa3e5752591 --- /dev/null +++ b/csrc/setup.py @@ -0,0 +1,53 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +import torch +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME +import subprocess + +def get_last_arch_torch(): + arch = torch.cuda.get_arch_list()[-1] + print(f"Found arch: {arch} from existing torch installation") + return arch + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + return raw_output, bare_metal_major, bare_metal_minor + +def append_nvcc_threads(nvcc_extra_args): + _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) + if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: + return nvcc_extra_args + ["--threads", "4"] + return nvcc_extra_args + +arch = get_last_arch_torch() +sm_num = arch[-2:] +cc_flag = ['--generate-code=arch=compute_90,code=compute_90'] # for H100 +# cc_flag = ['--generate-code=arch=compute_80,code=compute_80'] # for A100 +# cc_flag = ['--generate-code=arch=compute_89,code=compute_89'] # for RTX 6000, 4090 +# cc_flag = ['--generate-code=arch=compute_86,code=compute_86'] # for A6000, 3090 +# cc_flag = ['--generate-code=arch=compute_75,code=compute_75'] + +setup( + name='causal_attention_cuda_cpp', + ext_modules=[ + CUDAExtension('causal_attention_cuda', [ + # 'causal_attention.cpp', + 'causal_attention_cuda.cu', + ], + extra_compile_args={'cxx': ['-O3'], + 'nvcc': append_nvcc_threads(['-O3', '-lineinfo', '--use_fast_math', '-std=c++17'] + cc_flag) + }) + ], + cmdclass={ + 'build_ext': BuildExtension + }) diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/dataloaders/__init__.py b/src/dataloaders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..84745f94c7a4faff73a8ebff41efc47d8f13a88c --- /dev/null +++ b/src/dataloaders/__init__.py @@ -0,0 +1,22 @@ +""" +Load dataloaders +""" +import importlib + + +def load_data(dataset_config: dict, dataloader_config: dict): + """Return dataloaders from dataset_config""" + try: + dataset_module = importlib.import_module(f'dataloaders.{dataset_config["name"]}') + except Exception: + try: + dataset_module = importlib.import_module(f'src.dataloaders.{dataset_config["name"]}') + except Exception as e2: + print(e2) + try: # e.g., tasks like GLUE where name is benchmark and path specifies the dataset / task + dataset_module = importlib.import_module(f'dataloaders.{dataset_config["path"]}') + except Exception as e3: + print(f'Error from {dataset_config}') + raise e3 + _load_data = getattr(dataset_module, 'load_data') + return _load_data(**dataset_config, **dataloader_config) \ No newline at end of file diff --git a/src/dataloaders/alpaca_clean.py b/src/dataloaders/alpaca_clean.py new file mode 100644 index 0000000000000000000000000000000000000000..d5dce8e7c8d74946ce13732af187a0a98c5e2efe --- /dev/null +++ b/src/dataloaders/alpaca_clean.py @@ -0,0 +1,149 @@ +""" +Alpaca training dataloaders + +We adopt the original prompt template; goes something like: +``` +Below is an instruction that describes a task. +Write a response that appropriately completes the request. +### Instruction: +{instruction} + +### Response: +{response} +``` +See `PROMPT_DICT` for more. +""" +from functools import partial +from os.path import join + +from datasets import load_metric, load_dataset + +from .utils import ( + get_lm_loader, get_seq2seq_loader, + convert_to_hf_dataset, + get_tokenizer_from_config, + download_scrolls_metric as download_metric +) +from .utils.packing import ConcatDataset + + +PROMPT_DICT = { + "prompt_input": ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" + ), + "prompt_no_input": ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Response:\n" + ), +} + + +def load_data(name: str, dataset_config: dict, pretrained_model_config: dict, + preprocess_config: dict, **loader_kwargs: any): + """ + Shared function to load dataset from experiment config + -> e.g., see configs/experiments/distill_alpaca_clean_lr1e-2.yaml + """ + # Misc. setup + cache_dir = dataset_config['cache_dir'] + input_len = dataset_config['chunk_size'] + concat_data = dataset_config['concat_data'] + + tokenizer_name = pretrained_model_config['pretrained_model_name_or_path'] + tokenizer_name = tokenizer_name.split('/')[-1] + # save_path = join(cache_dir, f'{name}_{tokenizer_name}') + + # Setup tokenizer + tokenizer = get_tokenizer_from_config(pretrained_model_config) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + print(f'Setting tokenizer.pad_token to {tokenizer.pad_token}') + + tokenizer.padding_side = 'left' # for decoder-only generation + # Get initial data + ignore_kwargs = ['concat_data', 'chunk_size', 'pose_kwargs'] + dataset = load_dataset( + **{k: v for k, v in dataset_config.items() if k not in ignore_kwargs} + ) + if dataset_config['name'] == 'samsum': # hack + dataset = dataset.rename_column('dialogue', 'input') + dataset = dataset.rename_column('summary', 'output') + _instruction = 'Summarize this dialogue.' + for split in dataset.keys(): + dataset[split] = dataset[split].add_column( + 'instruction', [_instruction] * len(dataset[split]) + ) + train_set, val_set, test_set = dataset['train'], dataset['validation'], dataset['test'] + dataset = train_set # hack to work with below code + else: + dataset = dataset['train'] + train_set = convert_to_hf_dataset([dataset[ix] for ix in range(200, len(dataset))], cache_dir) + val_set = convert_to_hf_dataset([dataset[ix] for ix in range(200)], cache_dir) + test_set = convert_to_hf_dataset([dataset[ix] for ix in range(200)], cache_dir) + + # Convert to dicts of {input_ids, attention_mask, labels} + train_set = train_set.map( + partial(template_and_tokenize, tokenizer=tokenizer, include_label=True), + remove_columns=list(dataset.features),) # load_from_cache_file=False) + val_set = val_set.map( + partial(template_and_tokenize, tokenizer=tokenizer, include_label=True), + remove_columns=list(dataset.features),) # load_from_cache_file=False) + test_set = test_set.map( + partial(template_and_tokenize, tokenizer=tokenizer, include_label=False), + remove_columns=list(dataset.features),) # load_from_cache_file=False) + + # Chunk together train and val sets + if concat_data: + train_set = ConcatDataset(train_set, chunk_size=input_len) + val_set = ConcatDataset(val_set, chunk_size=input_len) + + # Get dataloaders + dataloaders = { + 'train': get_lm_loader(train_set, tokenizer, 'train', input_len, **loader_kwargs), + 'validation': get_lm_loader(val_set, tokenizer, 'validation', input_len, **loader_kwargs), + 'test': get_seq2seq_loader(test_set, tokenizer, 'test', **loader_kwargs), + } + # Evaluation metric + try: + metric = load_metric(download_metric(), 'gov_report') # hack but we want rouge + except Exception as e: + print(f'Error loading metric: {e}') + metric = None + + # Finishing touches + for k, v in dataloaders.items(): # Make tokenizer accessible + dataloaders[k].dataset.tokenizer = tokenizer + dataloaders[k].dataset.metric = metric + return dataloaders + + +def template_and_tokenize(sample, tokenizer, include_label: bool = True): + """ + Format dataset context and answers into single-sequence prompts + """ + if sample.get('input', '') == '': + prompt = PROMPT_DICT["prompt_no_input"].format_map(sample) + else: + prompt = PROMPT_DICT["prompt_input"].format_map(sample) + + prompt = tokenizer.encode(prompt, add_special_tokens=True) + if include_label: + answer = tokenizer.encode(f'{sample["output"]}{tokenizer.eos_token}', + add_special_tokens=False) + target = None + else: + answer = [] + target = tokenizer.encode(f'{sample["output"]}{tokenizer.eos_token}', + add_special_tokens=False) + input_ids = prompt + answer + attn_mask = [1] * len(input_ids) + + sample = { + "input_ids": input_ids, + "attention_mask" : attn_mask, + "labels": [-100] * len(prompt) + answer if include_label else target, + } + return sample diff --git a/src/dataloaders/alpaca_clean_instruct.py b/src/dataloaders/alpaca_clean_instruct.py new file mode 100644 index 0000000000000000000000000000000000000000..58fbec854f8a4175f0b8ae5f6ec83b80892b28c0 --- /dev/null +++ b/src/dataloaders/alpaca_clean_instruct.py @@ -0,0 +1,148 @@ +""" +Alpaca Clean dataset with Llama3-Instruct prompt formatting +""" + +from functools import partial +from os.path import join + +import numpy as np +from tqdm import tqdm + +import torch +from torch.utils.data import Dataset, DataLoader + +from datasets import load_metric, load_dataset +from transformers import AutoTokenizer +from transformers import DataCollatorForSeq2Seq, DefaultDataCollator, DataCollatorWithPadding + +from .utils import ( + get_lm_loader, get_seq2seq_loader, + convert_to_hf_dataset, + get_tokenizer_from_config, + download_scrolls_metric as download_metric +) +from .utils.packing import ConcatDataset + + +SYSTEM_PROMPT = "You are a helpful AI assistant who always responds to appropriately complete a user's request." + + +def encode_response(response: str, tokenizer) -> list[int]: + tokens = tokenizer.encode(response.strip(), add_special_tokens=False) + # For Llama 3 Instruct: tokens.append(tokenizer.get_added_vocab()["<|eot_id|>"]) + tokens.append(tokenizer.eos_token_id) + try: # Llama 3 Instruct + tokens.append(tokenizer.get_added_vocab()["<|end_of_text|>"]) + except KeyError: + pass + return tokens + + +def load_data(name: str, dataset_config: dict, pretrained_model_config: dict, + preprocess_config: dict, **loader_kwargs: any): + + # Misc. setup + cache_dir = dataset_config['cache_dir'] + input_len = dataset_config['chunk_size'] + concat_data = dataset_config['concat_data'] + load_from_cache_file = False # False if want to retokenize dataset + + # Hard-code system prompt handling + if 'istral' in pretrained_model_config['pretrained_model_name_or_path']: + system_prompt = '' + else: + system_prompt = SYSTEM_PROMPT + + tokenizer_name = pretrained_model_config['pretrained_model_name_or_path'] + tokenizer_name = tokenizer_name.split('/')[-1] + save_path = join(cache_dir, f'{name}_{tokenizer_name}') + + # Setup tokenizer + tokenizer = get_tokenizer_from_config(pretrained_model_config) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + print(f'Setting tokenizer.pad_token to {tokenizer.pad_token}') + + tokenizer.padding_side = 'left' # for decoder-only generation + + # Get initial data + ignore_kwargs = ['concat_data', 'chunk_size', 'pose_kwargs', 'system_prompt', 'name'] + train_set = load_dataset( + **{k: v for k, v in dataset_config.items() if k not in ignore_kwargs}, + split='train[100:-100]', + ) + val_set = load_dataset( # we just use this dataset as a validation set + **{k: v for k, v in dataset_config.items() if k not in ignore_kwargs}, + split='train[:100]+train[-100:]', + ) + test_set = load_dataset( + **{k: v for k, v in dataset_config.items() if k not in ignore_kwargs}, + split='train[:100]+train[-100:]', + ) + + # Convert to dicts of {input_ids, attention_mask, labels} + train_set = train_set.map(partial(template_and_tokenize, tokenizer=tokenizer, + include_label=True, system_prompt=system_prompt), + remove_columns=list(train_set.features), + load_from_cache_file=load_from_cache_file) + val_set = val_set.map(partial(template_and_tokenize, tokenizer=tokenizer, + include_label=True, system_prompt=system_prompt), + remove_columns=list(val_set.features), + load_from_cache_file=load_from_cache_file) + test_set = test_set.map(partial(template_and_tokenize, tokenizer=tokenizer, + include_label=False, system_prompt=system_prompt), + remove_columns=list(test_set.features), + load_from_cache_file=load_from_cache_file) + + # Chunk together train and val sets + if concat_data: + train_set = ConcatDataset(train_set, chunk_size=input_len) + val_set = ConcatDataset(val_set, chunk_size=input_len) + + # Get dataloaders + dataloaders = { + 'train': get_lm_loader(train_set, tokenizer, 'train', input_len, **loader_kwargs), + 'validation': get_lm_loader(val_set, tokenizer, 'validation', input_len, **loader_kwargs), + 'test': get_seq2seq_loader(test_set, tokenizer, 'test', **loader_kwargs), + } + # Evaluation metric + metric = load_metric(download_metric(), 'gov_report') # hack but we want rouge + + # Finishing touches + for k, v in dataloaders.items(): # Make tokenizer accessible + dataloaders[k].dataset.tokenizer = tokenizer + dataloaders[k].dataset.metric = metric + return dataloaders + + +def template_and_tokenize(sample, tokenizer, include_label: bool = True, + system_prompt: str = None): + if system_prompt is None: + system_prompt = SYSTEM_PROMPT + + prompt = sample['instruction'] + if sample['input'] != '': + prompt += f"\n\n{sample['input']}" + + messages = [ + {"role": "system", "content": system_prompt}, + ] if system_prompt != '' else [] + messages.append({"role": "user", "content": prompt}) + prompt_ids = tokenizer.apply_chat_template( + messages, tokenize=True, add_generation_prompt=True, + ) + if include_label: + answer = encode_response(sample['output'], tokenizer) + else: + answer = [] + target = encode_response(sample['output'], tokenizer) + + input_ids = prompt_ids + answer + attn_mask = [1] * len(input_ids) + sample = { + "input_ids": input_ids, + "attention_mask" : attn_mask, + "labels": [-100] * len(prompt_ids) + answer if include_label else target, + } + return sample + diff --git a/src/dataloaders/utils/__init__.py b/src/dataloaders/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e7af0c87aaf472d536cf41bfadaefa31edec966 --- /dev/null +++ b/src/dataloaders/utils/__init__.py @@ -0,0 +1,4 @@ +""" +Helper functions dataset setup and loading +""" +from .setup import * diff --git a/src/dataloaders/utils/llama3.py b/src/dataloaders/utils/llama3.py new file mode 100644 index 0000000000000000000000000000000000000000..d13a4fc8e055ad7b6a686a1219c29607d42e1c03 --- /dev/null +++ b/src/dataloaders/utils/llama3.py @@ -0,0 +1,62 @@ +""" +Data utils for Llama3 +""" + +def encode_header(message: str, tokenizer) -> list[int]: + tokens = [] + tokens.append(tokenizer.get_added_vocab()["<|start_header_id|>"]) + tokens.extend(tokenizer.encode(message["role"], add_special_tokens=False)) + tokens.append(tokenizer.get_added_vocab()["<|end_header_id|>"]) + tokens.extend(tokenizer.encode("\n\n", add_special_tokens=False)) + return tokens + + +def encode_message(message: str, tokenizer, include_header: bool = True) -> list[int]: + tokens = encode_header(message, tokenizer) if include_header else [] + tokens.extend( + tokenizer.encode(message["content"].strip(), add_special_tokens=False) + ) + tokens.append(tokenizer.get_added_vocab()["<|eot_id|>"]) + return tokens + + +def template_and_tokenize(sample, tokenizer, include_label: bool = True, + system_prompt: str = None): + if system_prompt is not None: + dialog = [{'role': 'system', 'content': system_prompt}] + else: + dialog = [] + + chat = [] + instruction = sample['instruction'] + if sample['input'] != '': + instruction += f"\n\n{sample['input']}" + dialog.extend([ + {'role': 'user', 'content': instruction}, + {'role': 'assistant', 'content': sample['output']}, + ]) + + prompt = [] + prompt.append(tokenizer.get_added_vocab()["<|begin_of_text|>"]) + for message in dialog[:-1]: + prompt.extend(encode_message(message, tokenizer)) + + if include_label: + answer = encode_message(dialog[-1], tokenizer) + answer.append(tokenizer.get_added_vocab()["<|end_of_text|>"]) + else: + answer = [] + target = encode_message(dialog[-1], tokenizer, include_header=False) + target.append(tokenizer.get_added_vocab()["<|end_of_text|>"]) + # Add the start of an assistant message for the model to complete. + prompt.extend(encode_header({"role": "assistant", "content": ""}, tokenizer)) + + input_ids = prompt + answer + attn_mask = [1] * len(input_ids) + + sample = { + "input_ids": input_ids, + "attention_mask" : attn_mask, + "labels": [-100] * len(prompt) + answer if include_label else target, + } + return sample \ No newline at end of file diff --git a/src/dataloaders/utils/packing.py b/src/dataloaders/utils/packing.py new file mode 100644 index 0000000000000000000000000000000000000000..b51914e4e1b433848f7bacbcb617fdec363c895e --- /dev/null +++ b/src/dataloaders/utils/packing.py @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. +""" +Copied from https://github.com/meta-llama/llama-recipes/blob/9b3dabcaac78980eae40005bbc8b1a8276c82af3/src/llama_recipes/data/concatenator.py#L1 +""" +import random +from itertools import chain +from tqdm import tqdm + + +from torch.utils.data import Dataset + + +class Concatenator(object): + def __init__(self, chunk_size=2048): + self.chunk_size=chunk_size + self.residual = {"input_ids": [], "attention_mask": []} + + def __call__(self, batch): + concatenated_samples = { + k: v + list(chain(*batch[k])) for k, v in self.residual.items() + } + + total_length = len(concatenated_samples[list(concatenated_samples.keys())[0]]) + + if total_length >= self.chunk_size: + chunk_num = total_length // self.chunk_size + result = { + k: [ + v[i : i + self.chunk_size] + for i in range(0, chunk_num * self.chunk_size, self.chunk_size) + ] + for k, v in concatenated_samples.items() + } + self.residual = { + k: v[(chunk_num * self.chunk_size) :] + for k, v in concatenated_samples.items() + } + else: + result = concatenated_samples + self.residual = {k: [] for k in concatenated_samples.keys()} + + result["labels"] = result["input_ids"].copy() + + return result + +class ConcatDataset(Dataset): + """ + Concatenates or packs samples of a dataset into chunks of size `chunk_size` + """ + def __init__(self, dataset, chunk_size: int = 1024, seed: int = 42,) -> None: + self.dataset = dataset + self.chunk_size = chunk_size + self.samples = [] + buffer = { + "input_ids": [], + "attention_mask": [], + "labels": [], + } + random.seed(seed) + for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True): + buffer = {k: v + sample[k] for k,v in buffer.items()} + + while len(next(iter(buffer.values()))) > self.chunk_size: + self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()}) + buffer = {k: v[self.chunk_size:] for k,v in buffer.items()} + # Slow hack, but filter out any samples without valid labels (all -100) + self.filtered_samples = [] + for s in self.samples: + if sum(s['labels']) != chunk_size * -100: + self.filtered_samples.append(s) + if len(self.filtered_samples) < len(self.samples): + print(f'OG dataset: {len(self.samples)} samples -> Filtered dataset: {len(self.filtered_samples)}') + print(f'-> Filtered out {len(self.samples) - len(self.filtered_samples)} samples') + + def __getitem__(self, idx): + return self.filtered_samples[idx] + + def __len__(self): + return len(self.filtered_samples) diff --git a/src/dataloaders/utils/setup.py b/src/dataloaders/utils/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..8be8912d7a5db67fe8b1a9366b1b78154af321c0 --- /dev/null +++ b/src/dataloaders/utils/setup.py @@ -0,0 +1,123 @@ +""" +Helper functions dataset setup and loading +""" +import os +from os.path import join +import shutil +import numpy as np + +from torch.utils.data import Dataset, DataLoader + +from datasets import Dataset as HFDataset +from huggingface_hub import hf_hub_download +from transformers import AutoTokenizer, LlamaTokenizer +from transformers import DataCollatorForSeq2Seq +# from transformers import DefaultDataCollator, DataCollatorWithPadding + + +def get_seq2seq_loader(dataset: Dataset, tokenizer: AutoTokenizer, + split: str, **loader_kwargs: any): + """ + Get dataloader for seq2seq tasks (evaluation) + """ + tokenizer.padding_side = 'right' + collate_fn = DataCollatorForSeq2Seq( + tokenizer, label_pad_token_id=-100, return_tensors='pt') + return DataLoader( + dataset, shuffle='train' in split, collate_fn=collate_fn, **loader_kwargs) + + +def get_lm_loader(dataset: Dataset, tokenizer: AutoTokenizer, + split: str, max_length: int = None, **loader_kwargs: any): + """ + Get dataloader for language modeling (training) + -> Currently this ends up being the same as get_seq2seq_loader + """ + # collate_fn = DefaultDataCollator(return_tensors='pt') + # collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, padding=True, + # max_length=max_length, return_tensors='pt') + collate_fn = DataCollatorForSeq2Seq( + tokenizer, label_pad_token_id=-100, return_tensors='pt') + return DataLoader( + dataset, shuffle='train' in split, collate_fn=collate_fn, **loader_kwargs) + + +def convert_to_hf_dataset(dataset, cache_dir: str): + """ + Convert iterable dataset to HuggingFace HFDataset object + """ + def gen(): + for _, sample in enumerate(dataset): + yield sample # dataset[idx] + return HFDataset.from_generator(gen, cache_dir=cache_dir) + + +def get_tokenizer_from_config(model_config): + """ + Get pretrained tokenizer based on (pretrained) model config + """ + # Get tokenizer + if 'llama' in model_config['pretrained_model_name_or_path']: + try: # if we store locally + model_path = join(model_config['cache_dir'], + model_config['pretrained_model_name_or_path']) + tokenizer = LlamaTokenizer.from_pretrained(model_path) + except Exception as e: + try: + tokenizer = AutoTokenizer.from_pretrained(**model_config) + print("-> Bad LlamaTokenizer.from_pretrained(model_path)", e) + print("-> But resolved with: AutoTokenizer.from_pretrained(**model_config)") + except Exception as e2: + print("-> Error with AutoTokenizer.from_pretrained(**model_config)", e2) + # tokenizer = LlamaTokenizer.from_pretrained(**model_config) # v4.43 errors with `*** TypeError: not a string` + elif 'Mistral-7B-Instruct-v0.3' in model_config['pretrained_model_name_or_path']: + tokenizer = LlamaTokenizer.from_pretrained(**model_config) # hack where AutoTokenizer doesn't recognize + elif 'Mistral-7B' in model_config['pretrained_model_name_or_path']: + tokenizer = AutoTokenizer.from_pretrained(**model_config) + else: + tokenizer = AutoTokenizer.from_pretrained(**model_config) + return tokenizer + + +def add_special_tokens_to_dataset(dataset, tokenizer): + """ + Add special tokens as attributes to a dataset object + """ + token_map = {k: v for k, v in tokenizer.special_tokens_map.items()} + special_ids = tokenizer.all_special_ids + for idx, k in enumerate(tokenizer.special_tokens_map.keys()): + token_map[f'{k}_id'] = special_ids[idx] + for k, v in token_map.items(): + setattr(dataset, k, v) + return dataset + + +def train_test_split(samples: any, train_size: int, test_size: int, seed: int): + """ + Split samples into train and test sets + """ + try: + assert len(samples) == train_size + test_size + except Exception as e: + print(len(samples), train_size + test_size) + raise e + arange = np.arange(len(samples)) + np.random.seed(seed) + test_idx = np.random.choice(arange, size=test_size, replace=False) + train_idx = np.setdiff1d(arange, test_idx) + return samples[train_idx], samples[test_idx] + + +def download_scrolls_metric(): + """ + Download ROUGE, F1, and other accuracy metrics included in the SCROLLS dataset + """ + scrolls_metric_path = hf_hub_download( + repo_id="tau/scrolls", filename="metrics/scrolls.py", repo_type="dataset" + ) + updated_scrolls_metric_path = ( + os.path.dirname(scrolls_metric_path) + + os.path.basename(scrolls_metric_path).replace(".", "_") + ".py" + ) + shutil.copy(scrolls_metric_path, updated_scrolls_metric_path) + return updated_scrolls_metric_path diff --git a/src/finetune.py b/src/finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..3ef7f867e68cb7738b56c08b0e10454e66e4fffb --- /dev/null +++ b/src/finetune.py @@ -0,0 +1,68 @@ +""" +Finetuning functions to do post-distillation +""" +from os.path import join +from omegaconf import OmegaConf + +import torch +from torch.nn import Module + +from src.utils.setup import update_config_from_args +from src.dataloaders import load_data +from src.trainer import get_trainer, get_optimizer, get_scheduler + + +def prepare_finetune_configs(args, model_config: dict, + finetune_config_name: str = None, + finetune_checkpoint_name: str = None, + config_dir='./configs/experiment'): + """ + Prepare finetuning configs + """ + # Load finetuning config + finetune_config = (finetune_config_name if finetune_config_name is not None else + finetune_checkpoint_name.split('-f=')[-1].split('-')[0]) + finetune_config_path = join(config_dir, f'{finetune_config}.yaml') + finetune_config = OmegaConf.load(finetune_config_path) + finetune_config = update_config_from_args(finetune_config, args, + ignore_args=['lr', 'weight_decay']) + # Update data tokenizer to match model + if getattr(finetune_config.dataset, 'pretrained_model_config', None) is not None: + for k in ['pretrained_model_name_or_path', 'cache_dir']: + finetune_config.dataset.pretrained_model_config[k] = model_config['model'][k] + # Set finetuning args + for arg, argv in finetune_config.trainer.items(): + if arg != 'name': + setattr(args, arg, argv) + for _config in ['dataloader', 'optimizer', 'lr_scheduler']: + setattr(args, _config, OmegaConf.to_container(getattr(finetune_config, _config))) + return finetune_config, args + + +def get_finetuner(model: Module, finetune_config: dict, device: torch.device, + args: any, wandb: any, initial_eval: bool = False): + """ + Initialize finetuning trainer + """ + model.to(device) # if using a fused optimizer + model.train() + + # Initialize optimizer and scheduler + optimizer = get_optimizer(model=model, **finetune_config.optimizer) + scheduler = get_scheduler(optimizer=optimizer, **finetune_config.lr_scheduler) + + dataloaders = load_data(finetune_config.dataset, finetune_config.dataloader) + train_loader = dataloaders[finetune_config.trainer.train_split] + eval_loader = dataloaders[finetune_config.trainer.val_split] + + OurTrainer = get_trainer(finetune_config.trainer.name) + trainer = OurTrainer(model=model, + args=args, + train_loader=train_loader, + eval_loader=eval_loader, + optimizer_and_scheduler=(optimizer, scheduler), + device=device, + wandb=wandb, + checkpoint_suffix='_ft', + **finetune_config.trainer) + return trainer diff --git a/src/model/__init__.py b/src/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/model/convert_model.py b/src/model/convert_model.py new file mode 100644 index 0000000000000000000000000000000000000000..41dd1dc8ead96e8cdaa0442abef6083d4d14a06f --- /dev/null +++ b/src/model/convert_model.py @@ -0,0 +1,173 @@ +""" +Attention conversion helpers +""" +from functools import partial +from tqdm import tqdm +import torch.nn as nn + + +def convert_attention(model: nn.Module, + attention_config: dict, + train_attention: bool = False, + remove_base_attn: bool = True,): + """ + Call to convert all attention layers + """ + softmax_attns = [] + if 'softmax_attentions' in attention_config: + softmax_attns = attention_config['softmax_attentions'] + if attention_config.attention_type != 'softmax': + layers = traverse_layers(model) + for layer_idx, layer in enumerate(tqdm(layers, desc='Converting attentions...')): + if layer_idx not in softmax_attns: + layer.self_attn = convert_llama_attention( + layer, attention_config, layers, train_attention, remove_base_attn, + ) + layer.self_attn.converted = True + else: # Freeze any preserved softmax attention layers + for p in layer.parameters(): + p.requires_grad = False + else: + print(f'-> attention_config.attention_type is {attention_config.attention_type}; not converting attentions') + return model + + +def toggle_attention(llama_model: nn.Module, train: bool = False): + """ + Make attentions trainable if train is True + -> Set train_attention = False when finetuning + """ + for layer in traverse_layers(llama_model): + layer.self_attn.train_attention = train + return llama_model + + +def remove_base_attention(llama_model: nn.Module): + """ + Remove teacher attention after distillation (if we keep it) + """ + for layer in traverse_layers(llama_model): + if getattr(layer.self_attn, 'base_attn', False): + del layer.self_attn.base_attn + return llama_model + + +def traverse_layers(model: nn.Module, verbose: bool = False): + """ + Return list of model layers + """ + try: + layers = model.model.layers + if verbose: + print('-> Loading from model.model.layers') + except AttributeError as e: # if base model + if verbose: + print(e) + try: + layers = model.layers + if verbose: + print('-> Loading from model.layers') + except AttributeError as e1: # If we make a PEFT model + if verbose: + print(e1) + layers = model.base_model.model.model.layers + if verbose: + print('-> Loading from model.base_model.model.model.layers') + return layers + + +def convert_llama_attention(layer: nn.Module, + attention_config: dict, + layers: list[nn.Module], # list of layers + train_attention: bool = False, + remove_base_attn: bool = True): + """ + Converts a single layer's attention layer as specified by attention_config + """ + return get_attention(**attention_config)( + base_attn=layer.self_attn, + layer_idx=layer.self_attn.layer_idx, # Transformers v4.36 + max_layer_idx=len(layers) - 1, + train_attention=train_attention, + remove_base_attn=remove_base_attn, + ) + + +def get_attention(attention_type: str, **kwargs: any): + """ + Get the linear attention class; either purely linear or linear with sliding window + -> 'linear' == 'lolcats_llama' + -> 'linear and sliding_window' == 'lolcats_llama_window_*' + """ + kwargs['attention_type'] = attention_type + + if attention_type == 'lolcats_llama': + from .linear_attention import LolcatsLinearAttention + return partial(LolcatsLinearAttention, **kwargs) + + elif attention_type == 'lolcats_llama_window_tk': + from .linear_attention import LolcatsTKWindowAttention + return partial(LolcatsTKWindowAttention, **kwargs) + + elif attention_type == 'lolcats_llama_window_sw': + from .linear_attention import LolcatsSlidingWindowAttention + return partial(LolcatsSlidingWindowAttention, **kwargs) + + elif attention_type == 'lolcats_llama_window_sw_linear': + from .linear_attention.linear_window_attention_sw_linear import LolcatsLinearSlidingWindowAttention + return partial(LolcatsLinearSlidingWindowAttention, **kwargs) + + ## Experimental chunked linear attentions below + elif attention_type == 'lolcats_long_llama_window_tk': + from .linear_attention import LolcatsTKWindowLongAttention + return partial(LolcatsTKWindowLongAttention, **kwargs) + + elif attention_type == 'lolcats_long_llama_window_sw': + from .linear_attention import LolcatsSlidingWindowLongAttention + return partial(LolcatsSlidingWindowLongAttention, **kwargs) + + ## TK generation build (requires Thunderkittens) + elif attention_type == 'lolcats_llama_window_tk_gen': + from .linear_attention import LolcatsWindowAttentionTKGen + return partial(LolcatsWindowAttentionTKGen, **kwargs) + + else: + print(f'-> attention_type {attention_type} not handled... returning None') + return None + + +def get_attention_cache(attention_type: str, past_key_values: any = None): + """ + Determine how we store past keys and values when generating + """ + if attention_type is None: + return past_key_values + + # print(f'Returning attention cache based on attention_type == {attention_type}') + elif 'lolcats_llama_window_tk_gen' in attention_type: + from .linear_attention import LinearAttentionTKWindowGenerationCache + return LinearAttentionTKWindowGenerationCache() + + elif 'llama_window_tk' in attention_type: + from .linear_attention import LinearAttentionTKWindowCache + return LinearAttentionTKWindowCache() + + elif 'llama_window_sw' in attention_type: + from .linear_attention import LinearAttentionSlidingWindowCache + return LinearAttentionSlidingWindowCache() + + elif 'llama_window_sw_linear' in attention_type: + from .linear_attention import LinearAttentionSlidingWindowCache + return LinearAttentionSlidingWindowCache() + + ## TK generation build (requires Thunderkittens) + elif attention_type == 'lolcats_llama_window_tk_gen': + from .linear_attention.linear_window_attention_tk_gen import LinearAttentionTKWindowGenerationCache + return LinearAttentionTKWindowGenerationCache() + + elif 'softmax' in attention_type: + return past_key_values + + else: + from .linear_attention import LinearAttentionState + return LinearAttentionState() diff --git a/src/model/feature_map.py b/src/model/feature_map.py new file mode 100644 index 0000000000000000000000000000000000000000..fae1064132f45f2f76f60bfef0f98f127a78cc4a --- /dev/null +++ b/src/model/feature_map.py @@ -0,0 +1,306 @@ +""" +Learnable linear attention feature map classes and functions +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def init_feature_map(name: str, mlp: nn.Module, **kwargs: dict): + """ + Initialize feature map final activation for linear attention + """ + return FeatureMap(activation_name=name, mlp=mlp, **kwargs) + + +def init_feature_map_act(name: str, fullspace: bool = True, **kwargs): + """ + Initialize feature map final activation for linear attention + """ + if name == 'softmax_dim' and fullspace: + return SoftmaxDim(**kwargs) + elif name == 'softmax_dim' and not fullspace: + return SoftmaxDimHalfspace(**kwargs) + elif name == 'exp_dim' and fullspace: + return Exp(**kwargs) + elif name == 'exp_dim' and not fullspace: + return ExpHalfspace(**kwargs) + elif name == 'pos_elu': + return PosELU(**kwargs) + elif name == 'relu': + return ReLU(**kwargs) + + else: + raise NotImplementedError + + +def init_learned_kernel(name: str, **kwargs: any): + """ + Initialize feature map MLP for linear attention + """ + if name == 'untied_head_einsum': + return FeatureMapMLP(**kwargs) + elif name == 'untied_head_adapter': + return FeatureMapAdapter(**kwargs) + else: + raise NotImplementedError + + +class FeatureMap(nn.Module): + """ + Final 'activation' of feature map. Can probably be combined with + `FeatureMapMLP` below + + Full feature map is like f(xW + b) + -> This is the `f` part + """ + def __init__(self, + activation_name: str, + head_dim_idx: int = -1, + eps: float = 1e-12, + mlp: nn.Module = None, + fullspace: bool = True,): + super().__init__() + self.head_dim_idx = head_dim_idx + self.eps = eps + self.mlp = mlp if mlp is not None else nn.Identity() + self.activation = init_feature_map_act(activation_name, fullspace, eps=eps) + + def forward(self, x: torch.Tensor, *mlp_args: any, **mlp_kwargs: any): + """ + Assume x.shape is (batch_size, n_heads, seq_len, head_dim) + """ + return self.activation(self.mlp(x, *mlp_args, **mlp_kwargs), x) + + def q_map(self, *args: any, **kwargs: any): + """ + Use for inference in case q and k feature maps differ + """ + return self.forward(*args, **kwargs) + + def k_map(self, *args: any, **kwargs: any): + """ + Use for inference in case q and k feature maps differ + """ + return self.forward(*args, **kwargs) + + +# ----------------------- +# Feature map activations +# ----------------------- +class FeatureMapAct(nn.Module): + """ + Base class for feature map activations + """ + def __init__(self, eps: float = 1e-12): + super().__init__() + self.eps = eps + + def forward(self, x: torch.Tensor, *args: any, **kwargs: any): + """ + x.shape is (batch_size, n_heads, seq_len, head_dim) + """ + return x + + +class PosELU(FeatureMapAct): + """ + 1 + ELU activation as in https://arxiv.org/abs/2006.16236 + """ + def forward(self, x: torch.Tensor, *args: any, **kwargs: any): + return (1 + F.elu(x)).clamp(min=self.eps) + + +class ReLU(FeatureMapAct): + """ + ReLU activation as in https://arxiv.org/abs/2103.13076 + """ + def forward(self, x: torch.Tensor, *args: any, **kwargs: any): + return F.relu(x).clamp(min=self.eps) + + +class SoftmaxDim(FeatureMapAct): + """ + Softmax activation as in https://arxiv.org/abs/2402.04347 + """ + def forward(self, x: torch.Tensor, *args: any, **kwargs: any): + return torch.cat([ + torch.softmax(x, dim=-1), torch.softmax(-x, dim=-1) + ], dim=-1).clamp(min=self.eps) + + +class SoftmaxDimHalfspace(FeatureMapAct): + """ + Softmax activation as in https://arxiv.org/abs/2402.04347 + """ + def forward(self, x: torch.Tensor, *args: any, **kwargs: any): + return torch.softmax(x, dim=-1).clamp(min=self.eps) + + +class Exp(FeatureMapAct): + """ + Exp activation as in https://arxiv.org/abs/2402.04347 + """ + def forward(self, x: torch.Tensor, *args: any, **kwargs: any): + x_max = torch.amax(x, dim=-1, keepdim=True) + x_min = torch.amin(x, dim=-1, keepdim=True) + return torch.cat([ + torch.exp(x - x_max), torch.exp(-x + x_min) + ], dim=-1).clamp(min=self.eps) + + +class ExpHalfspace(FeatureMapAct): + """ + Exp activation as in https://arxiv.org/abs/2402.04347 + """ + def forward(self, x: torch.Tensor, *args: any, **kwargs: any): + x_max = torch.amax(x, dim=-1, keepdim=True) + return torch.exp(x - x_max).clamp(min=self.eps) + + +# ---------------- +# Feature map MLPs +# ---------------- + +class FeatureMapMLP(nn.Module): + """ + Learnable MLP in feature map. + + Full feature map is like f(xW + b) + -> This is the `W` and (optional) `b` part + """ + def __init__(self, + num_heads: int, + head_dim: int, # input dim + feature_dim: int, # output dim + dtype: torch.dtype, + device: torch.device, + skip_connection: bool = False, + bias: bool = False, + zero_init: bool = False, + normal_init: bool = False,): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.feature_dim = feature_dim + self.dtype = dtype + self.device = device + self.skip_connection = skip_connection + self.bias = bias + self.zero_init = zero_init + self.normal_init = normal_init + self.init_weights_() + + if self.zero_init: # Zero-out weights or set as identity post-initialization + self.zero_init_with_skip_() if self.skip_connection else self.zero_init_() + + if self.normal_init: + with torch.no_grad(): + nn.init.normal_(self.layer) + + if self.skip_connection: + assertion_fail = f'If self.skip_connection we need self.head_dim == self.feature_dim but self.head_dim is {self.head_dim} != self.feature_dim is {self.feature_dim}' + assert self.head_dim == self.feature_dim, assertion_fail + + def init_weights_(self): + """ + Initialize (W)eights and (b)iases + """ + self.layer = nn.Parameter(torch.zeros( + (self.num_heads, self.head_dim, self.feature_dim), + dtype=self.dtype, device=self.device, + )) + nn.init.kaiming_uniform_(self.layer) + + if self.bias: + self.bias = nn.Parameter(torch.zeros( + (1, self.num_heads, 1, 1), # self.feature_dim), + dtype=self.dtype, device=self.device, + )) + nn.init.kaiming_uniform_(self.bias) + else: + self.bias = 0. # hack + + def zero_init_with_skip_(self): + """ + Initialize weights to zero matrix if skip connection + """ + with torch.no_grad(): + nn.init.zeros_(self.layer) + + def zero_init_(self): + """ + Initialize weights to identity matrix if no skip connection + """ + with torch.no_grad(): + for i in range(self.layer.shape[0]): + try: + nn.init.eye_(self.layer[i]) + except RuntimeError: + with torch.no_grad(): + dtype = self.layer[i].dtype + weight = torch.eye(*self.layer[i].shape, + requires_grad=self.layer[i].requires_grad, + device=self.layer[i].device) + self.layer[i] = weight.to(dtype=dtype) + + def forward(self, x: torch.Tensor): + """ + Assume x.shape is (batch_size, num_heads, seq_len, head_dim) + """ + _x = torch.einsum('hdf,bhld->bhlf', self.layer, x) + self.bias + return x + _x if self.skip_connection else _x + + +class FeatureMapAdapter(FeatureMapMLP): + """ + Learnable Feature map with bottleneck adapter + as in https://arxiv.org/abs/1902.00751 + + We don't use but could be fun to try + """ + def __init__(self, hidden_dim: int, *args, **kwargs): + kwargs['skip_connection'] = True + kwargs['bias'] = True + kwargs['zero_init'] = True + self.hidden_dim = hidden_dim + super().__init__(*args, **kwargs) + + def init_weights_(self): + """ + Initialize (W)eights and (b)iases + """ + kwargs = {'dtype': self.dtype, 'device': self.device} + self.layer0 = nn.Parameter( + torch.zeros((self.num_heads, self.head_dim, self.hidden_dim), **kwargs) + ) + self.layer1 = nn.Parameter( + torch.zeros((self.num_heads, self.hidden_dim, self.feature_dim), **kwargs) + ) + nn.init.kaiming_uniform_(self.layer0) + nn.init.kaiming_uniform_(self.layer1) + + self.bias0 = nn.Parameter(torch.zeros((1, self.num_heads, 1, self.hidden_dim), **kwargs)) + self.bias1 = nn.Parameter(torch.zeros((1, self.num_heads, 1, self.feature_dim), **kwargs)) + nn.init.kaiming_uniform_(self.bias0) + nn.init.kaiming_uniform_(self.bias1) + + def zero_init_with_skip_(self): + with torch.no_grad(): + nn.init.zeros_(self.layer0) + nn.init.zeros_(self.layer1) + nn.init.zeros_(self.bias0) + nn.init.zeros_(self.bias1) + + def zero_init_(self): + assert NotImplementedError + + def forward(self, x: torch.Tensor): + """ + Assume x.shape is (batch_size, num_heads, seq_len, head_dim) + -> Down-project, apply nonlinearity, up-project; add skip connection + """ + _x = torch.einsum('hde,bhld->bhle', self.layer0, x) + self.bias0 + _x = F.relu(_x) + _x = torch.einsum('hef,bhle->bhlf', self.layer1, _x) + self.bias1 + return x + _x if self.skip_connection else _x diff --git a/src/model/linear_attention/__init__.py b/src/model/linear_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2482d284080da3c880d341b63388d6e01a15bfc1 --- /dev/null +++ b/src/model/linear_attention/__init__.py @@ -0,0 +1,23 @@ +""" +Linear and linear attention + sliding window classes +""" +from .linear_attention import ( + LolcatsLinearAttention, LinearAttentionState +) +from .linear_window_attention_tk import ( + LolcatsTKWindowAttention, LinearAttentionTKWindowCache +) +from .linear_window_attention_sw import ( + LolcatsSlidingWindowAttention, LinearAttentionSlidingWindowCache +) +# Experimental chunk linear attentions +from .linear_window_attention_tk_long import ( + LolcatsTKWindowLongAttention, +) +from .linear_window_attention_sw_long import ( + LolcatsSlidingWindowLongAttention, +) +from .linear_window_attention_tk_gen import ( + LolcatsWindowAttentionTKGen, + LinearAttentionTKWindowGenerationCache +) diff --git a/src/model/linear_attention/linear_attention.py b/src/model/linear_attention/linear_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..727e15ae40eb421c80e0a673f68f1967d016a3f8 --- /dev/null +++ b/src/model/linear_attention/linear_attention.py @@ -0,0 +1,459 @@ +""" +Linear attention classes +""" +from typing import List, Tuple, Optional +import copy +import torch +import torch.nn as nn +from omegaconf import OmegaConf, DictConfig + +from transformers.cache_utils import Cache # starting at Transformers v4.36 + +# Causal linear attention dot product CUDA kernel from fast-transformers +try: + from csrc import causal_dot_product as fast_causal_dot_product +except ImportError: + fast_causal_dot_product = None + +from src.model.feature_map import init_feature_map, init_learned_kernel +from src.model.rotary import get_rotary_embeddings, apply_rotary_pos_emb +from .utils import repeat_kv + + +# ------------------- +# Attention functions +# ------------------- + +def causal_dot_product(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """ + Causal linear attention dot product + - If available, use CUDA kernel from fast-transformers + """ + if fast_causal_dot_product is None: + kv = torch.einsum('bhlf,bhld->bhlfd', k, v) + return torch.einsum('bhlf,bhlfd->bhld', q, kv.cumsum(dim=2)) + return fast_causal_dot_product(q, k, v) + +def linear_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + fp32_attention: bool = False, eps: float = 1e-12, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Compute linear attention with CUDA kernel implementation from fast-transformers + - https://github.com/idiap/fast-transformers + - Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim); + v is shape (b, h, l, head_dim) + """ + dtype = q.dtype + # Causal mask already applied + y = causal_dot_product(q.contiguous().to(dtype=torch.float32), + k.contiguous().to(dtype=torch.float32), + v.contiguous().to(dtype=torch.float32)) + if fp32_attention: + y = (y / (torch.einsum( + "bhld,bhld->bhl", q.float(), k.float().cumsum(dim=2) + ) + eps)[..., None]).to(dtype=dtype) + else: + y = y.to(dtype=dtype) + k = k.float().cumsum(dim=2).to(dtype=dtype) + y = y / (torch.einsum("bhld,bhld->bhl", q, k) + eps)[..., None] + return y, None, None + + +def softmax_attention(q: torch.Tensor, k: torch.Tensor, v: Optional[torch.Tensor] = None, + causal: bool = True, fp32_attention: bool = True, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Standard softmax attention; only compute outputs if v is not None + -> Assume q, k, v are shape (batch_size, num_heads, seq_len, head_dim) + """ + y = None + a = torch.einsum('bhmd,bhnd->bhmn', q, k) * (k.shape[-1] ** -0.5) + if causal: # Apply causal mask + m, n = a.shape[-2:] + causal_mask = torch.ones((m, n), device = a.device, dtype = torch.bool).triu(n - m + 1) + a = a.masked_fill(causal_mask, -torch.finfo(a.dtype).max) + if fp32_attention: + a = torch.softmax(a, dim=-1, dtype=torch.float32).to(q.dtype) + else: + a = torch.softmax(a, dim=-1) + if v is not None: + y = torch.einsum('bhmn,bhnd->bhmd', a, v) + return y, a, None + + +def quadratic_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor = None, + causal: bool = True, fp32_attention: bool = False, eps: float = 1e-12, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Compute attention with feature maps by instantiating L x L matrix of attention weights + -> Use for attention distillation + -> Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim); v is shape (b, h, l, head_dim) + """ + y = None + dtype = q.dtype + if fp32_attention: + q, k = q.float(), k.float() + a = torch.einsum('bhmd,bhnd->bhmn', q, k) # note we don't scale, tho we could + if causal: # Apply causal mask + m, n = a.shape[-2:] + causal_mask = torch.ones((m, n), device = a.device, dtype = torch.bool).triu(n - m + 1) + a = a.masked_fill(causal_mask, 0) + # Normalize to compute attention + a = a / (a.sum(dim=-1, keepdim=True) + eps) + a = a.to(dtype=dtype) if fp32_attention else a + if torch.isnan(a).sum() > 0: + breakpoint() + if v is not None: + y = torch.einsum('bhmn,bhnd->bhmd', a, v) + return y, a, None + + +# --------------------- +# Attention layer class +# --------------------- + +class LolcatsLinearAttention(nn.Module): + """ + LoLCATs attention implementation initialized from a + `LlamaAttention` or `MistralAttention` object (base_attn) + + Most of the arguments are directly tied to argparse args + - For now we don't support padding. + """ + def __init__(self, + base_attn: nn.Module, # like LlamaAttention + feature_map: str, + feature_map_kwargs: dict, + layer_idx: Optional[int] = None, + max_layer_idx: Optional[int] = None, + learned_kernel: Optional[str] = None, + learned_kernel_kwargs: Optional[dict] = None, + tie_qk_kernels: Optional[bool] = False, + rotary_config: Optional[dict] = None, + train_attention: Optional[bool] = False, + remove_base_attn: Optional[bool] = True, + attention_type: Optional[str] = 'lolcats_llama', + mask_value: int = 0, + eps: float = 1e-12, + fp32_attention: bool = False, + track_state_grads: bool = False, + rank: Optional[int] = 0, + **kwargs: any) -> None: + super().__init__() + self.base_config = getattr(base_attn, 'config', None) + if self.base_config is not None: + self.base_config = self.base_config.to_dict() + self.attention_type = attention_type + self.mask_value = mask_value + self.eps = eps + self.layer_idx = (layer_idx if layer_idx is not None else base_attn.layer_idx) + self.max_layer_idx = max_layer_idx + self.tie_qk_kernels = tie_qk_kernels + self.train_attention = train_attention + self.base_inference = False + self.fp32_attention = fp32_attention + self.track_state_grads = track_state_grads + if rank == 0: # multi-gpu + if fp32_attention and layer_idx == 0: + print(f'-> fp32_attention is {fp32_attention}') + if layer_idx == 0 and feature_map_kwargs is not None: + for k, v in feature_map_kwargs.items(): + print(f'-> {k}: {v}') + if layer_idx == 0 and learned_kernel_kwargs is not None: + for k, v in learned_kernel_kwargs.items(): + print(f'-> {k}: {v}') + + self.remove_base_attn = remove_base_attn + + # Rotary embeddings (patch for Llama 3.1, Transformer v4.43.0) + self.rotary_config = rotary_config + if isinstance(self.rotary_config, DictConfig): # ensure dict + self.rotary_config = OmegaConf.to_container(self.rotary_config) + + self.rotary_emb = None + if self.base_config is not None and self.rotary_config is None: + self.rotary_emb = base_attn.rotary_emb + + self.init_weights_(base_attn, remove_base_attn) + self.init_feature_map_(feature_map, feature_map_kwargs, + learned_kernel, learned_kernel_kwargs) + + def init_feature_map_(self, + feature_map: str, + feature_map_kwargs: dict, + learned_kernel: str = None, + learned_kernel_kwargs: dict = None): + """ + Initialize MLP-based feature map + """ + self.fmap_gqa = False # Turn True if specified below + if learned_kernel is not None: + # Ensure dict + learned_kernel_kwargs = {k: v for k, v in learned_kernel_kwargs.items()} + learned_kernel_kwargs['num_heads'] = self.num_heads + learned_kernel_kwargs['head_dim'] = self.head_dim + learned_kernel_kwargs['dtype'] = self.q_proj.weight.dtype + learned_kernel_kwargs['device'] = self.q_proj.weight.device + # Create MLP + mlp_learned_kernel = init_learned_kernel(learned_kernel, **learned_kernel_kwargs) + # Add "activation"; see src.models.feature_map.py + self.feature_map_q = init_feature_map(name=feature_map, + mlp=mlp_learned_kernel, + **feature_map_kwargs) + if self.tie_qk_kernels: # tie mlp weights for query and key feature maps + self.feature_map_k = self.feature_map_q + else: + self.feature_map_k = copy.deepcopy(self.feature_map_q) + + def init_weights_(self, base_attn: nn.Module, remove_base_attn: bool = True): + """ + Initialize module layers, weights, positional dependencies, etc. + from original softmax attention layer (base_attn) + """ + # Make other attributes accessible + self.attention_dropout = 0 # We don't use dropout + self.hidden_size = base_attn.hidden_size + self.num_heads = base_attn.num_heads + self.head_dim = base_attn.head_dim + self.num_key_value_heads = base_attn.num_key_value_heads + self.num_key_value_groups = base_attn.num_key_value_groups + + self.q_shape = [self.num_heads, self.head_dim] + self.k_shape = [self.num_key_value_heads, self.head_dim] + self.v_shape = [self.num_key_value_heads, self.head_dim] + device = base_attn.q_proj.weight.device + # Rotary embeddings + if self.rotary_emb is None: + self.max_position_embeddings = base_attn.max_position_embeddings + scaling_factor = getattr(base_attn.rotary_emb, 'scaling_factor', 1.) + if self.rotary_config is None: + self.rotary_emb = get_rotary_embeddings( + rope_scaling_type=None, + head_dim=self.head_dim, + max_position_embeddings=self.max_position_embeddings, # base_attn.rotary_emb.max_position_embeddings, + rope_theta=base_attn.rotary_emb.base, + rope_scaling_factor=scaling_factor, # base_attn.rotary_emb.scaling_factor, + device=device, + ) + else: + if 'device' not in self.rotary_config: + self.rotary_config['device'] = device + self.rotary_emb = get_rotary_embeddings(**self.rotary_config) + + # Copy original model projection layers + self.q_proj = base_attn.q_proj + self.k_proj = base_attn.k_proj + self.v_proj = base_attn.v_proj + self.o_proj = base_attn.o_proj + try: # If wanting to use FA2 for ground-truth inference + self._flash_attn_uses_top_left_mask = base_attn._flash_attn_uses_top_left_mask + except AttributeError: + pass + + if self.remove_base_attn or remove_base_attn: + del base_attn # We don't need to keep these around + else: + self.base_attn = base_attn # For some training runs helpful to just call + + def process_qkv(self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[int, torch.Tensor, torch.Tensor]] = None,): # "legacy" cache approach + """ + Compute queries, keys, and values + """ + b, l, _ = hidden_states.size() + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + kv_seq_len = k.shape[-2] + + # Shape is (batch_size, seq_len, num_heads, head_dim) + q = q.view(b, l, *self.q_shape).transpose(1, 2) + k = k.view(b, l, *self.k_shape).transpose(1, 2) + v = v.view(b, l, *self.v_shape).transpose(1, 2) + + if past_key_value is not None: # and k.shape[2] > q.shape[2]: # e.g., when generating + past_key_value.window_size = getattr(self, 'decode_window_size', None) # self.decode_window_size + if isinstance(past_key_value, Cache): # In Transformers v4.36+ this is a DynamicCache object + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + else: + kv_seq_len += past_key_value[0].shape[-2] + + # Apply rotary embeddings and repeat for GQA + if position_ids is not None and kv_seq_len <= position_ids[0, -1]: + kv_seq_len = position_ids[0, -1] + 1 # hack for adjusting position ids + try: # As in Transformers v4.36 + cos, sin = self.rotary_emb(k, seq_len=kv_seq_len) + q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) + except TypeError: # As in Transformers v4.39+ + cos, sin = self.rotary_emb(v, position_ids) + q, k = apply_rotary_pos_emb(q, k, cos, sin) + + k = repeat_kv(k, self.num_key_value_groups) + v = repeat_kv(v, self.num_key_value_groups) + return q, k, v, kv_seq_len + + def forward(self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[int, torch.Tensor, torch.Tensor]] = None, # "legacy" cache approach + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Forward pass modified from transformers.models.mistral.modeling_mistral (v4.36) + - Consistent with HuggingFace Transformers for easy use with their pretrained models + """ + b, l, _ = hidden_states.size() + q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask, + position_ids, past_key_value) + if self.base_inference: + with torch.no_grad(): + # 1. Compute "ground-truth" attention output and weights + y_true, _, _ = softmax_attention(q, k, v, causal=True) + y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) + y_true = self.o_proj(y_true) + attn_weights = (None, None) + + elif self.train_attention: # Distilling / learning attentions + # Note for now we assume no padding when distilling; attention masks only enforce causality + assert output_attentions is True, f'When training feature maps, output_attentions should be True but is {output_attentions}' + with torch.no_grad(): + # 1. Compute "ground-truth" attention output and weights + _y_true, attn_true, _ = softmax_attention(q, k, v, causal=True) + y_true = _y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) + y_true = self.o_proj(y_true) + + # 2. Compute "predicted" attention (just weights) + q, k = self.feature_map_q.q_map(q), self.feature_map_k.k_map(k) + y_pred, attn_pred, _ = quadratic_attention(q, k, v, causal=True) + attn_weights = ((attn_pred, attn_true), (y_pred, _y_true)) # Save both attention weights so we can supervise. + + else: # Finetuning + q, k = self.feature_map_q(q), self.feature_map_k(k) + # Apply prefill mask + if attention_mask is not None and q.shape[2] > 1: + if len(attention_mask.shape) == 4: + lin_attn_mask = (attention_mask == 0)[:, :1, -1, :l][..., None] # b, 1, k_len, 1 + else: + lin_attn_mask = attention_mask[:, None, :, None] # b, 1, k_len, 1 + k = k.masked_fill(~lin_attn_mask, 0) + + if past_key_value is not None: # Initialize states + if len(past_key_value.kv_states) == self.layer_idx: + b, h, _, f = k.shape + past_key_value.kv_states.append( + torch.zeros(b, h, f, self.head_dim, dtype=q.dtype, device=q.device) + ) + past_key_value.k_states.append( + torch.zeros(b, h, 1, f, dtype=q.dtype, device=q.device) + ) + # Generating + if q.shape[2] == 1 and kv_seq_len > 1 and past_key_value is not None: + assert use_cache is True + kv_state, k_state = past_key_value.update(k, v, self.layer_idx, + accumulate_in_fp32=self.fp32_attention) + if self.fp32_attention: + q = q.float() + y_true = (torch.einsum('bhlf,bhfd->bhld', q, kv_state.float()) / + torch.einsum('bhlf,bhlf->bhl', q, k_state.float())[..., None]).to(dtype=k.dtype) + else: + y_true = (torch.einsum('bhlf,bhfd->bhld', q, kv_state) / + torch.einsum('bhlf,bhlf->bhl', q, k_state)[..., None]) + else: + kv_state = past_key_value.kv_states[self.layer_idx] + k_state = past_key_value.k_states[self.layer_idx] + y_true, _, _ = linear_attention(q, k, v, self.fp32_attention, self.eps) # Ordinarily the states are ignored + past_key_value.update(k.detach(), v.detach(), self.layer_idx, + accumulate_in_fp32=self.fp32_attention) + # doing some unnecessary recomputation here + else: + y_true, _, _ = linear_attention(q, k, v, self.fp32_attention, self.eps) + + # Concatenate heads and apply output projection + y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) + y_true = self.o_proj(y_true) + attn_weights = None + + return y_true, attn_weights, past_key_value + + +class LinearAttentionState(Cache): + """ + Handle the KV and K states for linear attention + - Adopts HF Transformers `past_key_values` convention + - Inherits from `Cache` class + - Modified from transformers.cache_utils.DynamicCache (v4.36) + """ + def __init__(self) -> None: + self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36 + self._seen_tokens_by_layer: List[int] = [] + self.kv_states: List[torch.Tensor] = [] + self.k_states: List[torch.Tensor] = [] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """ + Returns the sequence length of the cached states. A layer index can be optionally passed. + """ + if len(self._seen_tokens_by_layer) <= layer_idx: # Initializing kv and k states + self._seen_tokens_by_layer.append(0) + return self._seen_tokens_by_layer[layer_idx] + + def get_max_length(self) -> Optional[int]: + """ + Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length. + """ + return None + + def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: + """Given the sequence length of the new inputs, returns the usable length of the cache.""" + # Cache without size limit -> all cache is usable + # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache + # length, we will need to evict part of the cache (and thus not all cache is usable) + max_length = self.get_max_length() + previous_seq_length = self.get_seq_length(layer_idx) + if max_length is not None and previous_seq_length + new_seq_length > max_length: + return max_length - new_seq_length + return previous_seq_length + + def update(self, key_states: torch.Tensor, value_states: torch.Tensor, + layer_idx: Optional[int] = None, cache_kwargs: Optional[any] = None, + accumulate_in_fp32: bool = True, **kwargs: any, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + with torch.no_grad (): + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + dtype = key_states.dtype + if accumulate_in_fp32: + key_states, value_states = key_states.float(), value_states.float() + + kv_state = torch.einsum('bhlf,bhld->bhfd', key_states, value_states).detach() + k_state = key_states.sum(dim=-2, keepdim=True).detach() # b, h, 1, f; note the 1 + # Update the cache + if len(self.k_states) <= layer_idx: # Initializing kv and k states + print('if len(self.k_states) <= layer_idx: # Initializing kv and k states') + self.kv_states.append(kv_state.to(dtype)) + self.k_states.append(k_state.to(dtype)) + else: + kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(dtype) + k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(dtype) + self.kv_states[layer_idx] = kv_state + self.k_states[layer_idx] = k_state + self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2] + return self.kv_states[layer_idx], self.k_states[layer_idx] + + def to_legacy_cache(self): + """Hack, but just return self""" + return self + + def reorder_cache(self, beam_idx: torch.LongTensor): + """ + Reorders the cache for beam search, given the selected beam indices. + -> Copied from transformers/src/transformers/cache_utils.py + """ + raise NotImplementedError('Reordering cache not implemented for LinearAttentionState') diff --git a/src/model/linear_attention/linear_window_attention_sw.py b/src/model/linear_attention/linear_window_attention_sw.py new file mode 100644 index 0000000000000000000000000000000000000000..1a9a77129bb16a49c214a7fe14de20fb3a8501f2 --- /dev/null +++ b/src/model/linear_attention/linear_window_attention_sw.py @@ -0,0 +1,339 @@ +""" +Subquadratic attention combining sliding window and linear attentions +- Using "standard" sliding windows +- Didactically computes outputs with n^2 attention weights for now +- Copied + adapted from linear_window_attention_tk.py for single-file reference + +For each layer: +- We first compute (softmax) attention over sliding windows +- We then compute standard linear attention to "fill in" the earlier parts +- We combine to model the entire sequence +""" +from typing import List, Tuple, Optional, Callable +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers.cache_utils import Cache + +from .linear_attention import ( + LolcatsLinearAttention, LinearAttentionState, + softmax_attention +) + +# ---------------------- +# Sliding window helpers +# ---------------------- +def get_masks(window_size: int, q_len: int, k_len: int, + device: torch.device) -> tuple[torch.Tensor]: + """ + Return masks for softmax and linear attention terms + -> 1 is include, 0 is ignore + """ + kwargs = {'device': device, 'dtype': int} + causal_mask = torch.ones((q_len, k_len), **kwargs).tril(k_len - q_len) + linear_mask = torch.ones((q_len, k_len), **kwargs).tril(k_len - q_len - window_size) + window_mask = causal_mask - linear_mask + # Return softmax mask (window), linear attention mask + # -> shapes broadcast over (b, h, q_len, k_len) + return window_mask[None, None, ...], linear_mask[None, None, ...] + + +def hybrid_attention_quadratic(q: torch.Tensor, k: torch.Tensor, + f_q: torch.Tensor, f_k: torch.Tensor, + v: torch.Tensor, + window_factor: torch.Tensor, + linear_factor: torch.Tensor, + window_size: int, + kv_state: torch.Tensor = None, + k_state: torch.Tensor = None, + eps: float = 1e-12, + mask_value: float=-1e8): + """ + Hybrid attention combining sliding window and linear attentions + """ + + mask_window, mask_linear = get_masks(window_size, q.shape[-2], k.shape[-2], q.device) + + # 1. Sliding window (softmax attention) + a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k.float()) * (k.shape[-1] ** -0.5) + a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value) + # torch.softmax(a_sm, dim=-1), but we account for the max when combining + a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) + a_sm = window_factor * torch.exp(a_sm - a_sm_max) + sum_sm = a_sm.sum(dim=-1, keepdim=True) + + # 2. Under window (linear attention) + a_ln = torch.einsum('bhmd,bhnd->bhmn', f_q.float(), f_k.float()) + a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0) + sum_ln = a_ln.sum(dim=-1, keepdim=True) + + # 3. Combine + a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights + # Allow outputs to also depend on prior kv_state and k_state + y = torch.einsum('bhmn,bhnd->bhmd', a_sm + a_ln, v.float()) + if kv_state is not None: # Combine with prior kv_state and k_state + y += linear_factor * torch.einsum('bhld,bhdf->bhlf', f_q.float(), kv_state.float()) + sum_ln += linear_factor * torch.einsum( + 'bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None] + y = (y / (sum_sm + sum_ln)).to(q.dtype) + return y, a # attention weights only for the last chunk + + +# --------------------- +# Attention layer class +# --------------------- +class LolcatsSlidingWindowAttention(LolcatsLinearAttention): + """ + Lolcats attention combining sliding window and linear attention + """ + def __init__(self, + window_size: int = 64, + decode_window_size: int = None, + affine_attention_factors: bool = False, + init_window_factor: float = 0, + train_window_factor: bool = True, + state_grad_enabled: bool = False, + **kwargs): + self.window_size = window_size + self.decode_window_size = ( + decode_window_size if decode_window_size is not None else window_size + ) + self.window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1} + super().__init__(**kwargs) + self.attention_type = kwargs['attention_type'] # 'hedgehog_llama_window_sw' + # Determine how we compute attentions + self.quadratic_attention = hybrid_attention_quadratic + self.attention_type = kwargs['attention_type'] # 'hedgehog_long_llama_window_sw' + # Learnable factor for combining attentions + self.affine_attention_factors = affine_attention_factors + device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype + if train_window_factor: + self.window_factors = nn.Parameter( + init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)) + else: + self.register_buffer( + "window_factors", init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype) + ) + # Whether we use original flash attention 2 inference (use during attention transfer) + self.base_inference = False + self.state_grad_enabled = state_grad_enabled + + def forward(self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Forward pass with the option to compute attention weights multiple ways + if self.train_attention is True + -> Consistent with HuggingFace Transformers for easy use with their pretrained models + """ + b, l, _ = hidden_states.size() + q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask, + position_ids, past_key_value) + f_q, f_k = self.feature_map_q(q), self.feature_map_k(k) # Have to do after repeat for grouped-query attn if we use same fmap + + if self.train_attention: + # 1. Compute "ground-truth" attention output and weights + with torch.no_grad(): + _y_true, a_true = softmax_attention(q, k, v)[:2] + y_true = _y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) + y_true = self.o_proj(y_true) + + # 2. Compute "predicted" attention outputs + # compute attn weights under sliding window + window_factors = F.sigmoid(self.window_factors) + linear_factors = 1 - window_factors if self.affine_attention_factors else 1 + y_pred, a_pred = self.quadratic_attention(q, k, f_q, f_k, v, + window_factors, linear_factors, + window_size=self.window_size) + attn_weights = ((a_pred, a_true), (y_pred, _y_true)) + else: + attn_weights = None + # attention_mask = None # For now this is always True + if past_key_value is None: # Regular training + window_factors = F.sigmoid(self.window_factors) + linear_factors = 1 - window_factors if self.affine_attention_factors else 1 + y_true, a_pred = self.quadratic_attention(q, k, f_q, f_k, v, + window_factors, linear_factors, + window_size=self.window_size) + attn_weights = a_pred + else: + past_key_value.window_size = self.decode_window_size + if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: # Generating + assert use_cache is True + _kv = past_key_value.update_for_decoding(k, v, self.layer_idx, + self.feature_map_k, + dtype=q.dtype) + k_cache, v_cache, f_kv_state, f_k_state = _kv + + # Sliding window + linear attention decode + window_factors = F.sigmoid(self.window_factors) + linear_factors = 1 - window_factors if self.affine_attention_factors else 1 + + # Softmax attention terms + a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5) + a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) + a_sm = window_factors * torch.exp(a_sm - a_sm_max) + sum_sm = a_sm.sum(dim=-1, keepdim=True) + + # Combine with linear attention terms + y_true = (torch.einsum('bhmn,bhnd->bhmd', a_sm, v_cache.float()) + + linear_factors * torch.einsum('bhlf,bhfd->bhld', f_q.float(), f_kv_state.float())) + sum_ln = linear_factors * torch.einsum( + 'bhlf,bhnf->bhl', f_q.float(), f_k_state.float())[..., None] + y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype) + + else: # Stateful training + try: + kv_state = past_key_value.kv_states[self.layer_idx] + k_state = past_key_value.k_states[self.layer_idx] + except IndexError: + kv_state, k_state = None, None + window_factors = F.sigmoid(self.window_factors) + linear_factors = 1 - window_factors if self.affine_attention_factors else 1 + y_true, _ = self.quadratic_attention(q, k, f_q, f_k, v, + window_factors, linear_factors, + window_size=self.window_size, + kv_state=kv_state, + k_state=k_state) + # Save and update KV cache and states + # past_key_value.update(k, v.detach(), self.layer_idx, + # fmap_key_states=f_k.detach(), + # accumulate_in_fp32=True) + past_key_value.update(k, v, self.layer_idx, + fmap_key_states=f_k, + accumulate_in_fp32=True) + # Concatenate heads and apply output projection + y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) + y_true = self.o_proj(y_true) + return y_true, attn_weights, past_key_value + + +class LinearAttentionSlidingWindowCache(LinearAttentionState): + """ + Class for `past_key_values` + -> Alternative to KV cache; here we only maintain a "KV state" and "K state" + -> Modified from transformers.cache_utils.DynamicCache (v4.36) + """ + def __init__(self, window_size: int = 64) -> None: + super().__init__() + self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36 + self._seen_tokens_by_layer: List[int] = [] + self.kv_states: List[torch.Tensor] = [] + self.k_states: List[torch.Tensor] = [] + + # Account for sliding windows + self.decode_kv_states: List[torch.Tensor] = [] + self.decode_k_states: List[torch.Tensor] = [] + self.k_cache: List[torch.Tensor] = [] + self.v_cache: List[torch.Tensor] = [] + self.window_size = window_size + + def update(self, key_states: torch.Tensor, value_states: torch.Tensor, + layer_idx: Optional[int] = None, cache_kwargs: Optional[any] = None, + accumulate_in_fp32: bool = False, + fmap_key_states: torch.Tensor = None, # should not be None + grad_enabled: bool = False, + **kwargs: any, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Update KV, K states; and KV cache during training + - For decoding, use `self.decode_kv_states` to keep track of KV states + up to sliding window terms + - For (chunked) training, use `self.kv_states` to keep track of KV states + up to end of sequence + - Likewise for `self.decode_k_states` and `self.k_states` + """ + with torch.set_grad_enabled(grad_enabled): + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + dtype = key_states.dtype + if accumulate_in_fp32: + # key_states = key_states.float() + fmap_key_states = fmap_key_states.float() + value_states = value_states.float() + + # Decoding KV state (KV terms up to last window_size) + decode_kv_state = torch.einsum( + 'bhlf,bhld->bhfd', fmap_key_states[:, :, :-self.window_size], value_states[:, :, :-self.window_size] + ) + # KV state + kv_state = decode_kv_state + torch.einsum( + 'bhlf,bhld->bhfd', fmap_key_states[:, :, -self.window_size:], value_states[:, :, -self.window_size:] + ) + # shape is b, h, 1, f; note the 1 + decode_k_state = fmap_key_states[:, :, :-self.window_size].sum(dim=-2, keepdim=True) + k_state = (decode_k_state + fmap_key_states[:, :, -self.window_size:].sum(dim=-2, keepdim=True)) + + # Update the cache + if len(self.k_states) <= layer_idx: # Initializing kv and k states + self.kv_states.append(kv_state.to(dtype)) + self.k_states.append(k_state.to(dtype)) + + self.decode_kv_states.append(decode_kv_state.to(dtype)) + self.decode_k_states.append(decode_k_state.to(dtype)) + + self.k_cache.append(key_states[:, :, -self.window_size:, :]) + self.v_cache.append(value_states[:, :, -self.window_size:, :].to(dtype)) + # self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2]) + else: + # Update kv and k states recurrently + kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(dtype) + k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(dtype) + self.kv_states[layer_idx] = kv_state + self.k_states[layer_idx] = k_state + + decode_kv_state = (self.decode_kv_states[layer_idx].to(kv_state.dtype) + + decode_kv_state).to(dtype) + decode_k_state = (self.decode_k_states[layer_idx].to(kv_state.dtype) + + decode_k_state).to(dtype) + self.decode_kv_states[layer_idx] = decode_kv_state + self.decode_k_states[layer_idx] = decode_k_state + + self.k_cache[layer_idx] = key_states[:, :, -self.window_size:, :] + self.v_cache[layer_idx] = value_states[:, :, -self.window_size:, :] + self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2] + + return self.kv_states[layer_idx], self.k_states[layer_idx] + + def update_for_decoding(self, keys: torch.Tensor, values: torch.Tensor, + layer_idx: int, feature_map_k: Callable, dtype: torch.dtype): + """ + Update the decoding KV and K states, and KV cache, during decodeing + """ + with torch.no_grad(): + k_cache = self.k_cache[layer_idx] + v_cache = self.v_cache[layer_idx] + + if k_cache.shape[-2] < self.window_size: # build window-size cache + self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2) + self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2) + else: + # MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size + # if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache + # f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device) + # else: + # f_k_state = feature_map_k(k_cache[:, :, :1, :]) + # -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation + k_state = feature_map_k(k_cache[:, :, :1, :]) + v_state = v_cache[:, :, :1, :] + kv_state = torch.einsum('bhlf,bhld->bhfd', k_state.float(), v_state.float()).to(dtype) # b, h, f, d + self.decode_kv_states[layer_idx] += kv_state + self.decode_k_states[layer_idx] += k_state + + self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], keys], dim=-2) + self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], values], dim=-2) + + if layer_idx == 0: + self._seen_tokens += keys.shape[-2] + self._seen_tokens_by_layer[layer_idx] += keys.shape[-2] + return (self.k_cache[layer_idx], self.v_cache[layer_idx], + self.decode_kv_states[layer_idx], self.decode_k_states[layer_idx]) diff --git a/src/model/linear_attention/linear_window_attention_sw_linear.py b/src/model/linear_attention/linear_window_attention_sw_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..ea52584212139fd4be34109841417a1e5ad31ee0 --- /dev/null +++ b/src/model/linear_attention/linear_window_attention_sw_linear.py @@ -0,0 +1,522 @@ +""" +Subquadratic attention combining sliding window and linear attentions +- Using "standard" sliding windows +- Didactically computes outputs with n^2 attention weights for now +- Copied + adapted from linear_window_attention_tk.py for single-file reference + +For each layer: +- We first compute (softmax) attention over sliding windows +- We then compute standard linear attention to "fill in" the earlier parts +- We combine to model the entire sequence +""" +from typing import List, Tuple, Optional, Callable +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers.cache_utils import Cache +try: + from transformers.modeling_flash_attention_utils import _flash_attention_forward +except ModuleNotFoundError: + _flash_attention_forward = None # Transformers v4.36 + +# Causal linear attention dot product CUDA kernel from fast-transformers +from csrc import causal_dot_product + +from src.model.rotary import apply_rotary_pos_emb +from .linear_attention import ( + LolcatsLinearAttention, LinearAttentionState, + softmax_attention +) + +# ---------------------- +# Sliding window helpers +# ---------------------- +def get_masks(window_size: int, q_len: int, k_len: int, + device: torch.device) -> tuple[torch.Tensor]: + """ + Return masks for softmax and linear attention terms + -> 1 is include, 0 is ignore + """ + kwargs = {'device': device, 'dtype': int} + causal_mask = torch.ones((q_len, k_len), **kwargs).tril(max(k_len - q_len, 0)) + linear_mask = torch.ones((q_len, k_len), **kwargs).tril(max(k_len - q_len, 0) - window_size) + window_mask = causal_mask - linear_mask + # Return softmax mask (window), linear attention mask + # -> shapes broadcast over (b, h, q_len, k_len) + return window_mask[None, None, ...], linear_mask[None, None, ...] + + +def hybrid_attention_quadratic(q: torch.Tensor, k: torch.Tensor, + f_q: torch.Tensor, f_k: torch.Tensor, + v: torch.Tensor, + window_factor: torch.Tensor, + linear_factor: torch.Tensor, + window_size: int, + kv_state: torch.Tensor = None, + k_state: torch.Tensor = None, + eps: float = 1e-12, + mask_value: float=-1e8): + """ + Hybrid attention combining sliding window and linear attentions + """ + + mask_window, mask_linear = get_masks(window_size, q.shape[-2], k.shape[-2], q.device) + + # 1. Sliding window (softmax attention) + a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k.float()) * (k.shape[-1] ** -0.5) + a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value) + # torch.softmax(a_sm, dim=-1), but we account for the max when combining + a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) + a_sm = window_factor * torch.exp(a_sm - a_sm_max) + sum_sm = a_sm.sum(dim=-1, keepdim=True) + + # 2. Under window (linear attention) + a_ln = torch.einsum('bhmd,bhnd->bhmn', f_q.float(), f_k.float()) + a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0) + sum_ln = a_ln.sum(dim=-1, keepdim=True) + + # 3. Combine + a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights + # Allow outputs to also depend on prior kv_state and k_state + y = torch.einsum('bhmn,bhnd->bhmd', a_sm + a_ln, v.float()) + if kv_state is not None: # Combine with prior kv_state and k_state + y += linear_factor * torch.einsum('bhld,bhdf->bhlf', f_q.float(), kv_state.float()) + sum_ln += linear_factor * torch.einsum( + 'bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None] + y = (y / (sum_sm + sum_ln)).to(q.dtype) + return y, a # attention weights only for the last chunk + + +# ------------------------------ +# Hybrid window attention linear +# ------------------------------ +def under_window_linear_attention(f_q: torch.Tensor, f_k: torch.Tensor, v: torch.Tensor, + window_size: int, linear_factor: float, eps: float=1e-12): + """Compute hybrid window attention dot product with linear complexity in q_len""" + dtype = f_q.dtype + w = window_size + f_k = F.pad(f_k, (0, 0, w, 0), value=0)[:, :, :-w, :] + v = F.pad(v, (0, 0, w, 0), value=0)[:, :, :-w, :] + qkv = linear_factor * causal_dot_product(f_q.contiguous().to(dtype=torch.float32), + f_k.contiguous().to(dtype=torch.float32), + v.contiguous().to(dtype=torch.float32)).to(dtype=dtype) + sum_f_k = f_k.float().cumsum(dim=2).to(dtype=dtype) + sum_qk = linear_factor * torch.einsum("bhld,bhld->bhl", f_q, sum_f_k)[..., None] + sum_qk[sum_qk == 0] += eps + return qkv, sum_qk + + +def sliding_window_softmax_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + window_size: int, window_factor: float, mask_value: float=-1e8): + """ + Compute sliding window softmax attention without materializing + O(seq_len^2) attention weights + """ + d = q.shape[-1] + # Compute windows for keys + window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1} + k = F.pad(k, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs) + v = F.pad(v, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs) + + # Compute windowed_softmax(qk); causal in its construction + a_sm = torch.einsum('bhld,bhldw->bhlw', q, k) * (d ** -0.5) + a_sm[a_sm == 0] = -torch.finfo(q.dtype).max # heuristic for zeroing out padding above + a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) + a_sm = window_factor * torch.exp(a_sm - a_sm_max) + sum_sm = a_sm.sum(dim=-1, keepdim=True) + return torch.einsum('bhlw,bhldw->bhld', a_sm, v), sum_sm + # return torch.einsum('bhlw,bhldw->bhld', torch.softmax(qk, dim=-1), v) + + +def hybrid_attention_linear(q: torch.Tensor, k: torch.Tensor, + f_q: torch.Tensor, f_k: torch.Tensor, + v: torch.Tensor, + window_factor: torch.Tensor = None, + linear_factor: torch.Tensor = None, + window_size: int = 64, + kv_state: torch.Tensor = None, + k_state: torch.Tensor = None, + eps: float = 1e-12, + mask_value: float=-1e8): + """ + Alternative hybrid attention combining sliding window and linear attentions + -> Uses O(n) memory if n is sequence length by padding and unfolding windows + """ + window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1} + # 1. Sliding window (softmax attention) + with torch.no_grad(): + qkv_sm, sum_qk_sm = sliding_window_softmax_attention(q, k, v, window_size, window_factor, mask_value) + + # 2. Under window (linear attention) + qkv_ln, sum_qk_ln = under_window_linear_attention(f_q, f_k, v, window_size, linear_factor, eps) + + # 3. Combine + y = (qkv_sm + qkv_ln) / (sum_qk_sm + sum_qk_ln) + return y, None + + +# --------------------- +# Attention layer class +# --------------------- +class LolcatsLinearSlidingWindowAttention(LolcatsLinearAttention): + """ + Lolcats attention combining sliding window and linear attention + """ + def __init__(self, + window_size: int = 64, + decode_window_size: int = None, + affine_attention_factors: bool = False, + init_window_factor: float = 0, + train_window_factor: bool = True, + state_grad_enabled: bool = False, + **kwargs): + self.window_size = window_size + self.decode_window_size = ( + decode_window_size if decode_window_size is not None else window_size + ) + self.window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1} + super().__init__(**kwargs) + # Determine how we compute attentions + self.linear_attention = hybrid_attention_linear + self.attention_type = 'lolcats_llama_window_sw' + # Learnable factor for combining attentions + self.affine_attention_factors = affine_attention_factors + device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype + if train_window_factor: + self.window_factors = nn.Parameter( + init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)) + else: + self.register_buffer( + "window_factors", init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype) + ) + # Whether we use original flash attention 2 inference (use during attention transfer) + self.base_inference = False + self.state_grad_enabled = state_grad_enabled + + def forward(self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Forward pass with the option to compute attention weights multiple ways + if self.train_attention is True + -> Consistent with HuggingFace Transformers for easy use with their pretrained models + """ + b, l, _ = hidden_states.size() + + if self.train_attention and self.base_inference: + with torch.no_grad(): + _y_true = flash_attention_2(self, # self.base_attn, + hidden_states=hidden_states, + attention_mask=None, + position_ids=position_ids, + past_key_value=None, + output_attentions=False, + use_cache=False)[0] + # _y_true.shape is (batch_size, seq_len, num_heads, head_dim) + y_true = _y_true.reshape(b, l, -1).contiguous() + y_true = self.o_proj(y_true) + # layer_io = (hidden_states, _y_true) # hack + layer_io = (hidden_states.cpu(), _y_true.cpu()) # hack + return y_true, layer_io, None + + else: + q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask, + position_ids, past_key_value) + f_q, f_k = self.feature_map_q(q), self.feature_map_k(k) # Have to do after repeat for grouped-query attn if we use same fmap + + attn_weights = None + # attention_mask = None # For now this is always True + if past_key_value is None: # Regular training + window_factors = F.sigmoid(self.window_factors) + linear_factors = 1 - window_factors if self.affine_attention_factors else 1 + y_true, a_pred = self.linear_attention(q, k, f_q, f_k, v, + window_factors, linear_factors, + window_size=self.window_size) + attn_weights = a_pred + else: + past_key_value.window_size = self.decode_window_size + if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: # Generating + assert use_cache is True + _kv = past_key_value.update_for_decoding(k, v, self.layer_idx, + self.feature_map_k, + dtype=q.dtype) + k_cache, v_cache, f_kv_state, f_k_state = _kv + + # Sliding window + linear attention decode + window_factors = F.sigmoid(self.window_factors) + linear_factors = 1 - window_factors if self.affine_attention_factors else 1 + + # Softmax attention terms + a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5) + a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) + a_sm = window_factors * torch.exp(a_sm - a_sm_max) + sum_sm = a_sm.sum(dim=-1, keepdim=True) + + # Combine with linear attention terms + y_true = (torch.einsum('bhmn,bhnd->bhmd', a_sm, v_cache.float()) + + linear_factors * torch.einsum('bhlf,bhfd->bhld', f_q.float(), f_kv_state.float())) + sum_ln = linear_factors * torch.einsum( + 'bhlf,bhnf->bhl', f_q.float(), f_k_state.float())[..., None] + y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype) + + else: # Stateful training + try: + kv_state = past_key_value.kv_states[self.layer_idx] + k_state = past_key_value.k_states[self.layer_idx] + except IndexError: + kv_state, k_state = None, None + window_factors = F.sigmoid(self.window_factors) + linear_factors = 1 - window_factors if self.affine_attention_factors else 1 + y_true, _ = self.linear_attention(q, k, f_q, f_k, v, + window_factors, linear_factors, + window_size=self.window_size, + kv_state=kv_state, + k_state=k_state) + # Save and update KV cache and states + # past_key_value.update(k, v.detach(), self.layer_idx, + # fmap_key_states=f_k.detach(), + # accumulate_in_fp32=True) + past_key_value.update(k, v, self.layer_idx, + fmap_key_states=f_k, + accumulate_in_fp32=True) + # Concatenate heads and apply output projection + _y_true = y_true.transpose(1, 2).contiguous() + y_true = self.o_proj(_y_true.view(b, l, self.hidden_size)) + + if self.train_attention: + attn_weights = _y_true # flash_attn outputs are shape (b, l, h, d) + return y_true, attn_weights, past_key_value + + +class LinearAttentionSlidingWindowCache(LinearAttentionState): + """ + Class for `past_key_values` + -> Alternative to KV cache; here we only maintain a "KV state" and "K state" + -> Modified from transformers.cache_utils.DynamicCache (v4.36) + """ + def __init__(self, window_size: int = 64) -> None: + super().__init__() + self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36 + self._seen_tokens_by_layer: List[int] = [] + self.kv_states: List[torch.Tensor] = [] + self.k_states: List[torch.Tensor] = [] + + # Account for sliding windows + self.decode_kv_states: List[torch.Tensor] = [] + self.decode_k_states: List[torch.Tensor] = [] + self.k_cache: List[torch.Tensor] = [] + self.v_cache: List[torch.Tensor] = [] + self.window_size = window_size + + def update(self, key_states: torch.Tensor, value_states: torch.Tensor, + layer_idx: Optional[int] = None, cache_kwargs: Optional[any] = None, + accumulate_in_fp32: bool = False, + fmap_key_states: torch.Tensor = None, # should not be None + grad_enabled: bool = False, + **kwargs: any, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Update KV, K states; and KV cache during training + - For decoding, use `self.decode_kv_states` to keep track of KV states + up to sliding window terms + - For (chunked) training, use `self.kv_states` to keep track of KV states + up to end of sequence + - Likewise for `self.decode_k_states` and `self.k_states` + """ + with torch.set_grad_enabled(grad_enabled): + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + dtype = key_states.dtype + if accumulate_in_fp32: + # key_states = key_states.float() + fmap_key_states = fmap_key_states.float() + value_states = value_states.float() + + # Decoding KV state (KV terms up to last window_size) + decode_kv_state = torch.einsum( + 'bhlf,bhld->bhfd', fmap_key_states[:, :, :-self.window_size], value_states[:, :, :-self.window_size] + ) + # KV state + kv_state = decode_kv_state + torch.einsum( + 'bhlf,bhld->bhfd', fmap_key_states[:, :, -self.window_size:], value_states[:, :, -self.window_size:] + ) + # shape is b, h, 1, f; note the 1 + decode_k_state = fmap_key_states[:, :, :-self.window_size].sum(dim=-2, keepdim=True) + k_state = (decode_k_state + fmap_key_states[:, :, -self.window_size:].sum(dim=-2, keepdim=True)) + + # Update the cache + if len(self.k_states) <= layer_idx: # Initializing kv and k states + self.kv_states.append(kv_state.to(dtype)) + self.k_states.append(k_state.to(dtype)) + + self.decode_kv_states.append(decode_kv_state.to(dtype)) + self.decode_k_states.append(decode_k_state.to(dtype)) + + self.k_cache.append(key_states[:, :, -self.window_size:, :]) + self.v_cache.append(value_states[:, :, -self.window_size:, :].to(dtype)) + # self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2]) + else: + # Update kv and k states recurrently + kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(dtype) + k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(dtype) + self.kv_states[layer_idx] = kv_state + self.k_states[layer_idx] = k_state + + decode_kv_state = (self.decode_kv_states[layer_idx].to(kv_state.dtype) + + decode_kv_state).to(dtype) + decode_k_state = (self.decode_k_states[layer_idx].to(kv_state.dtype) + + decode_k_state).to(dtype) + self.decode_kv_states[layer_idx] = decode_kv_state + self.decode_k_states[layer_idx] = decode_k_state + + self.k_cache[layer_idx] = key_states[:, :, -self.window_size:, :] + self.v_cache[layer_idx] = value_states[:, :, -self.window_size:, :] + self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2] + + return self.kv_states[layer_idx], self.k_states[layer_idx] + + def update_for_decoding(self, keys: torch.Tensor, values: torch.Tensor, + layer_idx: int, feature_map_k: Callable, dtype: torch.dtype): + """ + Update the decoding KV and K states, and KV cache, during decodeing + """ + with torch.no_grad(): + k_cache = self.k_cache[layer_idx] + v_cache = self.v_cache[layer_idx] + + if k_cache.shape[-2] < self.window_size: # build window-size cache + self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2) + self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2) + else: + # MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size + # if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache + # f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device) + # else: + # f_k_state = feature_map_k(k_cache[:, :, :1, :]) + # -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation + k_state = feature_map_k(k_cache[:, :, :1, :]) + v_state = v_cache[:, :, :1, :] + kv_state = torch.einsum('bhlf,bhld->bhfd', k_state.float(), v_state.float()).to(dtype) # b, h, f, d + self.decode_kv_states[layer_idx] += kv_state + self.decode_k_states[layer_idx] += k_state + + self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], keys], dim=-2) + self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], values], dim=-2) + + if layer_idx == 0: + self._seen_tokens += keys.shape[-2] + self._seen_tokens_by_layer[layer_idx] += keys.shape[-2] + return (self.k_cache[layer_idx], self.v_cache[layer_idx], + self.decode_kv_states[layer_idx], self.decode_k_states[layer_idx]) + + +# ----------------- +# Flash Attention 2 +# ----------------- + +def flash_attention_2(self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Wrapper for LlamaFlashAttention2 + Copied and modified from HF Transformers v4.36 and v4.43 implementations + - (4.43) https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L402 + - (4.36) https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L456 + """ + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + try: # As in Transformers v4.36 + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + except: # As in Transformers v4.39 + cos, sin = self.rotary_emb(key_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + if getattr(self, '_flash_attention_forward', False): + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate, + is_causal=True, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=0, # dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=True, + ) + return attn_output, past_key_value diff --git a/src/model/linear_attention/linear_window_attention_sw_long.py b/src/model/linear_attention/linear_window_attention_sw_long.py new file mode 100644 index 0000000000000000000000000000000000000000..952abfa9be75c64ceb2b4aef8b37112a6a70f028 --- /dev/null +++ b/src/model/linear_attention/linear_window_attention_sw_long.py @@ -0,0 +1,23 @@ +""" +LoLCATs attention combining sliding window and linear attentions +- Using standard sliding window arrangement +- Training over long sequences with fixed memory with recurrent view +- During attention transfer, use Flash Attention to compute softmax attention outputs + +For each layer: +- We first compute (softmax) attention over sliding windows +- We then compute standard linear attention to "fill in" the earlier parts +- We combine to model the entire sequence +""" +from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention +from .linear_window_attention_sw import hybrid_attention_quadratic + + +class LolcatsSlidingWindowLongAttention(LolcatsTKWindowLongAttention): + """ + Lolcats attention combining sliding window and linear attention + """ + def __init__(self, remove_base_attn=True, **kwargs): + # keep self.base_attn for Flash Attention inference + super().__init__(remove_base_attn=True, **kwargs) + self.quadratic_attention = hybrid_attention_quadratic diff --git a/src/model/linear_attention/linear_window_attention_tk.py b/src/model/linear_attention/linear_window_attention_tk.py new file mode 100644 index 0000000000000000000000000000000000000000..37e81dc66b14e33baa3478f5ed7bc5dba62b7165 --- /dev/null +++ b/src/model/linear_attention/linear_window_attention_tk.py @@ -0,0 +1,342 @@ +""" +Subquadratic attention combining sliding window and linear attentions +- Using the TK "terracing" arrangement + +For each layer: +- We first compute (softmax) attention over sliding windows +- We then compute standard linear attention to "fill in" the earlier parts +- We combine to model the entire sequence +""" +from typing import List, Tuple, Optional, Callable +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers.cache_utils import Cache + +from .linear_attention import ( + LolcatsLinearAttention, LinearAttentionState, softmax_attention +) + +# ---------------------- +# Sliding window helpers +# ---------------------- +def get_masks(window_size: int, q_len: int, k_len: int, + device: torch.device) -> tuple[torch.Tensor]: + """ + Return masks for softmax and linear attention terms + -> 1 is include, 0 is ignore + """ + kwargs = {'device': device, 'dtype': int} + l = window_size + m = math.ceil(max(q_len, k_len) / window_size) + # Creates an n x n mask where n = window_size^2 + mask = torch.block_diag(*[torch.ones((l, l), )] * m) + mask += torch.roll(mask, -l, -1) # this adds the terracing + if mask.shape[0] > q_len: + mask = mask[-q_len:] + if mask.shape[1] > k_len: + mask = mask[:, -k_len:] + # Return softmax mask (window), linear attention mask + mask = mask[None, None, ...] # b, h, q_len, k_len + return torch.tril(mask).to(**kwargs), torch.tril(1 - mask).to(**kwargs) + + +def hybrid_attention_quadratic(q: torch.Tensor, k: torch.Tensor, + f_q: torch.Tensor, f_k: torch.Tensor, + v: torch.Tensor, + window_factor: torch.Tensor, + linear_factor: torch.Tensor, + window_size: int, + kv_state: torch.Tensor = None, + k_state: torch.Tensor = None, + eps: float = 1e-12, + mask_value: float=-1e8): + """ + Hybrid attention combining sliding window and linear attentions + """ + + mask_window, mask_linear = get_masks(window_size, q.shape[-2], k.shape[-2], q.device) + + # 1. Sliding window (softmax attention) + a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k.float()) * (k.shape[-1] ** -0.5) + a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value) + # torch.softmax(a_sm, dim=-1), but we account for the max when combining + a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) + a_sm = window_factor * torch.exp(a_sm - a_sm_max) + sum_sm = a_sm.sum(dim=-1, keepdim=True) + + # 2. Under window (linear attention) + a_ln = torch.einsum('bhmd,bhnd->bhmn', f_q.float(), f_k.float()) + a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0) + sum_ln = a_ln.sum(dim=-1, keepdim=True) + + # 3. Combine + a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights + # Allow outputs to also depend on prior kv_state and k_state + y = torch.einsum('bhmn,bhnd->bhmd', a_sm + a_ln, v.float()) + if kv_state is not None: # Combine with prior kv_state and k_state + y += linear_factor * torch.einsum('bhld,bhdf->bhlf', f_q.float(), kv_state.float()) + sum_ln += linear_factor * torch.einsum( + 'bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None] + y = (y / (sum_sm + sum_ln)).to(q.dtype) + return y, a # attention weights only for the last chunk + + +# --------------------- +# Attention layer class +# --------------------- +class LolcatsTKWindowAttention(LolcatsLinearAttention): + """ + Lolcats attention combining sliding window and linear attention + """ + def __init__(self, + window_size: int = 64, + decode_window_size: int = None, + affine_attention_factors: bool = False, + init_window_factor: float = 0, + train_window_factor: bool = True, + state_grad_enabled: bool = False, + **kwargs): + self.window_size = window_size + self.decode_window_size = ( + decode_window_size if decode_window_size is not None else window_size + ) + self.window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1} + super().__init__(**kwargs) + self.attention_type = kwargs['attention_type'] # 'hedgehog_llama_window_tk' + # Determine how we compute attentions + self.quadratic_attention = hybrid_attention_quadratic + self.attention_type = kwargs['attention_type'] # 'hedgehog_long_llama_window_tk' + # Learnable factor for combining attentions + self.affine_attention_factors = affine_attention_factors + device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype + if train_window_factor: + self.window_factors = nn.Parameter( + init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)) + else: + self.register_buffer( + "window_factors", init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype) + ) + # Whether we use original flash attention 2 inference (use during attention transfer) + self.base_inference = False + self.state_grad_enabled = state_grad_enabled + self.window_factor = self.window_factors # legacy naming support + + def forward(self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Forward pass with the option to compute attention weights multiple ways + if self.train_attention is True + -> Consistent with HuggingFace Transformers for easy use with their pretrained models + """ + b, l, _ = hidden_states.size() + q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask, + position_ids, past_key_value) + f_q, f_k = self.feature_map_q(q), self.feature_map_k(k) # Have to do after repeat for grouped-query attn if we use same fmap + + if self.train_attention: + # 1. Compute "ground-truth" attention output and weights + with torch.no_grad(): + _y_true, a_true = softmax_attention(q, k, v)[:2] + y_true = _y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) + y_true = self.o_proj(y_true) + + # 2. Compute "predicted" attention outputs + # compute attn weights under sliding window + window_factors = F.sigmoid(self.window_factors) + linear_factors = 1 - window_factors if self.affine_attention_factors else 1 + y_pred, a_pred = self.quadratic_attention(q, k, f_q, f_k, v, + window_factors, linear_factors, + window_size=self.window_size) + attn_weights = ((a_pred, a_true), (y_pred, _y_true)) + else: + attn_weights = None + # attention_mask = None # For now this is always True + if past_key_value is None: # Regular training + window_factors = F.sigmoid(self.window_factors) + linear_factors = 1 - window_factors if self.affine_attention_factors else 1 + y_true, a_pred = self.quadratic_attention(q, k, f_q, f_k, v, + window_factors, linear_factors, + window_size=self.window_size) + attn_weights = a_pred + else: + past_key_value.window_size = self.decode_window_size + if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: # Generating + assert use_cache is True + _kv = past_key_value.update_for_decoding(k, v, self.layer_idx, + self.feature_map_k, + dtype=q.dtype) + k_cache, v_cache, f_kv_state, f_k_state = _kv + + # Sliding window + linear attention decode + window_factors = F.sigmoid(self.window_factors) + linear_factors = 1 - window_factors if self.affine_attention_factors else 1 + + # Softmax attention terms + a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5) + a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) + a_sm = window_factors * torch.exp(a_sm - a_sm_max) + sum_sm = a_sm.sum(dim=-1, keepdim=True) + + # Combine with linear attention terms + y_true = (torch.einsum('bhmn,bhnd->bhmd', a_sm, v_cache.float()) + + linear_factors * torch.einsum('bhlf,bhfd->bhld', f_q.float(), f_kv_state.float())) + sum_ln = linear_factors * torch.einsum( + 'bhld,bhnd->bhl', f_q.float(), f_k_state.float())[..., None] + y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype) + + else: # Stateful training + try: + kv_state = past_key_value.kv_states[self.layer_idx] + k_state = past_key_value.k_states[self.layer_idx] + except IndexError: + kv_state, k_state = None, None + window_factors = F.sigmoid(self.window_factors) + linear_factors = 1 - window_factors if self.affine_attention_factors else 1 + y_true, _ = self.quadratic_attention(q, k, f_q, f_k, v, + window_factors, linear_factors, + window_size=self.window_size, + kv_state=kv_state, + k_state=k_state) + # Save and update KV cache and states + # past_key_value.update(k, v.detach(), self.layer_idx, + # fmap_key_states=f_k.detach(), + # accumulate_in_fp32=True) + past_key_value.update(k, v, self.layer_idx, + fmap_key_states=f_k, + accumulate_in_fp32=True) + # Concatenate heads and apply output projection + y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) + y_true = self.o_proj(y_true) + return y_true, attn_weights, past_key_value + + +class LinearAttentionTKWindowCache(LinearAttentionState): + """ + Class for `past_key_values` + -> Alternative to KV cache; here we only maintain a "KV state" and "K state" + -> Modified from transformers.cache_utils.DynamicCache (v4.36) + """ + def __init__(self, window_size: int = 64) -> None: + super().__init__() + self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36 + self._seen_tokens_by_layer: List[int] = [] + self.kv_states: List[torch.Tensor] = [] + self.k_states: List[torch.Tensor] = [] + + # Account for sliding windows + self.decode_kv_states: List[torch.Tensor] = [] + self.decode_k_states: List[torch.Tensor] = [] + self.k_cache: List[torch.Tensor] = [] + self.v_cache: List[torch.Tensor] = [] + self.window_size = window_size + + def update(self, key_states: torch.Tensor, value_states: torch.Tensor, + layer_idx: Optional[int] = None, cache_kwargs: Optional[any] = None, + accumulate_in_fp32: bool = False, + fmap_key_states: torch.Tensor = None, # should not be None + grad_enabled: bool = False, + **kwargs: any, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Update KV, K states; and KV cache during training + - For decoding, use `self.decode_kv_states` to keep track of KV states + up to sliding window terms + - For (chunked) training, use `self.kv_states` to keep track of KV states + up to end of sequence + - Likewise for `self.decode_k_states` and `self.k_states` + """ + with torch.set_grad_enabled(grad_enabled): + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + dtype = key_states.dtype + if accumulate_in_fp32: + # key_states = key_states.float() + fmap_key_states = fmap_key_states.float() + value_states = value_states.float() + + # Decoding KV state (KV terms up to last window_size) + decode_kv_state = torch.einsum( + 'bhlf,bhld->bhfd', + fmap_key_states[:, :, :-self.window_size], + value_states[:, :, :-self.window_size] + ) + # KV state + kv_state = decode_kv_state + torch.einsum( + 'bhlf,bhld->bhfd', + fmap_key_states[:, :, -self.window_size:], + value_states[:, :, -self.window_size:] + ) + # shape is b, h, 1, f; note the 1 + decode_k_state = fmap_key_states[:, :, :-self.window_size].sum(dim=-2, keepdim=True) + k_state = (decode_k_state + + fmap_key_states[:, :, -self.window_size:].sum(dim=-2, keepdim=True)) + + # Update the cache + if len(self.k_states) <= layer_idx: # Initializing kv and k states + self.kv_states.append(kv_state.to(dtype)) + self.k_states.append(k_state.to(dtype)) + + self.decode_kv_states.append(decode_kv_state.to(dtype)) + self.decode_k_states.append(decode_k_state.to(dtype)) + + self.k_cache.append(key_states[:, :, -self.window_size:, :]) + self.v_cache.append(value_states[:, :, -self.window_size:, :].to(dtype)) + # self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2]) + else: + # Update kv and k states recurrently + kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(dtype) + k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(dtype) + self.kv_states[layer_idx] = kv_state + self.k_states[layer_idx] = k_state + + decode_kv_state = (self.decode_kv_states[layer_idx].to(kv_state.dtype) + + decode_kv_state).to(dtype) + decode_k_state = (self.decode_k_states[layer_idx].to(kv_state.dtype) + + decode_k_state).to(dtype) + self.decode_kv_states[layer_idx] = decode_kv_state + self.decode_k_states[layer_idx] = decode_k_state + + self.k_cache[layer_idx] = key_states[:, :, -self.window_size:, :] + self.v_cache[layer_idx] = value_states[:, :, -self.window_size:, :] + self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2] + + return self.kv_states[layer_idx], self.k_states[layer_idx] + + def update_for_decoding(self, keys: torch.Tensor, values: torch.Tensor, + layer_idx: int, feature_map_k: Callable, dtype: torch.dtype): + """ + Update the decoding KV and K states, and KV cache, during decodeing + """ + with torch.no_grad(): + k_cache = self.k_cache[layer_idx] + v_cache = self.v_cache[layer_idx] + + if k_cache.shape[-2] < self.window_size: # build window-size cache + self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2) + self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2) + else: + k_state = feature_map_k(k_cache[:, :, :1, :]) + v_state = v_cache[:, :, :1, :] + kv_state = torch.einsum('bhlf,bhld->bhfd', k_state.float(), v_state.float()).to(dtype) # b, h, f, d + self.decode_kv_states[layer_idx] += kv_state + self.decode_k_states[layer_idx] += k_state + + self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], keys], dim=-2) + self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], values], dim=-2) + + if layer_idx == 0: + self._seen_tokens += keys.shape[-2] + self._seen_tokens_by_layer[layer_idx] += keys.shape[-2] + return (self.k_cache[layer_idx], self.v_cache[layer_idx], + self.decode_kv_states[layer_idx], self.decode_k_states[layer_idx]) diff --git a/src/model/linear_attention/linear_window_attention_tk_gen.py b/src/model/linear_attention/linear_window_attention_tk_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..ffc2abf26b8a325873b5bb853842947fe99c25fe --- /dev/null +++ b/src/model/linear_attention/linear_window_attention_tk_gen.py @@ -0,0 +1,166 @@ +""" +LoLCATs + ThunderKittens linear attention + sliding window for generation +""" +from typing import Optional, Tuple, List +import torch +import torch.nn.functional as F + +try: + from thunderkittens import hedgehog as tk_window_hedgehog_attention + print(f"Successfully imported ThunderKittens for TK window attention") +except: + print(f"Failed to import ThunderKittens for TK window attention") + +from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention +from .linear_attention import LinearAttentionState + +class LolcatsWindowAttentionTKGen(LolcatsTKWindowLongAttention): + def __init__(self, *args, window_size: int = 64, **kwargs): + super().__init__(*args, **kwargs) + self.train_attention = False + self.base_inference = False + self.window_size = 64 # hard-coded support for TK kernel + self.decode_window_size = 64 + + b, h, l, d = 1, 32, 8192, 128 + self.y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device='cuda') + self.kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device='cuda') + self.k_state = torch.zeros(b, h, d, dtype=torch.float32, device='cuda') + + def forward(self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[int, torch.Tensor, torch.Tensor]] = None, # “legacy” cache approach + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Forward pass with the option to compute attention weights multiple ways + if self.train_attention is True + -> Consistent with HuggingFace Transformers for easy use with their pretrained models + """ + b, l, _ = hidden_states.size() + assert past_key_value is not None, "past_key_value must be provided for generation" + assert self.train_attention is False, "train_attention is not supported for generation" + assert self.base_inference is False, "base_inference is not supported for generation" + assert use_cache is True, "use_cache must be True for generation" + past_key_value.window_size = self.decode_window_size + q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask, + position_ids, past_key_value) + if q.shape[2] == 1 and kv_seq_len > 1: # Generating after prefill + f_q = self.feature_map_q(q) + _kv = past_key_value.update_for_decoding(k, v, self.layer_idx, + self.feature_map_k) + k_cache, v_cache, kv_state, k_state = _kv + # Sliding window + linear attention decode + window_factors = F.sigmoid(self.window_factors) + linear_factors = 1 - window_factors if self.affine_attention_factors else 1 + + # Softmax attention terms + a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5) + a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) + a_sm = window_factors * torch.exp(a_sm - a_sm_max) + sum_sm = a_sm.sum(dim=-1, keepdim=True) + + # Combine with linear attention terms + y_true = (torch.einsum('bhmn,bhnd->bhmd', a_sm, v_cache.float()) + + linear_factors * torch.einsum('bhld,bhdf->bhlf', f_q.float(), kv_state.float())) + sum_ln = linear_factors * torch.einsum( + 'bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None] + self.y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype) + + else: # Process prefill + # Use TK-implemented linear + terrace window attention + b, h, l, d = q.shape + device = q.device + # tk.hedgehog arguments + # y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device=device) + # kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device=device) + # k_state = torch.zeros(b, h, d, dtype=torch.float32, device=device) + betas = F.sigmoid(self.window_factors[0, :, 0, 0].to(dtype=torch.float32)) + alphas = (1 - betas if self.affine_attention_factors else + torch.ones(betas.shape, dtype=torch.float32, device=device)) + q_map = self.feature_map_q.mlp.layer + k_map = self.feature_map_k.mlp.layer + # Saves outputs to y_pred, k_state, kv_state, where we fuse: + # 1. f_q, f_k = self.feature_map_q(q), self.feature_map_k(k) + # 2. y_pred = attention(q, k, f_q, f_k, v) # b, h, l, d + # 3. kv_state = torch.einsum(‘bhlf,bhld->bhfd’, + # f_k[:, :, :-self.window_size], + # v[:, :, :-self.window_size]) # b, h, f, d + # 4. k_state = f_k[:, :, :-self.window_size].sum(dim=-2) # b, h, d + + tk_window_hedgehog_attention(q.contiguous(), k.contiguous(), v.contiguous(), + self.y_true, self.k_state, self.kv_state, + q_map, k_map, alphas, betas) + + past_key_value.update_with_kv(self.kv_state, self.k_state.unsqueeze(-2), k, v, self.layer_idx) + + # Concatenate heads and apply output projection + y_true = self.y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) + y_true = self.o_proj(y_true) + return y_true, None, past_key_value + + +class LinearAttentionTKWindowGenerationCache(LinearAttentionState): + """ + Class for `past_key_values` + -> Alternative to KV cache; here we only maintain a “KV state” and “K state” + -> Modified from transformers.cache_utils.DynamicCache (v4.36) + """ + def __init__(self, window_size: int = 64) -> None: + super().__init__() + self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36 + self._seen_tokens_by_layer: List[int] = [] + self.window_size = window_size + + self.decode_kv_states: List[torch.Tensor] = [] + self.decode_k_states: List[torch.Tensor] = [] + self.k_cache: List[torch.Tensor] = [] + self.v_cache: List[torch.Tensor] = [] + + def update_with_kv(self, + kv_state: torch.Tensor, k_state: torch.Tensor, + k: torch.Tensor, v: torch.Tensor, + layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Update the cache with new KV and K states + """ + if layer_idx == 0: + self._seen_tokens += k.shape[2] + self._seen_tokens_by_layer.append(k.shape[2]) + + # Initialize KV and K states + if len(self.decode_k_states) <= layer_idx: + self.decode_kv_states.append(kv_state) + self.decode_k_states.append(k_state) + else: # Update KV and K states + self.decode_kv_states[layer_idx] = self.decode_kv_states[layer_idx] + kv_state + self.decode_k_states[layer_idx] = self.decode_k_states[layer_idx] + k_state + + self.k_cache.append(k[:, :, -self.window_size:, :]) + self.v_cache.append(v[:, :, -self.window_size:, :]) + + def update_for_decoding(self, k: torch.Tensor, v: torch.Tensor, + layer_idx: int, feature_map_k: callable) -> None: + """ + Update the cache for decoding + """ + k_cache = self.k_cache[layer_idx] + v_cache = self.v_cache[layer_idx] + k_state = feature_map_k(k_cache[:, :, :1, :]) + v_state = v_cache[:, :, :1, :] + kv_state = torch.einsum('bhlf,bhld->bhfd', k_state.float(), v_state.float()).to(k.dtype) + + self.decode_kv_states[layer_idx] += kv_state + self.decode_k_states[layer_idx] += k_state + + self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], k], dim=-2) + self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], v], dim=-2) + if layer_idx == 0: + self._seen_tokens += k.shape[-2] + self._seen_tokens_by_layer[layer_idx] += k.shape[-2] + return (self.k_cache[layer_idx], self.v_cache[layer_idx], + self.decode_kv_states[layer_idx], self.decode_k_states[layer_idx]) \ No newline at end of file diff --git a/src/model/linear_attention/linear_window_attention_tk_long.py b/src/model/linear_attention/linear_window_attention_tk_long.py new file mode 100644 index 0000000000000000000000000000000000000000..47ffb40bdafdebfbc58cda54c5fe621cad6b82a8 --- /dev/null +++ b/src/model/linear_attention/linear_window_attention_tk_long.py @@ -0,0 +1,240 @@ +""" +LoLCATs attention combining sliding window and linear attentions +- Using the TK "terracing" arrangement +- Training over long sequences with fixed memory with recurrent view +- During attention transfer, use Flash Attention to compute softmax attention outputs + +For each layer: +- We first compute (softmax) attention over sliding windows +- We then compute standard linear attention to "fill in" the earlier parts +- We combine to model the entire sequence +""" +from typing import Optional, Tuple +import torch +import torch.nn.functional as F + +from transformers.cache_utils import Cache +try: + from transformers.modeling_flash_attention_utils import _flash_attention_forward +except ModuleNotFoundError: + _flash_attention_forward = None # Transformers v4.36 + +from src.model.rotary import apply_rotary_pos_emb +from .linear_window_attention_tk import LolcatsTKWindowAttention +from .linear_attention import softmax_attention + + +class LolcatsTKWindowLongAttention(LolcatsTKWindowAttention): + """ + Lolcats attention combining sliding window and linear attention + """ + def __init__(self, remove_base_attn=True, **kwargs): + # keep self.base_attn for Flash Attention inference + super().__init__(remove_base_attn=True, **kwargs) + + def forward(self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Forward pass with the option to compute attention weights multiple ways + if self.train_attention is True + -> Consistent with HuggingFace Transformers for easy use with their pretrained models + """ + b, l, _ = hidden_states.size() + if self.train_attention and self.base_inference: + with torch.no_grad(): + # print(hidden_states.shape) + _y_true = flash_attention_2(self, # self.base_attn, + hidden_states=hidden_states, + attention_mask=None, + position_ids=position_ids, + past_key_value=None, + output_attentions=False, + # output_hidden_states=False, + use_cache=False)[0] + # _y_true.shape is (batch_size, seq_len, num_heads, head_dim) + y_true = _y_true.reshape(b, l, -1).contiguous() + y_true = self.o_proj(y_true) + layer_io = (hidden_states, _y_true) # hack + # layer_io = (hidden_states.cpu(), _y_true.cpu()) # hack + return y_true, layer_io, None + + q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask, + position_ids, past_key_value) + f_q, f_k = self.feature_map_q(q), self.feature_map_k(k) + + # attention_mask = None # For now this is always True + if past_key_value is None: # Regular training + window_factors = F.sigmoid(self.window_factors) + linear_factors = 1 - window_factors if self.affine_attention_factors else 1 + y_pred, a_pred = self.quadratic_attention(q, k, f_q, f_k, v, + window_factors, linear_factors, + window_size=self.window_size,) + else: + past_key_value.window_size = self.decode_window_size + if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: # Generating + assert use_cache is True + _kv = past_key_value.update_for_decoding(k, v, self.layer_idx, + self.feature_map_k, + dtype=q.dtype) + k_cache, v_cache, f_kv_state, f_k_state = _kv + + # Sliding window + linear attention decode + window_factors = F.sigmoid(self.window_factors) + linear_factors = 1 - window_factors if self.affine_attention_factors else 1 + + a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5) + # a_sm = torch.softmax(a_sm, dim=-1) + a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) + a_sm = window_factors * torch.exp(a_sm - a_sm_max) + sum_sm = a_sm.sum(dim=-1, keepdim=True) + + y_pred = (torch.einsum('bhmn,bhnd->bhmd', a_sm, v_cache.float()) + + linear_factors * torch.einsum('bhlf,bhfd->bhld', f_q.float(), f_kv_state.float())) + sum_ln = linear_factors * torch.einsum('bhlf,bhnf->bhl', f_q.float(), f_k_state.float())[..., None] + y_pred = (y_pred / (sum_sm + sum_ln)).to(q.dtype) + + else: # Stateful training + if self.state_grad_enabled and self.layer_idx == 0: + print(f'\n position_ids: [{position_ids[0, 0]}, {position_ids[0, -1]}]') + print(f'q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}') + try: + kv_state = past_key_value.kv_states[self.layer_idx] + k_state = past_key_value.k_states[self.layer_idx] + except IndexError: + kv_state, k_state = None, None + window_factors = F.sigmoid(self.window_factors) + linear_factors = 1 - window_factors if self.affine_attention_factors else 1 + y_pred, a_pred = self.quadratic_attention(q, k, f_q, f_k, v, + window_factors, linear_factors, + window_size=self.window_size, + kv_state=kv_state, + k_state=k_state,) + # Save and update KV cache and states + # past_key_value.update(k, v.detach(), self.layer_idx, + # fmap_key_states=f_k.detach(), + # accumulate_in_fp32=True) + past_key_value.update(k, v, self.layer_idx, + fmap_key_states=f_k, + accumulate_in_fp32=True) + + # Concatenate heads and apply output projection + _y_pred = y_pred.transpose(1, 2).contiguous() + y_pred = self.o_proj(_y_pred.view(b, l, self.hidden_size)) + + if self.train_attention: + with torch.no_grad(): + a_true = softmax_attention(q, k, None, causal=True)[1] + attn_weights = (_y_pred, (a_pred, a_true)) + else: + attn_weights = _y_pred # flash_attn outputs are shape (b, l, h, d) + return y_pred, attn_weights, past_key_value + + +# ----------------- +# Flash Attention 2 +# ----------------- + +def flash_attention_2(self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Wrapper for LlamaFlashAttention2 + Copied and modified from HF Transformers v4.36 and v4.43 implementations + - (4.43) https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L402 + - (4.36) https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L456 + """ + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + try: # As in Transformers v4.36 + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + except: # As in Transformers v4.39 + cos, sin = self.rotary_emb(key_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + if getattr(self, '_flash_attention_forward', False): + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate, + is_causal=True, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=0, # dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=True, + ) + return attn_output, past_key_value \ No newline at end of file diff --git a/src/model/linear_attention/utils.py b/src/model/linear_attention/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e52c7bbc63250853bb15aff8f3d7ba9e301e8757 --- /dev/null +++ b/src/model/linear_attention/utils.py @@ -0,0 +1,31 @@ +""" +Shared attention helpers +""" +import torch + + +# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36) +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). + The hidden states go from: + (batch, num_key_value_heads, seqlen, head_dim) to + (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def mask_attention(qk_dot: torch.Tensor, attn_mask: torch.tensor, + mask_value: float = -10000) -> torch.Tensor: + """ + Apply attention mask (e.g., for padding) + """ + if len(attn_mask.shape) == 4: # attn_mask either (b, h, l, d) or (b, l) + return qk_dot.masked_fill(~attn_mask.bool(), mask_value) + else: + return qk_dot.masked_fill(~attn_mask[:, None, None, :].bool(), mask_value) \ No newline at end of file diff --git a/src/model/load_model.py b/src/model/load_model.py new file mode 100644 index 0000000000000000000000000000000000000000..6f7b978aa5091a6c6fdb5781538534de3296b7a6 --- /dev/null +++ b/src/model/load_model.py @@ -0,0 +1,173 @@ +""" +Helpers to load checkpoints for learned feature maps (attentions) or other parameters +""" +import torch +import torch.nn as nn +from omegaconf import OmegaConf + +from src.utils.logging import print_header, _format_arg +from .convert_model import convert_attention +from .peft import create_peft_config + + +def load_and_convert_attns(model: nn.Module, + model_config: dict, + attention_type: str = None, + checkpoint_path: str = None, + print_model: bool = False, + merge_loras: bool = False, + train_converted: bool = True, # Should be false if loading distill checkpoint by default + peft_gradient_checkpointing: bool = None, + train_attention: bool = False, # Should be true if converting attentions for first time, + freeze_weights: bool = True, + rank: int = 0, + remove_base_attn: bool = True, + ) -> nn.Module: + """ + Load trained attention kernel parameter weights + """ + if freeze_weights: + for p in model.parameters(): + p.requires_grad = False + + if attention_type is not None: # override default + model_config['attention']['attention_type'] = attention_type + model_config['attention']['rank'] = rank # multi-gpu debugging + + model = convert_attention(model, model_config['attention'], + train_attention, remove_base_attn) + + # Add low-rank adapters + peft_key = 'peft' # inconsistency across configs... why do this to myself + if 'peft_config' in model_config['attention']: + peft_key = 'peft_config' + if peft_key in model_config['attention']: + peft_config = model_config['attention'][peft_key] + model, peft_config = create_peft_config(model, peft_config, + model_config['model']['torch_dtype'], + preserve_requires_grad=train_converted, + use_gradient_checkpointing=peft_gradient_checkpointing) + else: + peft_config = None + + if print_model and rank == 0: # Look at model + print_header('*** Model before checkpoint load ***') + print(model) + + # Load any trained attentions + if checkpoint_path is not None: + print(f'Loading weights from {checkpoint_path}...') + state_dict = torch.load(checkpoint_path)['model_state_dict'] + _keys = model.load_state_dict(state_dict, strict=False) + try: + assert len(_keys.unexpected_keys) == 0 + if rank == 0: + print_header('*** All expected keys matched successfully ***') + if print_model: + for k in state_dict.keys(): + print(k) + except Exception as e: + if rank == 0: + print(e) + print_header('*** Error: unexpected keys in checkpoint ***') + print('Unexpected keys:') + for k in _keys.unexpected_keys: + print(k) + if print_model and rank == 0: # Look at model + print_header('*** Model ***') + print(model) + if merge_loras: + model = model.merge_and_unload() + if print_model and rank == 0: + print_header('*** Model (after merging adapters) ***') + print(model) + if print_model and rank == 0: # Look at model + print_header('*** Trainable Parameters ***') + for n, p in model.named_parameters(): + if p.requires_grad: + print(f'├── {n} (dtype = {p.dtype})') + return model, peft_config + + +def load_and_convert_finetune(model: nn.Module, + finetune_config: dict, + checkpoint_path: str = None, + print_model: bool = False, + merge_loras: bool = False, + peft_gradient_checkpointing: bool = None, + rank: int = 0, + **peft_kwargs: any): + """ + Load trained adapter / model weights + """ + # Add low-rank adapters + peft_config = None + if finetune_config.finetune.method == 'lora': + if getattr(finetune_config.finetune, 'kwargs', None) is not None: + model, peft_config = create_peft_config( + model, finetune_config.finetune, + use_gradient_checkpointing=peft_gradient_checkpointing, + **peft_kwargs, + ) + # Keep specified weights trainable + if 'trainable_weights' in finetune_config.finetune: + for name in finetune_config.finetune['trainable_weights']: + for n, p in model.named_parameters(): + if name in n: + p.requires_grad = True + else: + for p in model.parameters(): + p.requires_grad = False + # Keep specified weights trainable + if 'trainable_weights' in finetune_config.finetune: + for name in finetune_config.finetune['trainable_weights']: + for n, p in model.named_parameters(): + if name in n: + if 'layers_to_ignore' in finetune_config.finetune: + layer = int(n.split('layers.')[-1].split('.')[0]) + if layer not in finetune_config.finetune['layers_to_ignore']: + p.requires_grad = True + else: + p.requires_grad = True + + + # Load weights + if checkpoint_path: + state_dict = torch.load(checkpoint_path)['model_state_dict'] + _keys = model.load_state_dict(state_dict, strict=False) + try: + assert len(_keys.unexpected_keys) == 0 + if rank == 0: + print_header('*** All expected keys matched successfully ***') + except Exception as e: + if rank == 0: + print(e) + print_header('*** Error: unexpected keys in checkpoint ***') + print('Unexpected keys:') + for k in _keys.unexpected_keys: + print(k) + + if print_model and rank == 0: # Look at model + print_header('*** Model ***') + print(model) + + if merge_loras: + try: + model = model.merge_and_unload() + if print_model and rank == 0: + print_header('*** Model (after merging adapters) ***') + print(model) + except Exception as e: + print(e) + + if print_model and rank == 0: # Look at model + print_header('*** Trainable Parameters ***') + count = 0 + for n, p in model.named_parameters(): + if p.requires_grad: + print(f'├── {n}.requires_grad: {p.requires_grad}') + count += 1 + if count == 0: + print('(none)') + + return model, peft_config diff --git a/src/model/load_model_for_eval.py b/src/model/load_model_for_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..227cd955fc53b1911a852784203130c5eb1eb3ec --- /dev/null +++ b/src/model/load_model_for_eval.py @@ -0,0 +1,273 @@ +""" +Alternative way to load trained models for evaluation +""" +import copy +import sys +from os.path import join +from omegaconf import OmegaConf + +import torch + +from src.utils.logging import print_header, print_config, _format_arg +from .pretrained import get_pretrained_loader +from .peft import create_peft_config +from .load_model import load_and_convert_attns +from .convert_model import remove_base_attention, toggle_attention + +# Helpers +def get_args_from_checkpoint(fname: str): + """ + Get arguments from checkpoint filename + """ + id_to_name = { + 'lk': 'learned_kernel', + 'tqk': 'tie_qk_kernels', + 'tq': 'train_qk', + 'lzi': 'lk_zero_init', + 'lsc': 'lk_skip_connection', + 'pmnop': 'pretrained_model_name_or_path', + } + id_to_type = { + 'lk': str, + 'tqk': bool, + 'tq': bool, + 'lzi': bool, + 'lsc': bool, + 'pmnop': str, + } + args = {v: None for k, v in id_to_name.items()} + args['run_name'] = '' + + for id_val in fname.split('-'): + try: + _id, val = id_val.split('=') + if val[-len('_distill.pt'):] == '_distill.pt': # hardcode hack + val = val[:-len('_distill.pt')] + if _id in id_to_type: + _type = id_to_type[_id] + args[id_to_name[_id]] = _type(val) + except Exception: + pass + return OmegaConf.create(args) + + +def update_model_config_from_args(model_config, args): + """Override default configs""" + # Overall attention distillation + for arg in ['learned_kernel', 'tie_qk_kernels', 'train_qk']: + argval = getattr(args, arg) + if argval is not None: + setattr(model_config['attention'], arg, argval) + args.run_name += f'-{_format_arg(arg)}={argval}' + # Learned kernel + for arg in ['lk_skip_connection', 'lk_zero_init']: + argval = getattr(args, arg) + if argval is not None: + setattr(model_config['attention']['learned_kernel_kwargs'], + arg[len('lk_'):], argval) + args.run_name += f'-{_format_arg(arg)}={argval}' + # Pretrained model + if args.pretrained_model_name_or_path is not None: # if specified + pmnop = args.pretrained_model_name_or_path + model_config.model.pretrained_model_name_or_path = pmnop + args.run_name += f'-pmnop={pmnop.split("/")[-1]}' + return model_config + + + +def get_lm_eval_model(model_kwargs: dict, # model_loader.loading_kwargs + path_to_lm_eval_harness: str, # ../../lm-evaluation-harness + hedgehog_model: bool = False, + long_model: bool = False, + ): + """ + Load model for evaluation using LM Evaluation Harness + """ + lm_kwargs = copy.deepcopy(model_kwargs) + lm_kwargs['pretrained'] = lm_kwargs['pretrained_model_name_or_path'] + lm_kwargs['dtype'] = str(lm_kwargs['torch_dtype']).split('.')[-1] + del lm_kwargs['torch_dtype'] + + # lm_kwargs['use_cache'] = False + lm_kwargs['output_attentions'] = False + lm_kwargs['output_hidden_states'] = False + + print('-> Loading as lm-evaluation-harness model') + if hedgehog_model: + if 'mistral' in lm_kwargs['pretrained']: + from lm_eval_harness.models import LolcatsMistralForCausalLM as ModelClass + else: + from lm_eval_harness.models import LolcatsLlamaForCausalLM as ModelClass + lm = ModelClass.create_from_arg_string('', lm_kwargs) + else: + sys.path.append(path_to_lm_eval_harness) + from lm_eval.models import get_model + lm = get_model('hf-causal-experimental').create_from_arg_string('', lm_kwargs) + return lm + + +def load_model_from_config(model_config_name: str, + config_dir: str = './configs', + lm_eval_model: bool = False, + path_to_lm_eval_harness: str = '/juice2/scr2/mzhang/projects/lm-evaluation-harness', + ): + """ + Load model from a config file + """ + # Load model configs + model_config_path = join(config_dir, 'model', f'{model_config_name}.yaml') + model_config = OmegaConf.load(model_config_path) + + model_loader = get_pretrained_loader(**model_config.model) + tokenizer = model_loader.load_tokenizer() + tokenizer.pad_token_id = tokenizer.eos_token_id + tokenizer.padding_side = 'left' + + if lm_eval_model: # Instantiate as lm_eval.base.LM object + lm = get_lm_eval_model(model_loader.loading_kwargs, path_to_lm_eval_harness) + model = lm.model + else: + model = model_loader.load() + + model.eval() + if lm_eval_model: + lm.model = model + model = lm + return model, model_config, tokenizer + + +def load_model_from_checkpoint(attn_mlp_checkpoint_path: str = None, + finetune_checkpoint_path: str = None, + config_dir: str = './configs', + print_model: bool = False, + debug: bool = False, + lm_eval_model: bool = False, + path_to_lm_eval_harness: str = '/juice2/scr2/mzhang/projects/lm-evaluation-harness', + profile_model: bool = False, + ): + """ + Load model architecture from a checkpoint path + -> attn_mlp_checkpoint_path should direct to checkpoint with learned MLPs + -> finetune_checkpoint_path should direct to checkpoint with all other parameters + -> Assumes checkpoint_path stings have names for model_config and finetune_configs + """ + + # Load model configs + if attn_mlp_checkpoint_path is not None: + if len(attn_mlp_checkpoint_path.split('/')) == 4: + model_config = attn_mlp_checkpoint_path.split('/')[2] + else: + model_config = attn_mlp_checkpoint_path.split('/')[-1].split('-m=')[-1].split('-')[0] + model_config_path = join(config_dir, 'model', f'{model_config}.yaml') + model_config = OmegaConf.load(model_config_path) + args = get_args_from_checkpoint(attn_mlp_checkpoint_path.split('/')[-1]) + model_config = update_model_config_from_args(model_config, args) + else: + if len(finetune_checkpoint_path.split('/')) == 4: + model_config = finetune_checkpoint_path.split('/')[2] + else: + model_config = finetune_checkpoint_path.split('/')[-1].split('-m=')[-1].split('-')[0] + model_config_path = join(config_dir, 'model', f'{model_config}.yaml') + model_config = OmegaConf.load(model_config_path) + + if profile_model: + model_config['attention']['attention_type'] += '_profile' + + if finetune_checkpoint_path is not None: + finetune_config = finetune_checkpoint_path.split('-f=')[-1].split('-')[0] + finetune_config_path = join(config_dir, 'experiment', f'{finetune_config}.yaml') + finetune_config = OmegaConf.load(finetune_config_path) + + if debug: + print_header('-- Model Config --') + print_config(model_config) + try: + print_header('-- Finetune Config --') + print_config(finetune_config) + except NameError: + pass + + # Get base model + model_loader = get_pretrained_loader(**model_config.model) + tokenizer = model_loader.load_tokenizer() + tokenizer.pad_token_id = tokenizer.eos_token_id + tokenizer.padding_side = 'left' + + if lm_eval_model and attn_mlp_checkpoint_path is not None: + lm = get_lm_eval_model(model_loader.loading_kwargs, path_to_lm_eval_harness, + hedgehog_model=True) + model = lm.model # Do this way because we call the larger object + elif lm_eval_model: # Instantiate as lm_eval.base.LM object + lm = get_lm_eval_model(model_loader.loading_kwargs, path_to_lm_eval_harness) + model = lm.model + elif attn_mlp_checkpoint_path is None: + model = model_loader.load() + else: + model = model_loader.load(model_type=model_config['attention']['attention_type']) + try: + model.state_chunk_len = model_config['attention']['state_chunk_len'] + except KeyError: + pass + + if attn_mlp_checkpoint_path is not None: + # Update and load attentions + model = load_and_convert_attns(model, model_config, + checkpoint_path=attn_mlp_checkpoint_path)[0] + if 'peft' in model_config['attention']: # Merge back q and k proj + model = model.merge_and_unload() + # Already removed in load_and_convert_attns + # model = remove_base_attention(model) # , model_config.attention) + model = toggle_attention(model, False) + if debug: + print_header('*** Model after attention converion ***') + print(model) + + if finetune_checkpoint_path is not None: + # Update architecture with LoRAs + if finetune_config.finetune.method == 'lora': + model, _ = create_peft_config(model, finetune_config.finetune) + else: + for p in model.parameters(): + p.requires_grad = True + + # Load weights + state_dict = torch.load(finetune_checkpoint_path)['model_state_dict'] + _keys = model.load_state_dict(state_dict, strict=False) + try: + assert len(_keys.unexpected_keys) == 0 + print_header('*** All expected keys matched successfully ***') + except AssertionError: + print_header('*** Error: unexpected keys in checkpoint ***') + print('Unexpected keys:') + for k in _keys.unexpected_keys: + print(k) + if debug: + print_header('Missing keys:') + for k in _keys.missing_keys: + print(k) + print_header('Unexpected keys:') + for k in _keys.unexpected_keys: + print(k) + + try: + # model = model.merge_and_unload() + print('-> Training attention:', model.model.layers[0].self_attn.train_attention) + except AttributeError as e: + print('Error at:', e) + _train_attn = model.model.model.layers[0].self_attn.train_attention + print(f"But it's ok, {type(model.model.model)} has attribute 'layers'") + print('-> Training attention:', _train_attn) + + + if print_model or debug: # Look at model + print_header('*** Model ***') + print(model) + print_header('*** Trainable Parameters ***') + for n, p in model.named_parameters(): + if p.requires_grad: + print(f'├── {n}.requires_grad: {p.requires_grad}') + model.eval() + if lm_eval_model: + lm.model = model + model = lm + return model, model_config, tokenizer diff --git a/src/model/modeling_llama.py b/src/model/modeling_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..03c4789dcfc920491e7d711905dd07add4e33d6e --- /dev/null +++ b/src/model/modeling_llama.py @@ -0,0 +1,303 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Thin wrappers and replacement classes for LlamaForCausalLM +""" +from typing import Optional, Tuple, List, Union + +import warnings +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers.models.llama.modeling_llama import ( + LlamaModel, LlamaForCausalLM, LLAMA_INPUTS_DOCSTRING, +) +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.cache_utils import Cache, DynamicCache +from transformers.utils import ( + add_start_docstrings_to_model_forward, logging, +) + +from .convert_model import get_attention_cache + +logger = logging.get_logger(__name__) + +# Modified from transformers.models.llama.modeling_llama.LlamaModel (v4.43) +class LolcatsLlamaModel(LlamaModel): + """ + Wrapper for Llama or Mistral-like base model + + Modified from transformers.models.llama.modeling_llama.LlamaModel + -> Only difference is using KV state for past_key_values instead of cache + """ + def __init__(self, *args: any, **kwargs: any): + super().__init__(*args, **kwargs) + self.layerwise_cpu = False + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if use_cache: + if past_key_values is None or isinstance(past_key_values, DynamicCache): # Determine and setup our KV cache or state + attention_type = getattr(self.layers[0].self_attn, 'attention_type', None) + past_key_values = get_attention_cache(attention_type) + else: + past_key_values.get_usable_length(seq_length) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class LolcatsLlamaForCausalLM(LlamaForCausalLM): + """ + Wrapper for Llama-like autoregressive language model + """ + def __init__(self, config): + # Adapt config to LlamaConfig + if getattr(config, 'attention_bias', None) is None: + config.attention_bias = False + if getattr(config, 'rope_scaling', None) is None: + config.rope_scaling = None + if getattr(config, 'pretraining_tp', None) is None: + config.pretraining_tp = 1 + super().__init__(config) + self.model = LolcatsLlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def forward(self, *args: any, labels: Optional[torch.LongTensor] = None, **kwargs: any): + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model(*args, **kwargs) + hidden_states = outputs[0] + # if False: # getattr(self.model.layers[0].self_attn, 'train_attention', False): + # logits = None # MZ 8/25: Sorry, was trying stuff + # regular training + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) + for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + return CausalLMOutputWithPast( + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class LooooolcatsLlamaForCausalLM(LolcatsLlamaForCausalLM): + """ + Wrapper for Llama or Mistral-like autoregressive language model + -> Experimental / WIP; but goal is to combine chunked linear attention during training + to process long contexts with minimally-growing memory usage + """ + def chunk_forward(self, *args: any, **kwargs: any): + """Call this when training / processing one chunk""" + return super().forward(*args, **kwargs) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, # Ignored for now, new Transformers >4.36 + ) -> Union[Tuple, CausalLMOutputWithPast]: + """ + Forward pass where we chunk inputs + """ + self.generating = False + if use_cache is not True: + use_cache = True + + if attention_mask is not None and use_cache: + warnings.warn( + "Sorry padding currently not supported. Setting attention_mask to None (will still be causal)." + ) + attention_mask = None + + if past_key_values is None: + # Determine and setup our KV cache or state + attention_type = getattr(self.model.layers[0].self_attn, 'attention_type', None) + past_key_values = get_attention_cache(attention_type) + + if input_ids.shape[-1] == 1: # Heuristic to detect generating + return super().forward(input_ids, attention_mask, position_ids, + past_key_values, inputs_embeds, labels, + use_cache, output_attentions, output_hidden_states, + return_dict) + else: + assert self.training is False # To train this way, use training loop to chunk + if self.generating: # Heuristic to detect new sample + self.generating = False + # Determine and setup our KV cache or state + attention_type = getattr(self.model.layers[0].self_attn, 'attention_type', None) + past_key_values = get_attention_cache(attention_type) + + # Split inputs into chunks, and do linear attention over each (passing the states) + input_ids = torch.split(input_ids, self.state_chunk_len, dim=-1) + if position_ids is not None: + position_ids = torch.split(position_ids, self.state_chunk_len, dim=-1) + + all_logits = [] # save these + for _idx, _input_ids in enumerate(input_ids): + if self.training: + print(f'Model processing _input_ids.shape:', _input_ids.shape) + outputs = super().forward(_input_ids, None, + position_ids[_idx] if position_ids is not None else None, + past_key_values, inputs_embeds, + labels=None, + use_cache=True, + output_attentions=False, + output_hidden_states=False, + return_dict=True,) + past_key_values = outputs.past_key_values + all_logits.append(outputs.logits) + + if _idx == len(input_ids) - 1: + self.generating = True # time to generate; if no generation will reset + + return CausalLMOutputWithPast( + # loss=loss, + logits=torch.cat(all_logits, dim=-2), # b, l, d + past_key_values=past_key_values, + ) diff --git a/src/model/modeling_llama_sharded.py b/src/model/modeling_llama_sharded.py new file mode 100644 index 0000000000000000000000000000000000000000..45fe04f89a5f0a5b150829cd37dbd9d59d6d726d --- /dev/null +++ b/src/model/modeling_llama_sharded.py @@ -0,0 +1,244 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Thin wrappers and replacement classes for LlamaForCausalLM +- Simple sharding across multiple GPUs; will be slow but good for quality evals +- May need to update for Llama 405B +""" +from typing import Optional, Tuple, List, Union + +import warnings +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers.models.llama.modeling_llama import ( + LlamaModel, LlamaForCausalLM, LLAMA_INPUTS_DOCSTRING, +) +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.cache_utils import Cache, DynamicCache +from transformers.utils import ( + add_start_docstrings_to_model_forward, logging, +) + +from .convert_model import get_attention_cache + +logger = logging.get_logger(__name__) + +# Modified from transformers.models.llama.modeling_llama.LlamaModel (v4.43) +class ShardedLolcatsLlamaModel(LlamaModel): + """ + Wrapper for Llama or Mistral-like base model + + Modified from transformers.models.llama.modeling_llama.LlamaModel + -> Only difference is using KV state for past_key_values instead of cache + """ + def __init__(self, *args: any, **kwargs: any): + super().__init__(*args, **kwargs) + self.layerwise_cpu = False + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if use_cache: + if past_key_values is None or isinstance(past_key_values, DynamicCache): # Determine and setup our KV cache or state + attention_type = getattr(self.layers[0].self_attn, 'attention_type', None) + past_key_values = get_attention_cache(attention_type, past_key_values) + else: + past_key_values.get_usable_length(seq_length) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + # - ignored for linearized models + position_embeddings = None + # position_embeddings = self.rotary_emb(hidden_states, position_ids.to(hidden_states.device)) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + # Move output to right device + device = decoder_layer.self_attn.q_proj.weight.device + hidden_states = hidden_states.to(device) + position_ids = position_ids.to(device) + if attention_mask is not None: + attention_mask = attention_mask.to(device) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + + if getattr(decoder_layer.self_attn, 'converted', False): + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + else: + with torch.no_grad(): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states.to(self.norm.weight.device)) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class ShardedLolcatsLlamaForCausalLM(LlamaForCausalLM): + """ + Wrapper for Llama-like autoregressive language model + """ + def __init__(self, config): + # Adapt config to LlamaConfig + if getattr(config, 'attention_bias', None) is None: + config.attention_bias = False + if getattr(config, 'rope_scaling', None) is None: + config.rope_scaling = None + if getattr(config, 'pretraining_tp', None) is None: + config.pretraining_tp = 1 + super().__init__(config) + self.model = ShardedLolcatsLlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def forward(self, *args: any, labels: Optional[torch.LongTensor] = None, **kwargs: any): + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model(*args, **kwargs) + hidden_states = outputs[0] + if getattr(self.model.layers[0].self_attn, 'train_attention', False): + logits = None + else: # regular training + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) + for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + return CausalLMOutputWithPast( + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/model/modeling_mistral.py b/src/model/modeling_mistral.py new file mode 100644 index 0000000000000000000000000000000000000000..3b5fd02a39ae267b0ed1a6c5c3476607bc961535 --- /dev/null +++ b/src/model/modeling_mistral.py @@ -0,0 +1,161 @@ +# coding=utf-8 +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Thin wrappers and replacement classes for MistralForCausalLM +""" +from typing import Optional, Tuple, List, Union + +import warnings +import torch +import torch.nn as nn +from transformers import MistralModel, MistralForCausalLM +from transformers.modeling_outputs import CausalLMOutputWithPast + +from .modeling_llama import LolcatsLlamaModel +from .convert_model import get_attention_cache + + +# Modified from transformers.models.llama.modeling_llama.LlamaModel +class LolcatsMistralModel(LolcatsLlamaModel, MistralModel): + """ + Wrapper for Mistral-like autoregressive language model + """ + def forward(self, *args, **kwargs): + return super().forward(*args, **kwargs) + + +class LolcatsMistralForCausalLM(MistralForCausalLM): + """ + Wrapper for Llama or Mistral-like autoregressive language model + """ + def __init__(self, config): + # Adapt config to LlamaConfig + if getattr(config, 'attention_bias', None) is None: + config.attention_bias = False + if getattr(config, 'rope_scaling', None) is None: + config.rope_scaling = None + if getattr(config, 'pretraining_tp', None) is None: + config.pretraining_tp = 1 + if getattr(config, 'pretraining_tp', None) is None: + config.pretraining_tp = 1 + if getattr(config, 'mlp_bias', None) is None: + config.mlp_bias = False + super().__init__(config) + self.model = LolcatsMistralModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + +class LooooolcatsMistralForCausalLM(LolcatsMistralForCausalLM): + """ + Wrapper for Llama or Mistral-like autoregressive language model + -> Experimental / WIP; but goal is to combine chunked linear attention during training + to process long contexts with minimally-growing memory usage + """ + def chunk_forward(self, *args: any, **kwargs: any): + """Call this when training / processing one chunk""" + return super().forward(*args, **kwargs) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, # Ignored for now, new Transformers >4.36 + ) -> Union[Tuple, CausalLMOutputWithPast]: + """ + Forward pass where we chunk inputs + """ + self.generating = False + if use_cache is not True: + use_cache = True + + if attention_mask is not None and use_cache: + warnings.warn( + f"Sorry padding currently not supported. Setting attention_mask to None (will still be causal)." + ) + attention_mask = None + + if past_key_values is None: + # Determine and setup our KV cache or state + attention_type = getattr(self.model.layers[0].self_attn, 'attention_type', None) + past_key_values = get_attention_cache(attention_type) + # past_key_values = LinearAttentionState() + + if input_ids.shape[-1] == 1 and not self.training: # Heuristic to detect generating + return super().forward(input_ids, attention_mask, position_ids, + past_key_values, inputs_embeds, labels, + use_cache, output_attentions, output_hidden_states, + return_dict) + else: + if self.generating: # Heuristic to detect new sample + self.generating = False + # Determine and setup our KV cache or state + attention_type = getattr(self.model.layers[0].self_attn, 'attention_type', None) + past_key_values = get_attention_cache(attention_type) + print(f'-> attention_type:', attention_type) + + # Make it so we keep track of gradients in kv_state computation + for idx in range(len(self.model.layers)): + self.model.layers[idx].self_attn.state_grad_enabled = self.training + + # Split inputs into chunks, and do linear attention over each (passing the states) + input_ids = torch.split(input_ids, self.state_chunk_len, dim=-1) + if position_ids is not None: + position_ids = torch.split(position_ids, self.state_chunk_len, dim=-1) + + all_logits = [] # save these + for _idx, _input_ids in enumerate(input_ids): + outputs = super().forward(_input_ids, None, + position_ids[_idx] if position_ids is not None else None, + past_key_values, inputs_embeds, + labels=None, + use_cache=True, + output_attentions=False, + output_hidden_states=False, + return_dict=True,) + past_key_values = outputs.past_key_values + all_logits.append(outputs.logits) + + # Comment in / adjust to do gradient accumulation over chunks + # if self.training: + # loss = outputs.loss + # loss.backward() # accumulate gradients over chunks + # else: + # del outputs.loss + + if _idx == len(input_ids) - 1: + self.generating = True # time to generate; if no generation will reset + + return CausalLMOutputWithPast( + # loss=loss, + logits=torch.cat(all_logits, dim=-2), # b, l, d + past_key_values=past_key_values, + ) \ No newline at end of file diff --git a/src/model/peft.py b/src/model/peft.py new file mode 100644 index 0000000000000000000000000000000000000000..7022b4694cef08d0e28b58e334f3f5db95eab649 --- /dev/null +++ b/src/model/peft.py @@ -0,0 +1,96 @@ +""" +Helpers for parameter-efficient finetuning via low-rank adapters (LoRA) +-> Mainly follow PEFT / llama recipes + +Right now quantization not super tested +""" +import torch +from torch.nn import Module + + +# Modified from https://github.com/facebookresearch/llama-recipes/blob/main/examples/quickstart.ipynb +def create_peft_config(model: Module, + peft_config: dict, + target_dtype: str = 'bfloat16', + preserve_requires_grad: bool = False, + use_gradient_checkpointing: bool = None, + add_self_attn_prefix: bool = True): + """ + Create a parameter-efficient finetuning model (e.g., attaching LoRAs) + -> Assumes that all non-trainable weights have been frozen already. + If not, freeze them before calling this function. + """ + if peft_config['method'] == 'lora': + from peft import ( + get_peft_model, + LoraConfig, + TaskType, + prepare_model_for_kbit_training, + ) + try: + target_modules = [] # hack to only do self_attn terms + for module_name in peft_config['kwargs']['target_modules']: + if ('_proj' in module_name and 'self_attn' not in module_name + and add_self_attn_prefix): + target_modules.append(f'self_attn.{module_name}') + elif '_proj' in module_name: + target_modules.append(module_name) + peft_config['kwargs']['target_modules'] = target_modules + except Exception as e: + print(e) + target_modules = [] + + if 'layers_to_ignore' in peft_config: + peft_config['kwargs']['layers_to_transform'] = [ + i for i in range(len(model.model.layers)) + if i not in peft_config['layers_to_ignore'] + ] + + peft_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + **peft_config['kwargs'], + ) + # Save parameters that did not have frozen weights before to unfreeze later + trainable_weights = [ + n for n, p in model.named_parameters() if p.requires_grad + ] + # Prepare int-8 or int-4 model for training + loaded_in_kbit = (getattr(model, "is_loaded_in_8bit", False) or + getattr(model, "is_loaded_in_4bit", False)) + if loaded_in_kbit: # From https://huggingface.co/docs/peft/en/package_reference/peft_model: + # This method wraps the entire protocol for preparing a model before running a training. + # 1- Cast the layernorm in fp32 + # 2- making output embedding layer require grads + # 3- Add the upcasting of the lm head to fp32 + model.enable_input_require_grads() + ugc = (use_gradient_checkpointing + if use_gradient_checkpointing is not None else True) + print('-> use_gradient_checkpointing:', ugc) + # model.gradient_checkpointing_enable() + model = prepare_model_for_kbit_training( + model, use_gradient_checkpointing=ugc, + gradient_checkpointing_kwargs={'use_reentrant': False}, + ) + + model = get_peft_model(model, peft_config) + model.print_trainable_parameters() + + for n, p in model.named_parameters(): + # Unfreeze weights frozen by get_peft_model() + if preserve_requires_grad: + if n[len('base_model.model.'):] in trainable_weights: + p.requires_grad = True + + # prepare_model_for_kbit_training will cast all non INT8 parameters to fp32 + # -> https://github.com/huggingface/peft/blob/7e84dec20b3106bdd0a90ba8e80187f0aec835b7/src/peft/utils/other.py#L103 + # So we'll cast these back to their prior dtype + if p.requires_grad and loaded_in_kbit: + p.data = p.data.to(getattr(torch, target_dtype)) + + if not loaded_in_kbit: + model.to(dtype=getattr(torch, target_dtype)) + + return model, peft_config + else: + raise NotImplementedError(f"Sorry PEFT method {peft_config['method']} not implemented yet.") diff --git a/src/model/pretrained.py b/src/model/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..dc0ebe3c36682bbb7aa79a0ff41dd455c36ddafc --- /dev/null +++ b/src/model/pretrained.py @@ -0,0 +1,205 @@ +""" +Classes for loading pretrained models +""" +from os.path import join +from omegaconf import OmegaConf + +import torch +import torch.nn as nn + +import transformers +from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer +# from transformers import BitsAndBytesConfig +from peft import prepare_model_for_kbit_training + + +def get_pretrained_loader(pretrained_model_name_or_path: str, + huggingface_token: str = None, + **model_kwargs: any): + """ + Return the appropriate loader for the pretrained model + """ + + if 'lama' in pretrained_model_name_or_path: # Llama or llama + return PretrainedLlamaLoader( + pretrained_model_name_or_path=pretrained_model_name_or_path, + huggingface_token=huggingface_token, + **model_kwargs, + ) + elif 'istral' in pretrained_model_name_or_path: # Mistral or mistral; + return PretrainedMistralLoader( + pretrained_model_name_or_path=pretrained_model_name_or_path, + huggingface_token=huggingface_token, + **model_kwargs, + ) + else: + print(f'-> {pretrained_model_name_or_path} using default pretrained model loader') + return PretrainedModelLoader( + pretrained_model_name_or_path=pretrained_model_name_or_path, + huggingface_token=huggingface_token, + **model_kwargs, + ) + + +class PretrainedModelLoader(): + """ + Class for loading a pretrained model. + Example: + model_loader = PretrainedModelLoader(**model_kwargs) + model = model_loader.load() + """ + def __init__(self, + pretrained_model_name_or_path: str, + cache_dir: str = None, + return_dict: bool = True, # False + device_map: str = 'auto', + low_cpu_mem_usage: bool = True, + torch_dtype: str = 'bfloat16', + rope_theta: float = 10000., + attn_implementation: str = 'sdpa', # eager + load_in_8bit: bool = False, + load_in_4bit: bool = False, + huggingface_token: str = None, + peft_id: str = None, + rope_scaling: dict = None, + **other_kwargs: any) -> None: + + print(f'-> Using {attn_implementation} attention') + + self.loading_kwargs = { + 'pretrained_model_name_or_path': pretrained_model_name_or_path, + 'cache_dir': cache_dir, + 'return_dict': return_dict, + 'load_in_8bit': load_in_8bit, + 'load_in_4bit': load_in_4bit, + 'device_map': device_map, + 'low_cpu_mem_usage': low_cpu_mem_usage, + 'torch_dtype': getattr(torch, torch_dtype), + 'rope_theta': rope_theta, + 'attn_implementation': attn_implementation, + } + if rope_scaling is not None: # Llama 3.1 patch + rope_scaling = OmegaConf.to_container(rope_scaling) + self.loading_kwargs['rope_scaling'] = rope_scaling + for k, v in other_kwargs.items(): + self.loading_kwargs[k] = v + + self.quantization = load_in_8bit or load_in_4bit + self.peft_id = peft_id + self.gradient_checkpointing = False + if huggingface_token is not None: # for gated models, e.g., Llama 3 + self.loading_kwargs['token'] = huggingface_token + + if self.quantization: + raise NotImplementedError + # bnb_config = BitsAndBytesConfig( + # load_in_8bit=load_in_8bit, + # load_in_4bit=load_in_4bit, + # bnb_4bit_compute_dtype=torch.bfloat16, + # bnb_4bit_use_double_quant=True, + # bnb_4bit_quant_type="nf4", + # ) + # del self.loading_kwargs['load_in_8bit'] + # del self.loading_kwargs['load_in_4bit'] + # self.loading_kwargs['quantization_config'] = bnb_config + + def load(self) -> nn.Module: + """ + Load pretrained model + """ + model = AutoModelForCausalLM.from_pretrained(**self.loading_kwargs) + if self.quantization: + model = prepare_model_for_kbit_training( + model, use_gradient_checkpointing=self.gradient_checkpointing, + gradient_checkpointing_kwargs={'use_reentrant': False}, + ) + return model + + def load_tokenizer(self): + """ + Load pretrained tokenizer + """ + try: + return AutoTokenizer.from_pretrained(**self.loading_kwargs) + except Exception as e: + print("-> Error with `AutoTokenizer.from_pretrained(**self.loading_kwargs)`:", e) + print("-> Trying `LlamaTokenizer.from_pretrained(**self.loading_kwargs)`") + # MZ 6/1: Mistral-7B-Instruct-v0.3 in Transformers v4.36 doesn't work with the above + return LlamaTokenizer.from_pretrained(**self.loading_kwargs) + + +class PretrainedLlamaLoader(PretrainedModelLoader): + def load(self, model_type: str = None, ): + llama3_1 = float('.'.join(transformers.__version__.split('.')[:2])) > 4.42 # 'Meta-Llama-3.1' in self.loading_kwargs['pretrained_model_name_or_path'] + if model_type is None: + from transformers import LlamaForCausalLM as model_class + + elif 'lolcats_llama_sharded' in model_type: + from .modeling_llama_sharded import ShardedLolcatsLlamaForCausalLM as model_class + + elif 'lolcats_long_llama' in model_type: + from .modeling_llama import LooooolcatsLlamaForCausalLM as model_class + + elif 'lolcats_llama' in model_type: + from .modeling_llama import LolcatsLlamaForCausalLM as model_class + + else: + if model_type == 'flash_attention_2': + self.loading_kwargs['attn_implementation'] = model_type + from transformers import AutoModelForCausalLM as model_class + print('-> Loading from AutoModelForCausalLM') + + model = model_class.from_pretrained(**self.loading_kwargs) + if self.peft_id is not None: + from peft import PeftModel + print('-> Loading PEFT checkpoint') + model = PeftModel.from_pretrained( + model, + self.peft_id, + torch_dtype=self.loading_kwargs['torch_dtype'], + device_map='auto', + cache_dir=self.loading_kwargs['cache_dir'] + ).merge_and_unload() + + if self.quantization: + model = prepare_model_for_kbit_training( + model, use_gradient_checkpointing=self.gradient_checkpointing, + gradient_checkpointing_kwargs={'use_reentrant': False}, + ) + return model + + def load_tokenizer(self): + return AutoTokenizer.from_pretrained(**self.loading_kwargs) + + +class PretrainedMistralLoader(PretrainedModelLoader): + def load(self, model_type: str = None): + if model_type is None: + from transformers import MistralForCausalLM as model_class + elif 'lolcats_long_llama' in model_type: + from .modeling_mistral import LooooolcatsMistralForCausalLM as model_class + elif 'lolcats_llama' in model_type: + from .modeling_mistral import LolcatsMistralForCausalLM as model_class + else: + if model_type == 'flash_attention_2': + self.loading_kwargs['attn_implementation'] = model_type + from transformers import AutoModelForCausalLM as model_class + print('-> Loading from AutoModelForCausalLM') + + model = model_class.from_pretrained(**self.loading_kwargs) + if self.peft_id is not None: + from peft import PeftModel + model = PeftModel.from_pretrained( + model, + self.peft_id, + torch_dtype=self.loading_kwargs['torch_dtype'], + device_map='auto', + cache_dir=self.loading_kwargs['cache_dir'], + ).merge_and_unload() + + if self.quantization: + model = prepare_model_for_kbit_training( + model, use_gradient_checkpointing=self.gradient_checkpointing, + gradient_checkpointing_kwargs={'use_reentrant': False}, + ) + return model diff --git a/src/model/rotary.py b/src/model/rotary.py new file mode 100644 index 0000000000000000000000000000000000000000..58b79044212ab4dc4e0419ab83098df7e269e136 --- /dev/null +++ b/src/model/rotary.py @@ -0,0 +1,172 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Rotary embeddings. Same as usual for Transformer models. + +Note these are modified from HF Transformers v4.36, from: +- transformers/models/llama/modeling_llama.py or transformers/models/mistral/modeling_mistral.py +- i.e., https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L123 +""" +import torch +import torch.nn as nn + + +def get_rotary_embeddings(rope_scaling_type: str = None, + head_dim: int = 128, + max_position_embeddings: int = 4096, + rope_theta: float = 10000.0, + rope_scaling_factor: float = 1.0, + device: torch.device = None, + ) -> nn.Module: + """Return rotary embedding object""" + if rope_scaling_type is None: + return RotaryEmbedding( + head_dim, + max_position_embeddings=max_position_embeddings, + base=rope_theta, + device=device, + ) + elif rope_scaling_type == "linear": + return LinearScalingRotaryEmbedding( + head_dim, + max_position_embeddings=max_position_embeddings, + scaling_factor=rope_scaling_factor, + base=rope_theta, + device=device, + ) + elif rope_scaling_type == "dynamic": + return DynamicNTKScalingRotaryEmbedding( + head_dim, + max_position_embeddings=max_position_embeddings, + scaling_factor=rope_scaling_factor, + base=rope_theta, + device=device, + ) + else: + raise NotImplementedError(f'Sorry rope_scaling_type == "{rope_scaling_type}" not implemented.') + + +# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36) +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36) +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors.""" + if position_ids is not None: + cos, sin = cos[position_ids], sin[position_ids] + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Modified from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36) +class RotaryEmbedding(nn.Module): + """Original Rotary Embeddings from RoFormer https://arxiv.org/abs/2104.09864""" + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + """ + Compute rotary embeddings + """ + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers/models/llama/modeling_llama.py at v4.36 +class LinearScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Copied from transformers/models/llama/modeling_llama.py at v4.36 +class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + diff --git a/src/model/utils.py b/src/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cde449f3aa1baedd6aa0547862fe1bcd7a1ac262 --- /dev/null +++ b/src/model/utils.py @@ -0,0 +1,15 @@ +import numpy as np + + +def count_parameters(model, requires_grad: bool = True): + """ + Return total # of trainable parameters + """ + if requires_grad: + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + else: + model_parameters = model.parameters() + try: + return sum([np.prod(p.size()) for p in model_parameters]).item() + except: + return sum([np.prod(p.size()) for p in model_parameters]) diff --git a/src/trainer/__init__.py b/src/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3b5939570fa21db5ab46ead823f1dd7c5bf9ec88 --- /dev/null +++ b/src/trainer/__init__.py @@ -0,0 +1,15 @@ +import importlib +from .optim import get_optimizer, get_scheduler + + +def get_trainer(name: str): + """ + Return our trainer class + """ + try: + module = importlib.import_module(f'src.trainer.{name}') + except ModuleNotFoundError as e: + print(e) + print('-> Using default trainer') + module = importlib.import_module('src.trainer.default') + return getattr(module, 'OurTrainer') diff --git a/src/trainer/default_lm.py b/src/trainer/default_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..89ecb0cd6143599e6afdd372e9047cf1a6f09ec1 --- /dev/null +++ b/src/trainer/default_lm.py @@ -0,0 +1,406 @@ +""" +Default trainer class for training models +""" +from collections import OrderedDict +from os.path import join +from argparse import ArgumentParser +from tqdm import tqdm + +import pandas as pd + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler + +from .optim import get_optimizer, get_scheduler +from .utils import decode_samples + + +class OurTrainer(): + """ + Basic parent trainer class. Defaults to language modeling. + -> Replacement for Hugging Face Trainer + """ + def __init__(self, + model: nn.Module, + args: ArgumentParser, + train_loader: DataLoader, + eval_loader: DataLoader, + optimizer_and_scheduler: tuple[Optimizer, LRScheduler], + device: torch.device, + wandb, # WandB object + checkpoint_suffix: str = None, + save_checkpoints: bool = True, + save_results: bool = True, + # Custom arguments + optimizer_args: dict = None, + lr_scheduler_args: dict = None, + greater_is_better: bool = False, + metric_for_best_model: str = 'eval/loss', + num_train_epochs: int = 2, + gradient_accumulation_steps: int = 1, + evaluation_strategy: str = 'steps', + load_best_model_at_end: bool = True, + logging_steps: int = 100, + max_steps: int = -1, + eval_steps: int = 100, + max_eval_batches: int = -1, + print_samples: bool = False, + initial_eval: bool = True, + num_save_ckpt_steps: int = 1000, + **kwargs: any): + super().__init__() + self.model = model + self.step = 0 # Total steps taken + self.grad_step = 0 # Total gradient updates + self.compute_loss_backprop = False # Whether we backprop in self.compute_loss + + if optimizer_and_scheduler is None: + assert optimizer_args is not None and lr_scheduler_args is not None + self.optimizer = get_optimizer(model=self.model, **optimizer_args) + self.scheduler = get_scheduler(optimizer=self.optimizer, **lr_scheduler_args) + else: + self.optimizer, self.scheduler = optimizer_and_scheduler + try: + self.scheduler_step_after_epoch = 'plateau' in args.lr_scheduler['lr_scheduler_type'] + except KeyError: + self.scheduler_step_after_epoch = False + + # Dataloaders + self.train_loader = train_loader + self.eval_loader = eval_loader + + self.device = device + self.wandb = wandb + + # Custom arguments + self.metric_for_best_model = metric_for_best_model + self.num_train_epochs = num_train_epochs + self.gradient_accumulation_steps = gradient_accumulation_steps + self.evaluation_strategy = evaluation_strategy + self.greater_is_better = greater_is_better + self.is_better = (lambda x, y: x > y if greater_is_better else x < y) + self.load_best_model_at_end = load_best_model_at_end + self.logging_steps = logging_steps + self.max_steps = max_steps + self.eval_steps = eval_steps + self.max_eval_batches = max_eval_batches + self.print_samples = print_samples + self.initial_eval = initial_eval + self.num_save_ckpt_steps = num_save_ckpt_steps + + # Saving metrics + self.train_metrics = {'train/loss': None, + 'train/epoch': None, + 'train/step': None} + self.eval_metrics = {metric_for_best_model: None} + self.eval_metrics_by_step = {'eval_step': []} # save all eval metrics + self.criterion = nn.CrossEntropyLoss(reduction='mean') + try: + self.tokenizer = self.train_loader.dataset.tokenizer + except AttributeError: + self.tokenizer = None + + self.save_results = save_results + self.results_path = None + self.best_val_metric = 0 if greater_is_better else 1e10 + self.best_val_metric_epoch = 0 + self.best_val_metric_step = 0 + if save_checkpoints: # Also initializes best_val_metrics + self.init_checkpointing(args=args, checkpoint_suffix=checkpoint_suffix) + + def train(self) -> nn.Module: + """ + Entire training run + """ + model = self.model + pbar = tqdm(range(self.num_train_epochs), leave=False, colour='white', + desc='Training') + for ix, epoch in enumerate(pbar): + model, early_stopping = self.train_step(model, epoch) + if self.evaluation_strategy == 'epoch': + _eval_metrics = self.eval_step(model, step=self.grad_step) + print(f'Epoch {ix} metrics:', _eval_metrics) + if early_stopping: + break + + if self.load_best_model_at_end: # Return best checkpoint + try: + state_dict = torch.load(self.best_val_checkpoint_path)['model_state_dict'] + model.load_state_dict(state_dict, strict=False) + print(f'-> Loading best checkpoint from {self.best_val_checkpoint_path}') + except FileNotFoundError as e: + print(e) + print('-> Returning most recent model instead') + return model + + def train_step(self, model: nn.Module, epoch: int) -> nn.Module: + """ + Training loop over one epoch + """ + if self.gradient_accumulation_steps is None: + accum_iter = 1 + else: + accum_iter = self.gradient_accumulation_steps + + model.train() + model.zero_grad() + pbar = tqdm(self.train_loader, leave=False, colour='blue', + desc=f'-> Training (epoch {epoch} / {self.num_train_epochs})') + total_loss = 0 + eval_for_step = False + + # Initial eval + if self.initial_eval: + print('') + print('-> Initial eval') + self.compute_eval_metrics(model, step=self.grad_step) + + # model.to(self.device) + for ix, data in enumerate(pbar): + loss, train_metrics = self.compute_loss(model, data, + sample_idx=ix) + loss /= accum_iter + if not self.compute_loss_backprop: + # loss.backward() did not occur in compute_loss + try: + with torch.autograd.set_detect_anomaly(True): + loss.backward() + except Exception as e: + breakpoint() + if (self.step + 1) % accum_iter == 0: # and self.step != 0: + self.optimizer.step() + if not self.scheduler_step_after_epoch and self.scheduler is not None: + self.scheduler.step() + self.optimizer.zero_grad() + self.grad_step += 1 + if not self.compute_loss_backprop: + loss = loss.detach().cpu().item() + + self.step += 1 + if not isinstance(loss, float): + total_loss += loss.item() + else: + total_loss += loss + desc = f"Training epoch {epoch} | loss: {total_loss / (ix + 1):.3f} | lr: {self.optimizer.param_groups[0]['lr']:.5f}" + desc += f' | gradient step: {self.grad_step}' + for k, v in train_metrics.items(): + desc += f' | {k}: {v:.3f}' + pbar.set_description(desc) + + # Logging + if (self.grad_step) % (self.logging_steps): + self.train_metrics['train/loss'] = loss.item() if not isinstance(loss, float) else loss + self.train_metrics['train/epoch'] = epoch + self.train_metrics['train/step'] = self.grad_step + self.train_metrics['train/lr'] = self.optimizer.param_groups[0]['lr'] + for k, v in train_metrics.items(): + self.train_metrics[f'train/{k}'] = v + + if self.wandb is not None: + self.wandb.log(self.train_metrics, step=self.grad_step) + + if self.evaluation_strategy == 'steps': + if (self.grad_step % self.eval_steps == 0 and self.grad_step > 0 and not eval_for_step): + _eval_metrics = self.eval_step(model, step=self.grad_step) + print(f'Grad Step {self.grad_step} eval metrics:', _eval_metrics) + eval_for_step = True + model.train() # Need to set back to train mode + elif self.grad_step == 0 and self.num_save_ckpt_steps < 1000 and not eval_for_step: # hack for micros + _eval_metrics = self.eval_step(model, step=self.grad_step) + print(f'Grad Step {self.grad_step} eval metrics:', _eval_metrics) + eval_for_step = True + model.train() # Need to set back to train mode + + elif self.grad_step % self.eval_steps == 0 and self.grad_step > 0 and eval_for_step: + pass + else: + if self.grad_step > 0: + eval_for_step = False + if self.grad_step == self.max_steps: + early_stopping = True + return model, early_stopping + + early_stopping = False + return model, early_stopping + + def eval_step(self, model: nn.Module, step: int = None, + **kwargs: any) -> dict[any]: + """ + Evaluation loop over one epoch + """ + with torch.no_grad(): + self.eval_metrics = self.compute_eval_metrics(model, step=step, **kwargs) + val_metric = self.eval_metrics[self.metric_for_best_model] + + # Save results + if self.wandb is not None: # log to WandB + self.wandb.log(self.eval_metrics, step=self.grad_step) + + if self.results_path is not None: # log to local file + self.eval_metrics_by_step['eval_step'].append(step) + for k, v in self.eval_metrics.items(): + if k not in self.eval_metrics_by_step: + self.eval_metrics_by_step[k] = [v] + else: + self.eval_metrics_by_step[k].append(v) + # Inefficient, but log for experiments results + pd.DataFrame(self.eval_metrics_by_step).to_csv(self.results_path) + + # Save best metric and checkpoint + if self.grad_step % self.eval_steps == 0: + if self.is_better(val_metric, self.best_val_metric): + self.best_val_metric = val_metric + self.best_val_metric_step = self.grad_step + # model.cpu() + torch.save({ + 'model_state_dict': self.save_trainable_weights(model), + 'step': self.grad_step, + self.metric_for_best_model: val_metric + }, self.best_val_checkpoint_path) + print(f'\n-> Saved best model checkpoint to: {self.best_val_checkpoint_path}!') + + if self.grad_step % self.num_save_ckpt_steps == 0: + save_path = self.best_val_checkpoint_path.replace('.pt', f'_{self.grad_step}.pt') + torch.save({ + 'model_state_dict': self.save_trainable_weights(model), + 'step': self.grad_step, + self.metric_for_best_model: val_metric + }, save_path) + print(f'\n-> Saved best model checkpoint to: {save_path}!') + + if self.scheduler_step_after_epoch and self.scheduler is not None: + self.scheduler.step(val_metric) + return self.eval_metrics + + def compute_eval_metrics(self, + model: nn.Module, step: int, + max_batches: int = None, + dataloader: DataLoader = None, + **kwargs: any) -> dict[any]: + """ + One evaluation loop over a validation dataset + """ + max_batches = (self.max_eval_batches if max_batches is None else max_batches) + dataloader = self.eval_loader if dataloader is None else dataloader + pbar = tqdm(dataloader, leave=False, colour='green', + desc=f'Evaluating at step {step}') + + model.eval() + step_loss = 0 + step_eval_metrics = {} + with torch.no_grad(): + for ix, data in enumerate(pbar): + loss, eval_metrics = self.compute_loss(model, data) + if not self.compute_loss_backprop: + loss = loss.item() # otherwise already float + if ix == 0: + step_eval_metrics[self.metric_for_best_model] = [loss] + for k, v in eval_metrics.items(): + step_eval_metrics[f'eval/{k}'] = [v] + else: + step_eval_metrics[self.metric_for_best_model].append(loss) + for k, v in eval_metrics.items(): + step_eval_metrics[f'eval/{k}'].append(v) + + step_loss += loss + desc = f"Evaluating at step {step} | loss: {step_loss / (ix + 1):.3f}" + if self.optimizer is not None: + desc += f" | lr: {self.optimizer.param_groups[0]['lr']:.5f}" + pbar.set_description(desc) + if ix == max_batches: + break + + # Average over batches + for k, v in step_eval_metrics.items(): + step_eval_metrics[k] = sum(v) / len(v) + print(f'Eval step {step}:', step_eval_metrics) + del loss + torch.cuda.empty_cache() + return step_eval_metrics + + def compute_loss(self, model: nn.Module, data: torch.Tensor, + sample_idx: int = None, **kwargs: any, + ) -> tuple[torch.Tensor, dict[any]]: + """ + Main method to determine how models are trained. + -> Defaults to next-token prediction / classification, + but override in child classes + + Args: + - model: nn.Module, HF model to train + - data: dict[torch.Tensor], HF datasets batch of data + - sample_idx: int, index of batch in dataset + """ + input_keys = {'input_ids', 'attention_mask'} + inputs = {k: v.to(model.device) + for k, v in data.items() if k in input_keys} + + outputs = model(**inputs, output_attentions=False, use_cache=False) + + outputs = outputs.get('logits')[..., :-1, :].contiguous() + targets = data.get('labels')[..., 1:].contiguous() + + # Look at model outputs + if self.print_samples and sample_idx is not None and (sample_idx + 1) % 100 == 0: + decode_samples(outputs, targets, self.tokenizer, sample_idx) + + # Flatten and compute cross-entropy loss + outputs = outputs.view(-1, outputs.shape[-1]) + targets = targets.view(-1).to(outputs.device) + try: + loss = self.criterion(outputs, targets) + except Exception as e: + print('outputs.shape', outputs.shape) + print('targets.shape', targets.shape) + raise e + + targets = targets.cpu() + outputs = outputs.cpu() + return loss, {'ppl': torch.exp(loss).item(), 'seq_len': targets.shape[-1] + 1} + + def save_trainable_weights(self, model: nn.Module): + """ + Save checkpoint with only weights actively being trained (e.g., for adapters). + Make sure to later load with model.load_state_dict(state_dict, strict=False) + """ + with torch.no_grad(): + state_dict = OrderedDict() + for n, p in model.named_parameters(): + if p.requires_grad: + state_dict[n] = p.cpu() # assurance + return state_dict + + def init_checkpointing(self, + args: ArgumentParser, + checkpoint_suffix: str) -> None: + """ + Initialize checkpointing attributes + + Inputs: + - args: Argparse or HuggingFace TrainingArguments object + - checkpoint_suffix: str to append to checkpoint name + """ + self.best_val_checkpoint_path = f'{join(args.checkpoint_dir, args.run_name)}.pt' + if checkpoint_suffix is not None: + self.best_val_checkpoint_path = self.best_val_checkpoint_path.replace( + '.pt', f'{checkpoint_suffix}.pt') + print(f'-> Saving best model checkpoint to {self.best_val_checkpoint_path}') + if self.save_results: + self.results_path = self.best_val_checkpoint_path.replace( + '.pt', '.csv').replace(args.checkpoint_dir, args.results_dir) + print(f'-> Saving results to {self.results_path}') + + # Best metric setup + self.best_val_metric = 0 if self.greater_is_better else 1e10 + self.best_val_metric_epoch = 0 + self.best_val_metric_step = 0 + self.best_train_metric = 0 if self.greater_is_better else 1e10 + self.best_train_metric_epoch = 0 + self.best_train_metric_step = 0 + self.metric_for_best_model = self.metric_for_best_model + if self.metric_for_best_model is not None: + if 'eval' not in self.metric_for_best_model: + self.metric_for_best_model = f'eval/{self.metric_for_best_model}' diff --git a/src/trainer/distill_attention_mse_linear.py b/src/trainer/distill_attention_mse_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..95110c95d577929532560fd84155f9f7e7bd2fdb --- /dev/null +++ b/src/trainer/distill_attention_mse_linear.py @@ -0,0 +1,99 @@ +""" +Custom trainer class for distilling attentions ("attention transfer") over long sequences with recurrent linear attention view. Can substitute for Hugging Face trainer. +""" +import torch +import torch.nn as nn + +from tqdm import tqdm + +from src.model.modeling_llama import get_attention_cache +from src.model.convert_model import traverse_layers +from .default_lm import OurTrainer as DefaultTrainer + + +class OurTrainer(DefaultTrainer): + """ + Custom trainer class for distilling attentions. + - We compute and store the attention outputs and/or weights for each head and layer, + for both the "teacher" softmax attentions and "student" learnable subquadratic attentions + - We then train the student layers to minimize either MSE(outputs) or CrossEntropy(weights) + """ + def __init__(self, + model: nn.Module, + metric_for_best_model: str = 'distill/eval/loss', + mse_factor: float = 1e3, + **kwargs: any): + super().__init__(model=model, + metric_for_best_model=metric_for_best_model, + **kwargs) + self.criterion_mse = nn.MSELoss(reduction='mean') + self.mse_factor = mse_factor + self.xent_factor = 0 + self.compute_loss_backprop = False # Whether we backprop in self.compute_loss + + + def compute_loss(self, model: nn.Module, data: dict[torch.Tensor], + sample_idx: int = None, **kwargs: any,) -> tuple[torch.Tensor, dict[any]]: + """ + Attention distillation ("attention transfer") + - For each layer and head, get attentions and train to + minimize some combo of MSE and cross-entropy loss + """ + input_seq_len = data['input_ids'].shape[-1] + inputs = {'input_ids': data['input_ids'].to(model.device)} # assume all inputs good + + # Get softmax attention outputs + with torch.no_grad(): + # Set base_inference to True to use FlashAttention + for layer in traverse_layers(model): + layer.self_attn.base_inference = True + # Get hidden states + true_outputs = model(**inputs, output_attentions=True, + use_cache=False,) + # no_logit_float=True,) + # Hack were we save attention layer inputs and outputs in outputs.attentions + # -> see model/hedgehog_attention_tk_long.py + # attn_inputs = [a[0] for a in true_outputs.get('attentions')] + # attn_outputs = [a[1] for a in true_outputs.get('attentions')] + true_attn_io = true_outputs.get('attentions') # layer-wise attn inputs and outputs + true_outputs = true_outputs.get('logits').cpu() + for layer in traverse_layers(model): + layer.self_attn.base_inference = False + inputs = {k: v.cpu() for k, v in inputs.items()} + torch.cuda.empty_cache() + + # Get trainable subquadratic attention outputs + attention_type = getattr(layer.self_attn, 'attention_type', None) + past_key_values = get_attention_cache(attention_type) + + total_seq_len = 0 + position_ids = torch.arange(input_seq_len).view(1, -1) + + loss_mse = 0 + for layer_idx, layer in enumerate(tqdm(traverse_layers(model), desc='Processing layer', + leave=False)): + attn_input, attn_output = true_attn_io[layer_idx] + attn_preds = layer.self_attn(attn_input.to(model.device), + attention_mask=None, + position_ids=position_ids.to(model.device), + past_key_value=past_key_values)[1] + if self.mse_factor > 0: # MSE on layer outputs + loss_mse += self.criterion_mse(attn_preds, attn_output.to(model.device)) + del attn_input; del attn_output + loss_mse = loss_mse / (layer_idx + 1) * self.mse_factor + loss = loss_mse + torch.cuda.empty_cache() + + if 'position_ids' in data: + outputs = {'loss_mse': loss_mse.item(), + 'loss_xent': 0, + 'mse_factor': self.mse_factor, + 'xent_factor': self.xent_factor, + 'input_len': data['position_ids'].shape[1], + 'position_ids': data['position_ids'][0],} + else: + outputs = {'loss_mse': loss_mse.item(), + 'loss_xent': 0, + 'mse_factor': self.mse_factor, + 'xent_factor': self.xent_factor,} + return loss, outputs \ No newline at end of file diff --git a/src/trainer/distill_attention_xent_mse.py b/src/trainer/distill_attention_xent_mse.py new file mode 100644 index 0000000000000000000000000000000000000000..68a9bd2150a454da7684ce4f2623748abafac55a --- /dev/null +++ b/src/trainer/distill_attention_xent_mse.py @@ -0,0 +1,87 @@ +""" +Custom trainer class for distilling attentions ("attention transfer"). Can substitute for Hugging Face trainer. + +In this implementation we support using either just the softmax attention outputs, or the softmax attention weights. +""" +import torch +import torch.nn as nn + +from .default_lm import OurTrainer as DefaultTrainer + + +class OurTrainer(DefaultTrainer): + """ + Custom trainer class for distilling attentions. + - We compute and store the attention outputs and/or weights for each head and layer, + for both the "teacher" softmax attentions and "student" learnable subquadratic attentions + - We then train the student layers to minimize either MSE(outputs) or CrossEntropy(weights) + """ + def __init__(self, + model: nn.Module, + metric_for_best_model: str = 'distill/eval/loss', + mse_factor: float = 1e3, + xent_factor: float = 0, + **kwargs: any): + super().__init__(model=model, + metric_for_best_model=metric_for_best_model, + **kwargs) + self.criterion_xent = nn.CrossEntropyLoss(reduction='mean') + self.criterion_mse = nn.MSELoss(reduction='mean') + self.mse_factor = mse_factor + self.xent_factor = xent_factor + self.compute_loss_backprop = False # Whether we backprop in self.compute_loss + + def compute_loss(self, model: nn.Module, data: dict[torch.Tensor], + sample_idx: int = None, **kwargs: any,) -> tuple[torch.Tensor, dict[any]]: + """ + Attention distillation ("attention transfer") + - For each layer and head, get attentions and train to + minimize some combo of MSE and cross-entropy loss + """ + inputs = {k: v.to(model.device) for k, v in data.items() if k != 'labels'} + outputs = model(**inputs, output_attentions=True, use_cache=False) + outputs = outputs.get('attentions') + + # Attentions are tuple[tuple[torch.Tensor, torch.Tensor]] + # n_layers x (predicted_attns, true_attns) + # predicted_attns and true_attns are shape (batch, n_heads, q_len, k_len) + loss_mse = 0 + loss_xent = 0 + n_layers = 0 # Number of layers to distill + softmax_layers = [] + for layer_idx, attns in enumerate(outputs): + if attns is not None: + if len(attns) != 2: + attns = attns.cpu() + else: + if self.xent_factor > 0: + # Cross-entropy loss + a_pred, a_true = attns[0] + a_pred = a_pred.clamp(min=1e-12).log() # nn.CrossEntropy assumes unnormalized logits + k_len = a_true.shape[-1] # batch, n_heads, q_len, k_len + # Compute mean cross-entropy over all queries + a_pred = a_pred.contiguous().view(-1, k_len) + a_true = a_true.contiguous().view(-1, k_len) + loss_xent += self.criterion_xent(a_pred, a_true) + if self.mse_factor > 0: + loss_mse += self.criterion_mse(*attns[1]) + n_layers += 1 + else: + softmax_layers.append(layer_idx) + if n_layers > 0: + loss_xent = loss_xent / n_layers * self.xent_factor + loss_mse = loss_mse / n_layers * self.mse_factor + loss = loss_xent + loss_mse + if 'position_ids' in data: + outputs = {'loss_xent': loss_xent.item() if self.xent_factor > 0 else 0, + 'loss_mse': loss_mse.item() if self.mse_factor > 0 else 0, + 'input_len': data['position_ids'].shape[1], + 'position_ids': data['position_ids'][0].detach().cpu().numpy(), + 'mse_factor': self.mse_factor, + 'xent_factor': self.xent_factor,} + else: + outputs = {'loss_xent': loss_xent.item() if self.xent_factor > 0 else 0, + 'loss_mse': loss_mse.item() if self.mse_factor > 0 else 0, + 'mse_factor': self.mse_factor, + 'xent_factor': self.xent_factor} + return loss, outputs diff --git a/src/trainer/finetune_seq2seq.py b/src/trainer/finetune_seq2seq.py new file mode 100644 index 0000000000000000000000000000000000000000..a85cd5ae2f1a9159f6d4e8a65083f3abd84d84d7 --- /dev/null +++ b/src/trainer/finetune_seq2seq.py @@ -0,0 +1,140 @@ +""" +General seq2seq / input-output trainer +""" +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data import DataLoader + +from tqdm import tqdm + +from .default_lm import OurTrainer as DefaultTrainer +from .utils import replace_padding_tokens + + +def compute_scrolls_metrics(eval_preds, scrolls_metric, tokenizer): + """ + Function to compute metrics that are also in SCROLLS (ROUGE, F1, etc.) + """ + preds, labels = eval_preds + if isinstance(preds, tuple): + preds = preds[0] + # Replace -100s used for padding as we can't decode them + preds = replace_padding_tokens(preds, tokenizer.pad_token_id) + labels = replace_padding_tokens(labels, tokenizer.pad_token_id) + + decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + + # Scrolls metric expects predictions to be [pred_1, pred_2, ...] + # and references to be [[ref_1], [ref_2], ... ] + decoded_labels = [[s] for s in decoded_labels] + + result = scrolls_metric.compute(predictions=decoded_preds, + references=decoded_labels) + print('----------------') + print('Model generation') + print(decoded_preds[:10]) + print('----------------') + print('True answer') + print(decoded_labels[:10]) + return result + + +class OurTrainer(DefaultTrainer): + """ + Evaluator for seq-to-seq / generation benchmarks + """ + def __init__(self, model, args, # max_eval_batches: Optional[int] = 100, + **kwargs: any): + super().__init__(model=model, args=args, **kwargs) + # Reset + determine metric for best automatically based on the dataset + self.metric_for_best = None + self.is_better = lambda x, y: x > y # Hardcode greater is better for now + self.print_steps = getattr(args, 'print_steps', 100) + print(f'self.print_steps:', self.print_steps) + # ablation sweep + self.max_eval_batches = 10 + + def init_criterion_(self): + pass + + def compute_loss(self): + pass + + def evaluate(self, *args: any, **kwargs: any): + return self.eval_step(*args, **kwargs) + + def eval_step(self, model: nn.Module, step: int, + dataloader: DataLoader = None, + max_batches: int = None, + prefix: str = None, + **kwargs: any): # -1): + """ + One evaluation step + """ + total = 0 + total_loss = 0 + metrics = {} + max_batches = self.max_eval_batches if max_batches is None else max_batches + max_batches = 10 # ablation sweep + + dataloader = (dataloader if dataloader is not None else self.eval_loader) + + scrolls_metric = dataloader.dataset.metric # Should be assigned in dataset + tokenizer = dataloader.dataset.tokenizer + + # Save decoded predictions and references here to compute average metrics + predictions, references = [], [] + + model.eval() + + pbar = tqdm(dataloader, leave=False, colour='green', + desc=f'Evaluating at step {step}') + + with torch.no_grad(): + for ix, data in enumerate(pbar): + inputs = {k: v.to(self.device) for k, v in data.items() + if k in ['input_ids', 'attention_mask']} + labels = data['labels'] + outputs = model.generate(**inputs, + max_new_tokens=1024, # hardcoded for now + pad_token_id=tokenizer.pad_token_id, + use_cache=True,).cpu() + # Only save newly generated tokens + pred_ids = outputs[:, data['input_ids'].shape[1]:] + predictions.append(pred_ids) + references.append(labels) + pbar.set_description(f"Evaluating at step {step} | input_len: {data['input_ids'].shape[1]} | output_len: {labels.shape[1]}") + + if ix == max_batches: + break + + if (ix + 1) % self.print_steps == 0: # 100 == 0: + print(f'Model input: \n', tokenizer.batch_decode(inputs['input_ids'].detach().cpu())[0]) + print(f'Model output:\n', tokenizer.batch_decode(pred_ids)[0]) + print(f'True output:\n', tokenizer.batch_decode(labels)[0]) + + # Compute and save metrics + try: + predictions = torch.cat(predictions, dim=0) + references = torch.cat(references, dim=0) + except: + pass + _metric = compute_scrolls_metrics((predictions, references), + scrolls_metric, tokenizer) + if self.metric_for_best is None: # Hard-coded for now + if 'f1' in _metric: + self.metric_for_best = f'eval/f1' + elif 'exact_match' in _metric: + self.metric_for_best = f'eval/exact_match' + elif 'rouge/geometric_mean' in _metric: + self.metric_for_best = f'eval/rouge/geometric_mean' + for k, v in _metric.items(): + if 'display' not in k: + _k = f'{prefix}/eval/{k}' if prefix is not None else f'eval/{k}' + metrics[_k] = v + + return metrics diff --git a/src/trainer/optim.py b/src/trainer/optim.py new file mode 100644 index 0000000000000000000000000000000000000000..c9e79512990d01215b39d1efa19568d7f65339df --- /dev/null +++ b/src/trainer/optim.py @@ -0,0 +1,48 @@ +""" +Optimizer and schedulers +""" +import torch +import torch.nn as nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler + + +def get_optimizer(optim: str, model: nn.Module, **kwargs: any) -> Optimizer: + """ + Return training optimizer + """ + if optim == 'sgd': + return torch.optim.SGD(model.parameters(), **kwargs) + elif optim == 'adam': + return torch.optim.Adam(model.parameters(), **kwargs) + elif optim in ['adamw', 'adamw_torch']: + return torch.optim.AdamW(model.parameters(), **kwargs) + elif optim == 'adamw_torch_fused': + return torch.optim.AdamW(model.parameters(), **kwargs, fused=True) + elif optim == 'adafactor': + from transformers import Adafactor + kwargs['relative_step'] = False # for now + return Adafactor(model.parameters(), **kwargs) + else: + raise NotImplementedError(f"{optim} optimizer not implemented sorry.") + + +def get_scheduler(lr_scheduler_type: str, optimizer: Optimizer, + **kwargs: any) -> LRScheduler: + """ + Return learning rate scheduler + """ + if lr_scheduler_type in ['plateau', 'reduce_lr_on_plateau']: + from torch.optim.lr_scheduler import ReduceLROnPlateau + return ReduceLROnPlateau(optimizer=optimizer, **kwargs) + + elif lr_scheduler_type == 'cosine_warmup': + from transformers import get_cosine_schedule_with_warmup + return get_cosine_schedule_with_warmup(optimizer=optimizer, **kwargs) + + elif lr_scheduler_type in ['linear_warmup', 'linear']: + from transformers import get_linear_schedule_with_warmup + return get_linear_schedule_with_warmup(optimizer=optimizer, **kwargs) + + else: + return None \ No newline at end of file diff --git a/src/trainer/utils.py b/src/trainer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dcc84b257f92136e319bc71fe372c53e88df197b --- /dev/null +++ b/src/trainer/utils.py @@ -0,0 +1,42 @@ +""" +Training loop helpers +""" +import torch +import numpy as np + +from transformers.tokenization_utils import PreTrainedTokenizer + + +def replace_padding_tokens(token_ids: torch.Tensor, + pad_token_id: int, + ignore_token_id: int = -100) -> any: + """ + Replace ignore_token_id tokens with pad_token_id, + e.g., for printing inputs during training + """ + if isinstance(token_ids, list): + return [np.where(t != ignore_token_id, t, pad_token_id)[0] for t in token_ids] + else: + return np.where(token_ids != ignore_token_id, token_ids, pad_token_id) + + +def decode_samples(outputs: torch.Tensor, + targets: torch.Tensor, + tokenizer: PreTrainedTokenizer, + sample_idx: int = None) -> None: + """ + Print first element of samples for debugging + """ + print('=' * 20) + print(f'*** TARGETS (sample {sample_idx})***') + tokens = tokenizer.decode( + replace_padding_tokens(targets[0], tokenizer.pad_token_id) + ) + print(tokens) + print('-' * 20) + print(f'*** PREDICTIONS (sample {sample_idx}) ***') + pred_logits = outputs.argmax(dim=-1).cpu() + pred_tokens = tokenizer.decode( + replace_padding_tokens(pred_logits[0], tokenizer.pad_token_id) + ) + print(pred_tokens) diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/utils/logging.py b/src/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..5f5acff3f1352e4cbd92c27518c7b2fd72a5d8f3 --- /dev/null +++ b/src/utils/logging.py @@ -0,0 +1,117 @@ +""" +Logging utilities to make terminal slightly more delightful +""" +import rich.syntax +import rich.tree + +from omegaconf import OmegaConf, DictConfig, ListConfig + + +def _format_arg(arg_name: str, cutoff=2) -> str: + if arg_name is None: + return arg_name + arg_name = str(arg_name) + + # Hardcode to handle backslash + name_splits = arg_name.split('/') + if len(name_splits) > 1: + return name_splits[-1] + # Abbreviate based on underscore + name_splits = arg_name.split('_') + if len(name_splits) > 1: + return ''.join([s[0] for s in name_splits]) + else: + return arg_name[:cutoff] + + +def print_header(x: str) -> None: + """ + Print a header with a line above and below + """ + print('-' * len(x)) + print(x) + print('-' * len(x)) + + +def print_args(args, return_dict=False, verbose=True): + """ + Print the arguments passed to the script + """ + attributes = [a for a in dir(args) if a[0] != '_'] + arg_dict = {} # switched to ewr + if verbose: + print('ARGPARSE ARGS') + for ix, attr in enumerate(attributes): + fancy = '└─' if ix == len(attributes) - 1 else '├─' + if verbose: + print(f'{fancy} {attr}: {getattr(args, attr)}') + arg_dict[attr] = getattr(args, attr) + if return_dict: + return arg_dict + + +def update_description_metrics(description: str, metrics: dict): + """ + Set the numbers that show up on progress bars + """ + for split in metrics: + if split != 'test': # No look + for metric_name, metric in metrics[split].items(): + description += f' | {split}/{metric_name}: {metric:.3f}' + return description + + +# Control how tqdm progress bar looks +def type_of_script(): + try: + ipy_str = str(type(get_ipython())) + if 'zmqshell' in ipy_str: + return 'jupyter' + if 'terminal' in ipy_str: + return 'ipython' + except: + return 'terminal' + +# Progress bar +def update_pbar_display(metrics, batch_ix, pbar, prefix, batch_size, accum_iter=1): + description = f'└── {prefix} batch {int(batch_ix)}/{len(pbar)} [batch size: {batch_size} - grad. accum. over {accum_iter} batch(es)]' + for metric_name, metric in metrics.items(): + if metric_name == 'correct': + description += f' | {metric_name} (acc. %): {int(metric):>5d}/{int(metrics["total"])} = {metric / metrics["total"] * 100:.3f}%' + elif metric_name == 'acc': + description += f' | {metric_name}: {metric:.3f}' + elif metric_name in ['perplexity']: # , 'bpc']: + description += f' | {metric_name}: {Decimal(metric):.3E}' + elif metric_name != 'total': + description += f' | {metric_name}: {metric / metrics["total"]:.3f}' + pbar.set_description(description) + + +def print_config(config: DictConfig, + resolve: bool = True, + name: str = 'CONFIG') -> None: + """Prints content of DictConfig using Rich library and its tree structure. + Args: + config (DictConfig): Configuration composed by Hydra. + fields (Sequence[str], optional): Determines which main fields from config will + be printed and in what order. + resolve (bool, optional): Whether to resolve reference fields of DictConfig. + """ + + style = "bright" # "dim" + tree = rich.tree.Tree(name, style=style, guide_style=style) + + fields = config.keys() + for field in fields: + branch = tree.add(field, style=style, guide_style=style) + + config_section = config.get(field) + branch_content = str(config_section) + if isinstance(config_section, DictConfig): + branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) + elif isinstance(config_section, ListConfig): + branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + rich.print(tree) \ No newline at end of file diff --git a/src/utils/setup.py b/src/utils/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..b25cf9cfe8cfd6b6b76ef05dab9747050008e8d3 --- /dev/null +++ b/src/utils/setup.py @@ -0,0 +1,201 @@ +""" +General helper functions for setting up experiments +""" +import os +import random + +from argparse import ArgumentParser +from omegaconf import DictConfig + +import torch +import numpy as np + +from .logging import _format_arg + + +def init_wandb(args: ArgumentParser) -> any: + """Initialize WandB""" + if args.no_wandb: + wandb = None + else: + import wandb + wandb.init(config={}, + entity=args.wandb_entity, + name=args.run_name, + project=args.project_name) + return wandb + + +def seed_everything(seed: int) -> None: + """ + Seed everything + """ + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def get_run_name_from_checkpoint(checkpoint_path: str) -> str: + """ + Helper function to get a condensed run name from the checkpoint path + """ + name = [] + for s in checkpoint_path.split('/')[-1].split('-'): + if '.pt' in s: + name.append(f'_{s[:-3]}') + try: + s = s.split('=') + s = ''.join([c[0] for c in s[1].split('_')]) + name.append(s) + except IndexError: + pass + return ''.join(name) + + +def get_run_name_from_args(args) -> str: + """ + Prepare a heinous identifier for the run based on args + """ + if args.load_distill_checkpoint is not None and args.load_distill_checkpoint != 'default': + distill_name = get_run_name_from_checkpoint(args.load_distill_checkpoint) + else: + distill_name = args.distill_config + if args.load_finetune_checkpoint is not None and args.finetune_config is None: # args.load_finetune_checkpoint != 'default': + finetune_name = get_run_name_from_checkpoint(args.load_finetune_checkpoint) + else: + finetune_name = args.finetune_config + args.run_name = f'dl-d={distill_name}-m={args.model_config}-f={finetune_name}' + if args.no_peft_grad_ckpt is not None: + args.run_name += f'-npgc={args.no_peft_grad_ckpt}' + args.run_name += f'-s={args.seed}' + if args.debug: + args.run_name += f'-debug' + if args.no_attention_mask is not None: + args.run_name += f'-nam=1' + return args.run_name.replace('True', '1').replace('False', '0') # concise hacks + + +def flatten_config(config: dict, flattened: dict, key: str) -> dict: + """ + Recursive way to flatten config args for saving to WandB + """ + for k, v in config.items(): + if isinstance(v, dict): + flatten_config(v, flattened, f'{key}{k}_') + elif isinstance(v, list): + for ix, _config in enumerate(v): + if isinstance(_config, dict): + flatten_config(_config, flattened, f'{key}{k}_{ix}_') + else: + flattened[f'{key}{k}'] = v + return flattened + + +def update_config_from_args(config: DictConfig, + args: ArgumentParser, + ignore_args: list = None) -> DictConfig: + """ + Quick hacks to override default configs + """ + ignore_args = [] if ignore_args is None else ignore_args + + # Dataset + if getattr(args, 'dataset', None): + config.dataset.name = args.dataset + args.run_name += f'-ds={args.dataset}' + + # Optimizer + for arg in ['lr', 'weight_decay']: + if arg not in ignore_args: + argval = getattr(args, arg, None) + if argval is not None: + setattr(config.optimizer, arg, argval) + args.run_name += f'-{_format_arg(arg)}={argval}' + try: + if getattr(args, 'optim', None): + config.optimizer.optim = args.optim + args.run_name += f'-o={args.optim}' + except AttributeError: + pass + + # Scheduler + try: + if getattr(args, 'scheduler', None): + config.lr_scheduler.lr_scheduler_type = args.scheduler + args.run_name += f'-sc={args.scheduler}' + except AttributeError: + pass + + # Dataset + for arg in [a for a in dir(args) if 'dataset_' in a]: + argval = getattr(args, arg, None) + if argval is not None: + setattr(config.dataset.dataset_config, arg[len('dataset_'):], argval) + args.run_name += f'-{_format_arg(arg)}={argval}' + + # Dataloader + for arg in ['batch_size']: # , 'num_workers']: + argval = getattr(args, arg, None) + if argval is not None: + setattr(config.dataloader, arg, argval) + args.run_name += f'-{_format_arg(arg)}={argval}' + + # Trainer + for arg in ['gradient_accumulation_steps', 'num_train_epochs', + 'max_steps', 'max_finetune_steps', 'eval_steps', + 'seed', 'max_eval_batches']: + argval = getattr(args, arg, None) + if argval is not None: + setattr(config.trainer, arg, argval) + if arg in ['max_steps', 'max_finetune_steps', + 'gradient_accumulation_steps', 'num_train_epochs', 'seed']: + args.run_name += f'-{_format_arg(arg)}={argval}' + + # Misc + for arg in ['replicate']: + argval = getattr(args, arg, None) + if argval is not None: + args.run_name += f'-{_format_arg(arg)}={argval}' + + return config + + +def update_model_config_from_args(model_config: DictConfig, + args: ArgumentParser) -> DictConfig: + """ + Override default configs given argparse args + """ + # Overall attention + for arg in ['attention_type', 'learned_kernel', 'tie_qk_kernels', + 'train_qk', 'state_chunk_len', 'no_peft_grad_ckpt', + 'window_size']: + argval = getattr(args, arg, None) + if argval is not None: + setattr(model_config['attention'], arg, argval) + args.run_name += f'-{_format_arg(arg)}={argval}' + else: + try: + getattr(model_config['attention'], arg) + except AttributeError: + setattr(model_config['attention'], arg, None) + + # Learned kernel + for arg in ['lk_skip_connection', 'lk_zero_init', 'lk_normal_init']: + argval = getattr(args, arg, None) + if argval is not None: + setattr(model_config['attention']['learned_kernel_kwargs'], + arg[len('lk_'):], argval) + args.run_name += f'-{_format_arg(arg)}={argval}' + + # Pretrained model + if args.pretrained_model_name_or_path is not None: # if specified + pmnop = args.pretrained_model_name_or_path + model_config.model.pretrained_model_name_or_path = pmnop + args.run_name += f'-pmnop={pmnop.split("/")[-1]}' + + return model_config