winglian commited on
Commit
6b9b229
·
unverified ·
1 Parent(s): 131afdb

btlm and falcon monkey patches for flash attn (#566)

Browse files
examples/cerebras/btlm-ft.yml ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: cerebras/btlm-3b-8k-base
2
+ base_model_config: cerebras/btlm-3b-8k-base
3
+ model_type: AutoModelForCausalLM
4
+ tokenizer_type: GPT2Tokenizer
5
+ trust_remote_code: true
6
+ tokenizer_use_fast: true
7
+ tokenizer_legacy: true
8
+
9
+ load_in_8bit: false
10
+ load_in_4bit: false
11
+ strict: false
12
+ push_dataset_to_hub:
13
+ hf_use_auth_token: true
14
+ datasets:
15
+ - path: mhenrichsen/alpaca_2k_test
16
+ type: alpaca
17
+ dataset_prepared_path: last_prepared_run
18
+ val_set_size: 0.01
19
+
20
+ adapter:
21
+ lora_model_dir:
22
+ sequence_len: 2048
23
+ max_packed_sequence_len:
24
+ sample_packing: false
25
+ sample_packing_eff_est:
26
+ sample_packing_seq_len_multiplier:
27
+ total_num_tokens:
28
+
29
+ lora_r:
30
+ lora_alpha:
31
+ lora_dropout:
32
+ lora_target_modules:
33
+ lora_target_linear:
34
+ lora_fan_in_fan_out:
35
+
36
+ wandb_project:
37
+ wandb_entity:
38
+ wandb_watch:
39
+ wandb_run_id:
40
+ wandb_log_model:
41
+
42
+ output_dir: btlm-out
43
+ gradient_accumulation_steps: 1
44
+ micro_batch_size: 1
45
+ num_epochs: 1
46
+ optimizer: adamw_torch
47
+ adam_beta2: 0.95
48
+ adam_eps: 0.000000001
49
+ max_grad_norm: 1.0
50
+
51
+ torchdistx_path:
52
+ lr_scheduler: cosine
53
+ lr_quadratic_warmup: true
54
+ learning_rate: 0.000085
55
+ train_on_inputs: true
56
+ group_by_length: false
57
+ bf16: true
58
+ fp16: false
59
+ tf32: true
60
+
61
+ gradient_checkpointing: false
62
+ early_stopping_patience:
63
+ resume_from_checkpoint:
64
+ local_rank:
65
+ logging_steps: 1
66
+
67
+ xformers_attention:
68
+ flash_attention: true
69
+ sdp_attention:
70
+ flash_optimum:
71
+
72
+ gptq_groupsize:
73
+ gptq_model_v1:
74
+
75
+ warmup_steps: 32
76
+ eval_steps:
77
+ save_steps:
78
+ save_total_limit:
79
+
80
+ debug:
81
+ deepspeed:
82
+ weight_decay: 0.1
83
+ special_tokens:
84
+ pad_token: "<|endoftext|>"
85
+ fsdp:
86
+ # - full_shard
87
+ # - auto_wrap
88
+ fsdp_config:
89
+ # fsdp_state_dict_type: FULL_STATE_DICT
90
+ # fsdp_transformer_layer_cls_to_wrap: BTLMBlock
src/axolotl/monkeypatch/btlm_attn_hijack_flash.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Flash attention monkey patch for cerebras btlm model
3
+ """
4
+
5
+ import importlib
6
+ import logging
7
+ from typing import Optional, Tuple
8
+
9
+ import torch
10
+ from flash_attn.flash_attn_interface import flash_attn_func
11
+ from transformers import AutoConfig, AutoModelForCausalLM
12
+
13
+ LOG = logging.getLogger("axolotl")
14
+
15
+
16
+ def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):
17
+ # this is a wonky hack to get the remotely loaded module
18
+ model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
19
+ # we need to load the model here in order for modeling_btlm to be available
20
+ AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
21
+ module_name = model_config.__class__.__module__.replace(
22
+ ".configuration_btlm", ".modeling_btlm"
23
+ )
24
+ modeling_btlm = importlib.import_module(module_name)
25
+ modeling_btlm.BTLMAttention._attn = ( # pylint: disable=protected-access
26
+ flashattn_attn
27
+ )
28
+
29
+
30
+ def flashattn_attn(
31
+ self,
32
+ query: torch.Tensor,
33
+ key: Optional[torch.Tensor] = None,
34
+ value: Optional[torch.Tensor] = None,
35
+ attention_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
36
+ head_mask: Optional[torch.Tensor] = None,
37
+ position_bias: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
38
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
39
+ softmax_scale = (
40
+ 1 / (key.size(-1) ** self.attn_scale_power) if self.scale_attn_weights else None
41
+ )
42
+
43
+ query = query.permute(0, 2, 1, 3)
44
+ key = key.permute(0, 2, 1, 3)
45
+ value = value.permute(0, 2, 1, 3)
46
+
47
+ # Perform Flash attention
48
+ attn_output = flash_attn_func(
49
+ query,
50
+ key,
51
+ value,
52
+ dropout_p=0.0, # Assuming you have this attribute
53
+ softmax_scale=softmax_scale, # Set this if you have specific scaling in mind
54
+ causal=not self.is_cross_attention, # Assuming you have this attribute
55
+ return_attn_probs=False, # Set this based on your needs
56
+ )
57
+
58
+ # Optional: Apply head mask if it's not None
59
+ if head_mask is not None:
60
+ attn_output *= head_mask
61
+
62
+ attn_output = attn_output.permute(0, 2, 1, 3)
63
+
64
+ return attn_output, None # We don't have explicit attn_weights in Flash attention
src/axolotl/monkeypatch/falcon_attn_hijack_flash.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Flash Attention monkey patch for Falcon
3
+
4
+ copied from https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/falcon_flash_attn_monkey_patch.py
5
+ """
6
+
7
+ from typing import Optional, Tuple
8
+
9
+ import torch
10
+ import transformers
11
+ from flash_attn import flash_attn_func
12
+
13
+
14
+ def forward(
15
+ self,
16
+ hidden_states: torch.Tensor,
17
+ alibi: Optional[torch.Tensor],
18
+ attention_mask: torch.Tensor, # pylint: disable=unused-argument
19
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
20
+ head_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
21
+ use_cache: bool = False,
22
+ output_attentions: bool = False, # pylint: disable=unused-argument
23
+ ):
24
+ fused_qkv = self.query_key_value(
25
+ hidden_states
26
+ ) # [batch_size, seq_length, 3 x hidden_size]
27
+ num_kv_heads = (
28
+ self.num_heads if self.new_decoder_architecture else self.num_kv_heads
29
+ )
30
+ # 3 x [batch_size, seq_length, num_heads, head_dim]
31
+ (
32
+ query_layer,
33
+ key_layer,
34
+ value_layer,
35
+ ) = self._split_heads( # pylint: disable=protected-access
36
+ fused_qkv
37
+ )
38
+
39
+ batch_size, query_length, _, _ = query_layer.shape
40
+
41
+ query_layer = query_layer.transpose(1, 2).reshape(
42
+ batch_size * self.num_heads, query_length, self.head_dim
43
+ )
44
+ key_layer = key_layer.transpose(1, 2).reshape(
45
+ batch_size * num_kv_heads,
46
+ query_length,
47
+ self.head_dim,
48
+ )
49
+ value_layer = value_layer.transpose(1, 2).reshape(
50
+ batch_size * num_kv_heads, query_length, self.head_dim
51
+ )
52
+
53
+ past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
54
+ query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
55
+
56
+ if layer_past is not None:
57
+ past_key, past_value = layer_past
58
+ # concatenate along seq_length dimension:
59
+ # - key: [batch_size * self.num_heads, kv_length, head_dim]
60
+ # - value: [batch_size * self.num_heads, kv_length, head_dim]
61
+ key_layer = torch.cat((past_key, key_layer), dim=1)
62
+ value_layer = torch.cat((past_value, value_layer), dim=1)
63
+
64
+ # unused
65
+ # _, kv_length, _ = key_layer.shape
66
+ if use_cache:
67
+ present = (key_layer, value_layer)
68
+ else:
69
+ present = None
70
+ # unused
71
+ # attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
72
+ query_layer_ = (
73
+ query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
74
+ .transpose(1, 2)
75
+ .to(torch.bfloat16)
76
+ )
77
+ key_layer_ = (
78
+ key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
79
+ .transpose(1, 2)
80
+ .to(torch.bfloat16)
81
+ )
82
+ value_layer_ = (
83
+ value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
84
+ .transpose(1, 2)
85
+ .to(torch.bfloat16)
86
+ )
87
+
88
+ if alibi is not None:
89
+ raise ValueError("`alibi` is not supported when `use_flash_attn` is True")
90
+
91
+ # below output will have shape (batch_size, seqlen, nheads, headdim)
92
+ attn_output = flash_attn_func(query_layer_, key_layer_, value_layer_, causal=True)
93
+ attn_output = attn_output.reshape(
94
+ batch_size, query_length, self.num_heads * self.head_dim
95
+ )
96
+ output_tensor = self.dense(attn_output)
97
+ return output_tensor, present
98
+
99
+
100
+ def replace_falcon_attn_with_flash_attn():
101
+ transformers.models.falcon.modeling_falcon.FalconAttention.forward = forward
src/axolotl/utils/models.py CHANGED
@@ -100,10 +100,31 @@ def load_model(
100
  base_model = cfg.base_model
101
  base_model_config = cfg.base_model_config
102
  model_type = cfg.model_type
 
103
 
104
  # TODO refactor as a kwarg
105
  load_in_8bit = cfg.load_in_8bit
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  if cfg.is_llama_derived_model and cfg.flash_attention:
108
  if cfg.device not in ["mps", "cpu"] and not inference:
109
  from axolotl.monkeypatch.llama_attn_hijack_flash import (
@@ -338,6 +359,9 @@ def load_model(
338
  for name, module in model.named_modules():
339
  if "norm" in name:
340
  module.to(torch.float32)
 
 
 
341
  if "lm_head" in name or "embed_tokens" in name:
342
  if hasattr(module, "weight"):
343
  module.to(torch.float32)
 
100
  base_model = cfg.base_model
101
  base_model_config = cfg.base_model_config
102
  model_type = cfg.model_type
103
+ model_config = load_model_config(cfg)
104
 
105
  # TODO refactor as a kwarg
106
  load_in_8bit = cfg.load_in_8bit
107
 
108
+ if hasattr(model_config, "model_type") and model_config.model_type == "btlm":
109
+ if cfg.flash_attention:
110
+ from axolotl.monkeypatch.btlm_attn_hijack_flash import (
111
+ replace_btlm_attn_with_flash_attn,
112
+ )
113
+
114
+ replace_btlm_attn_with_flash_attn(cfg.base_model)
115
+
116
+ if hasattr(model_config, "model_type") and model_config.model_type in [
117
+ "falcon",
118
+ "RefinedWebModel",
119
+ "RefinedWeb",
120
+ ]:
121
+ if cfg.flash_attention:
122
+ from axolotl.monkeypatch.falcon_attn_hijack_flash import (
123
+ replace_falcon_attn_with_flash_attn,
124
+ )
125
+
126
+ replace_falcon_attn_with_flash_attn()
127
+
128
  if cfg.is_llama_derived_model and cfg.flash_attention:
129
  if cfg.device not in ["mps", "cpu"] and not inference:
130
  from axolotl.monkeypatch.llama_attn_hijack_flash import (
 
359
  for name, module in model.named_modules():
360
  if "norm" in name:
361
  module.to(torch.float32)
362
+ if model_config.model_type == "btlm":
363
+ # don't upcast lm_head for btlm
364
+ continue
365
  if "lm_head" in name or "embed_tokens" in name:
366
  if hasattr(module, "weight"):
367
  module.to(torch.float32)