winglian commited on
Commit
40a6362
1 Parent(s): d339beb

support for mamba (#915)

Browse files

* support for mamba

* more mamba fixes

* use fork for mamba kwargs fix

* grad checkpointing doesn't work

* fix extras for mamaba

* mamba loss fix

* use fp32 and remove verbose logging

* mamba fixes

* fix collator for mamba

* set model_type on training_args

* don't save safetensors for mamba

* update mamba config to disable safetensor checkpooints, install for tests

* no evals for mamba tests

* handle save_pretrained

* handle unused safetensors arg

.github/workflows/tests.yml CHANGED
@@ -73,7 +73,7 @@ jobs:
73
  run: |
74
  pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
75
  pip3 uninstall -y transformers accelerate
76
- pip3 install -U -e .[flash-attn]
77
  pip3 install -r requirements-tests.txt
78
 
79
  - name: Run e2e tests
 
73
  run: |
74
  pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
75
  pip3 uninstall -y transformers accelerate
76
+ pip3 install -U -e .[flash-attn,mamba-ssm]
77
  pip3 install -r requirements-tests.txt
78
 
79
  - name: Run e2e tests
examples/mamba/config.yml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: state-spaces/mamba-2.8b
2
+ model_type: MambaLMHeadModel
3
+ tokenizer_type: AutoTokenizer
4
+ tokenizer_config: EleutherAI/gpt-neox-20b
5
+
6
+ load_in_8bit: false
7
+ load_in_4bit: false
8
+ strict: false
9
+
10
+ datasets:
11
+ - path: mhenrichsen/alpaca_2k_test
12
+ type: alpaca
13
+ dataset_prepared_path:
14
+ val_set_size: 0.0
15
+ output_dir: ./out
16
+
17
+ sequence_len: 2048
18
+ sample_packing: false
19
+ pad_to_sequence_len: false
20
+
21
+ wandb_project:
22
+ wandb_entity:
23
+ wandb_watch:
24
+ wandb_name:
25
+ wandb_log_model:
26
+
27
+ gradient_accumulation_steps: 4
28
+ micro_batch_size: 1
29
+ num_epochs: 2
30
+ optimizer: paged_adamw_8bit
31
+ lr_scheduler: cosine
32
+ learning_rate: 5e-5
33
+
34
+ train_on_inputs: false
35
+ group_by_length: true
36
+
37
+ bf16: true
38
+ fp16: false
39
+ tf32: true
40
+
41
+ gradient_checkpointing: false
42
+ early_stopping_patience:
43
+ resume_from_checkpoint:
44
+ local_rank:
45
+ logging_steps: 1
46
+ xformers_attention:
47
+ flash_attention:
48
+
49
+ warmup_steps: 10
50
+ eval_steps:
51
+ eval_table_size:
52
+ eval_table_max_new_tokens: 128
53
+ save_steps: 0.25
54
+ debug:
55
+ deepspeed:
56
+ weight_decay: 0.0
57
+ fsdp:
58
+ fsdp_config:
59
+ special_tokens:
60
+ tokens:
61
+ save_safetensors: False
setup.py CHANGED
@@ -51,5 +51,8 @@ setup(
51
  "deepspeed": [
52
  "deepspeed",
53
  ],
 
 
 
54
  },
55
  )
 
51
  "deepspeed": [
52
  "deepspeed",
53
  ],
54
+ "mamba-ssm": [
55
+ "mamba-ssm==1.0.1",
56
+ ],
57
  },
58
  )
src/axolotl/core/trainer_builder.py CHANGED
@@ -31,7 +31,10 @@ from axolotl.utils.callbacks import (
31
  bench_eval_callback_factory,
32
  log_prediction_callback_factory,
33
  )
34
- from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
 
 
 
35
  from axolotl.utils.samplers import MultipackBatchSampler
36
  from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
37
 
@@ -49,6 +52,9 @@ class AxolotlTrainingArguments(TrainingArguments):
49
  Extend the base TrainingArguments for axolotl helpers
50
  """
51
 
 
 
 
52
  lr_quadratic_warmup: bool = field(
53
  default=False,
54
  metadata={"help": "Use quadratic warmup for cosine scheduling."},
@@ -285,6 +291,32 @@ class AxolotlTrainer(Trainer):
285
  return super().compute_loss(model, inputs, return_outputs=return_outputs)
286
 
287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  class OneCycleLRSchedulerTrainer(AxolotlTrainer):
289
  """
290
  Trainer subclass that uses the OneCycleLR scheduler
@@ -462,6 +494,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
462
  return OneCycleLRSchedulerTrainer
463
  if self.cfg.relora_steps:
464
  return ReLoRATrainer
 
 
465
  return AxolotlTrainer
466
 
467
  def build(self, total_num_steps):
@@ -529,7 +563,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
529
  if self.cfg.hub_strategy:
530
  training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
531
 
532
- if self.cfg.save_safetensors:
533
  training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
534
 
535
  if self.cfg.sample_packing_eff_est:
@@ -677,6 +711,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
677
  training_arguments_kwargs = self.hook_pre_create_training_args(
678
  training_arguments_kwargs
679
  )
 
680
  training_args = (
681
  AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
682
  **training_arguments_kwargs,
@@ -731,11 +766,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
731
  train_dataset=self.train_dataset,
732
  eval_dataset=self.eval_dataset,
733
  args=training_args,
734
- data_collator=BatchSamplerDataCollatorForSeq2Seq(
735
- self.tokenizer,
736
- return_tensors="pt",
737
- **data_collator_kwargs,
738
- ),
739
  bench_data_collator=transformers.DataCollatorForSeq2Seq(
740
  self.tokenizer,
741
  return_tensors="pt",
@@ -755,3 +786,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
755
  ] = self.cfg.micro_batch_size
756
 
757
  return trainer
 
 
 
 
 
 
 
 
 
 
 
31
  bench_eval_callback_factory,
32
  log_prediction_callback_factory,
33
  )
34
+ from axolotl.utils.collators import (
35
+ BatchSamplerDataCollatorForSeq2Seq,
36
+ MambaDataCollator,
37
+ )
38
  from axolotl.utils.samplers import MultipackBatchSampler
39
  from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
40
 
 
52
  Extend the base TrainingArguments for axolotl helpers
53
  """
54
 
55
+ model_type: Optional[str] = field(
56
+ default=None, metadata={"help": "HF model configuration model_type."}
57
+ )
58
  lr_quadratic_warmup: bool = field(
59
  default=False,
60
  metadata={"help": "Use quadratic warmup for cosine scheduling."},
 
291
  return super().compute_loss(model, inputs, return_outputs=return_outputs)
292
 
293
 
294
+ class AxolotlMambaTrainer(AxolotlTrainer):
295
+ """
296
+ Mamba specific trainer to handle loss calculation
297
+ """
298
+
299
+ def compute_loss(
300
+ self,
301
+ model,
302
+ inputs,
303
+ return_outputs=False, # pylint: disable=unused-argument
304
+ ):
305
+ input_ids = inputs.pop("input_ids")
306
+ lm_logits = model(input_ids).logits
307
+
308
+ labels = input_ids.to(lm_logits.device)
309
+ shift_logits = lm_logits[:, :-1, :].contiguous()
310
+ labels = labels[:, 1:].contiguous()
311
+
312
+ loss_fct = torch.nn.CrossEntropyLoss()
313
+ lm_loss = loss_fct(
314
+ shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
315
+ )
316
+
317
+ return lm_loss
318
+
319
+
320
  class OneCycleLRSchedulerTrainer(AxolotlTrainer):
321
  """
322
  Trainer subclass that uses the OneCycleLR scheduler
 
494
  return OneCycleLRSchedulerTrainer
495
  if self.cfg.relora_steps:
496
  return ReLoRATrainer
497
+ if self.cfg.model_config_type == "mamba":
498
+ return AxolotlMambaTrainer
499
  return AxolotlTrainer
500
 
501
  def build(self, total_num_steps):
 
563
  if self.cfg.hub_strategy:
564
  training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
565
 
566
+ if self.cfg.save_safetensors is not None:
567
  training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
568
 
569
  if self.cfg.sample_packing_eff_est:
 
711
  training_arguments_kwargs = self.hook_pre_create_training_args(
712
  training_arguments_kwargs
713
  )
714
+ training_arguments_kwargs["model_type"] = self.cfg.model_config_type
715
  training_args = (
716
  AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
717
  **training_arguments_kwargs,
 
766
  train_dataset=self.train_dataset,
767
  eval_dataset=self.eval_dataset,
768
  args=training_args,
769
+ data_collator=self.build_collator(**data_collator_kwargs),
 
 
 
 
770
  bench_data_collator=transformers.DataCollatorForSeq2Seq(
771
  self.tokenizer,
772
  return_tensors="pt",
 
786
  ] = self.cfg.micro_batch_size
787
 
788
  return trainer
789
+
790
+ def build_collator(self, **kwargs):
791
+ if self.cfg.model_config_type == "mamba":
792
+ return MambaDataCollator(tokenizer=self.tokenizer)
793
+
794
+ return BatchSamplerDataCollatorForSeq2Seq(
795
+ self.tokenizer,
796
+ return_tensors="pt",
797
+ **kwargs,
798
+ )
src/axolotl/models/mamba/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modeling module for Mamba models
3
+ """
4
+
5
+
6
+ def fix_mamba_attn_for_loss():
7
+ from mamba_ssm.models import mixer_seq_simple
8
+
9
+ from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed
10
+
11
+ mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed
12
+ return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name
src/axolotl/models/mamba/configuration_mamba.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HF Transformers MambaConfig
3
+ """
4
+ from transformers import PretrainedConfig
5
+
6
+
7
+ class MambaConfig(PretrainedConfig):
8
+ """
9
+ modeling configuration for state space model/mamba
10
+ """
11
+
12
+ model_type = "mamba"
13
+
14
+ def __init__(
15
+ self,
16
+ vocab_size=50280,
17
+ d_model=2560,
18
+ n_layer=64,
19
+ rms_norm=True,
20
+ residual_in_fp32=True,
21
+ fused_add_norm=True,
22
+ pad_vocab_size_multiple=8,
23
+ pad_token_id=50277,
24
+ bos_token_id=0,
25
+ eos_token_id=0,
26
+ tie_word_embeddings=False,
27
+ **kwargs,
28
+ ):
29
+ self.vocab_size = vocab_size
30
+ self.d_model = d_model
31
+ self.n_layer = n_layer
32
+ self.rms_norm = rms_norm
33
+ self.residual_in_fp32 = residual_in_fp32
34
+ self.fused_add_norm = fused_add_norm
35
+ self.pad_vocab_size_multiple = pad_vocab_size_multiple
36
+ super().__init__(
37
+ pad_token_id=pad_token_id,
38
+ bos_token_id=bos_token_id,
39
+ eos_token_id=eos_token_id,
40
+ tie_word_embeddings=tie_word_embeddings,
41
+ **kwargs,
42
+ )
src/axolotl/models/mamba/modeling_mamba.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ import os
3
+ from collections import namedtuple
4
+ from functools import partial
5
+ from typing import Optional, Union
6
+
7
+ import torch
8
+ from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights
9
+ from mamba_ssm.utils.generation import GenerationMixin
10
+ from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
11
+ from torch import nn
12
+ from torch.nn import CrossEntropyLoss
13
+
14
+ from axolotl.models.mamba.configuration_mamba import MambaConfig
15
+
16
+
17
+ class MambaLMHeadModel(nn.Module, GenerationMixin):
18
+ def __init__(
19
+ self,
20
+ d_model: int,
21
+ n_layer: int,
22
+ vocab_size: int,
23
+ initializer_cfg=None,
24
+ pad_vocab_size_multiple: int = 1,
25
+ device=None,
26
+ dtype=None,
27
+ **backbone_kwargs,
28
+ ) -> None:
29
+ factory_kwargs = {"device": device, "dtype": dtype}
30
+ super().__init__()
31
+ if vocab_size % pad_vocab_size_multiple != 0:
32
+ vocab_size += pad_vocab_size_multiple - (
33
+ vocab_size % pad_vocab_size_multiple
34
+ )
35
+ self.config = MambaConfig(
36
+ vocab_size=vocab_size,
37
+ d_model=d_model,
38
+ n_layer=n_layer,
39
+ pad_vocab_size_multiple=pad_vocab_size_multiple,
40
+ )
41
+ self.backbone = MixerModel(
42
+ d_model=d_model,
43
+ n_layer=n_layer,
44
+ vocab_size=vocab_size,
45
+ initializer_cfg=initializer_cfg,
46
+ **backbone_kwargs,
47
+ **factory_kwargs,
48
+ )
49
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
50
+
51
+ # Initialize weights and apply final processing
52
+ self.apply(
53
+ partial(
54
+ _init_weights,
55
+ n_layer=n_layer,
56
+ **(initializer_cfg if initializer_cfg is not None else {}),
57
+ )
58
+ )
59
+ self.tie_weights()
60
+
61
+ def tie_weights(self):
62
+ self.lm_head.weight = self.backbone.embedding.weight
63
+
64
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
65
+ return self.backbone.allocate_inference_cache(
66
+ batch_size, max_seqlen, dtype=dtype, **kwargs
67
+ )
68
+
69
+ def forward(
70
+ self,
71
+ input_ids,
72
+ position_ids=None,
73
+ inference_params=None,
74
+ num_last_tokens=0,
75
+ labels=None,
76
+ **kwargs,
77
+ ):
78
+ """
79
+ "position_ids" is just to be compatible with Transformer generation. We don't use it.
80
+ num_last_tokens: if > 0, only return the logits for the last n tokens
81
+ """
82
+ hidden_states = self.backbone(input_ids, inference_params=inference_params)
83
+ if num_last_tokens > 0:
84
+ hidden_states = hidden_states[:, -num_last_tokens:]
85
+ lm_logits = self.lm_head(hidden_states)
86
+
87
+ CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
88
+ return CausalLMOutput(logits=lm_logits)
89
+
90
+ loss = None
91
+ if labels is not None:
92
+ logits = lm_logits
93
+ # Shift so that tokens < n predict n
94
+ shift_logits = logits[..., :-1, :].contiguous()
95
+ shift_labels = labels[..., 1:].contiguous()
96
+ # Flatten the tokens
97
+ loss_fct = CrossEntropyLoss()
98
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
99
+ shift_labels = shift_labels.view(-1)
100
+ # Enable model parallelism
101
+ shift_labels = shift_labels.to(shift_logits.device)
102
+ loss = loss_fct(shift_logits, shift_labels)
103
+ CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "loss"])
104
+ print(loss)
105
+ return CausalLMOutput(logits=lm_logits, loss=loss)
106
+
107
+ else:
108
+ CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
109
+ return CausalLMOutput(logits=lm_logits)
110
+
111
+ def save_pretrained(
112
+ self,
113
+ save_directory: Union[str, os.PathLike],
114
+ state_dict: Optional[dict] = None,
115
+ safe_serialization: Optional[bool] = None, # pylint: disable=unused-argument
116
+ ):
117
+ if state_dict is None:
118
+ state_dict = self.state_dict()
119
+ torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
120
+
121
+ @classmethod
122
+ def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
123
+ config = load_config_hf(pretrained_model_name)
124
+ model = cls(**config, device=device, dtype=dtype, **kwargs)
125
+ model.load_state_dict(
126
+ load_state_dict_hf(pretrained_model_name, device={"": device}, dtype=dtype)
127
+ )
128
+ return model
src/axolotl/train.py CHANGED
@@ -82,7 +82,8 @@ def train(
82
  cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
83
  )
84
 
85
- model.config.use_cache = False
 
86
 
87
  # go ahead and presave, so we have the adapter config available to inspect
88
  if peft_config:
@@ -92,7 +93,8 @@ def train(
92
  if not Path(cfg.output_dir).is_dir():
93
  os.makedirs(cfg.output_dir, exist_ok=True)
94
  tokenizer.save_pretrained(str(Path(cfg.output_dir)))
95
- model.config.save_pretrained(str(Path(cfg.output_dir)))
 
96
 
97
  # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
98
  if cfg.local_rank == 0:
 
82
  cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
83
  )
84
 
85
+ if hasattr(model, "config"):
86
+ model.config.use_cache = False
87
 
88
  # go ahead and presave, so we have the adapter config available to inspect
89
  if peft_config:
 
93
  if not Path(cfg.output_dir).is_dir():
94
  os.makedirs(cfg.output_dir, exist_ok=True)
95
  tokenizer.save_pretrained(str(Path(cfg.output_dir)))
96
+ if hasattr(model, "config"):
97
+ model.config.save_pretrained(str(Path(cfg.output_dir)))
98
 
99
  # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
100
  if cfg.local_rank == 0:
src/axolotl/utils/collators.py CHANGED
@@ -2,12 +2,16 @@
2
  DataCollator for axolotl to pad labels and position_ids for packed sequences
3
  """
4
  from dataclasses import dataclass
5
- from typing import Any, Optional, Union
6
 
7
  import numpy as np
 
 
8
  from transformers import PreTrainedTokenizerBase
9
  from transformers.utils import PaddingStrategy
10
 
 
 
11
 
12
  @dataclass
13
  class DataCollatorForSeq2Seq:
@@ -146,3 +150,31 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
146
  chunked_data[feature] = np.concatenate(arrays)
147
  features = [chunked_data]
148
  return super().__call__(features, return_tensors=return_tensors)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  DataCollator for axolotl to pad labels and position_ids for packed sequences
3
  """
4
  from dataclasses import dataclass
5
+ from typing import Any, Dict, Optional, Sequence, Union
6
 
7
  import numpy as np
8
+ import torch
9
+ import transformers
10
  from transformers import PreTrainedTokenizerBase
11
  from transformers.utils import PaddingStrategy
12
 
13
+ IGNORE_INDEX = -100
14
+
15
 
16
  @dataclass
17
  class DataCollatorForSeq2Seq:
 
150
  chunked_data[feature] = np.concatenate(arrays)
151
  features = [chunked_data]
152
  return super().__call__(features, return_tensors=return_tensors)
153
+
154
+
155
+ @dataclass
156
+ class MambaDataCollator:
157
+ """
158
+ Collator for State Space Models (Mamba)
159
+ """
160
+
161
+ tokenizer: transformers.PreTrainedTokenizer
162
+
163
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
164
+ input_ids, labels = tuple(
165
+ [torch.LongTensor(instance[key]) for instance in instances]
166
+ for key in ("input_ids", "labels")
167
+ )
168
+ input_ids = torch.nn.utils.rnn.pad_sequence(
169
+ input_ids,
170
+ batch_first=True,
171
+ padding_value=self.tokenizer.pad_token_id,
172
+ )
173
+ labels = torch.nn.utils.rnn.pad_sequence(
174
+ labels, batch_first=True, padding_value=IGNORE_INDEX
175
+ )
176
+
177
+ return {
178
+ "input_ids": input_ids,
179
+ "labels": labels,
180
+ }
src/axolotl/utils/models.py CHANGED
@@ -4,6 +4,7 @@ import math
4
  import os
5
  from typing import Optional, Tuple # noqa: F401
6
 
 
7
  import bitsandbytes as bnb
8
  import torch
9
  import transformers
@@ -21,6 +22,7 @@ from transformers import ( # noqa: F401
21
  PreTrainedTokenizerBase,
22
  )
23
 
 
24
  from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
25
  from axolotl.utils.bench import log_gpu_memory_usage
26
  from axolotl.utils.dict import DictDefault
@@ -52,9 +54,19 @@ def check_model_config(cfg: DictDefault, model_config: AutoConfig):
52
  def load_model_config(cfg):
53
  model_config_name = cfg.base_model_config or cfg.base_model
54
  trust_remote_code = cfg.trust_remote_code is True
55
- model_config = AutoConfig.from_pretrained(
56
- model_config_name, trust_remote_code=trust_remote_code
57
- )
 
 
 
 
 
 
 
 
 
 
58
  if cfg.model_config:
59
  for key, val in cfg.model_config.items():
60
  setattr(model_config, key, val)
@@ -351,6 +363,20 @@ def load_model(
351
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
352
  **model_kwargs,
353
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  elif model_type and not cfg.trust_remote_code:
355
  if cfg.gptq:
356
  model = AutoModelForCausalLM.from_pretrained(
@@ -410,13 +436,17 @@ def load_model(
410
  if cfg.resize_token_embeddings_to_32x
411
  else len(tokenizer)
412
  )
413
- if model.get_input_embeddings().num_embeddings < embeddings_len:
 
 
 
414
  model.resize_token_embeddings(embeddings_len)
415
  else:
416
  model.tie_weights()
417
 
418
  if (
419
- hasattr(model.config, "max_position_embeddings")
 
420
  and model.config.max_position_embeddings
421
  and cfg.sequence_len > model.config.max_position_embeddings
422
  ):
@@ -426,20 +456,22 @@ def load_model(
426
  model.config.max_position_embeddings = cfg.sequence_len
427
 
428
  if (
429
- hasattr(model.config, "bos_token_id")
 
430
  and model.config.bos_token_id
431
  and model.config.bos_token_id != tokenizer.bos_token_id
432
  ):
433
  model.config.bos_token_id = tokenizer.bos_token_id
434
 
435
  if (
436
- hasattr(model.config, "eos_token_id")
 
437
  and model.config.eos_token_id
438
  and model.config.eos_token_id != tokenizer.eos_token_id
439
  ):
440
  model.config.eos_token_id = tokenizer.eos_token_id
441
 
442
- if model.device.type == "cuda":
443
  log_gpu_memory_usage(LOG, "after model load", model.device)
444
 
445
  # make sure these are fp32 per Ramesh et al. (2021)
@@ -498,7 +530,8 @@ def load_model(
498
  requires_grad.append(f"{name}: {param.requires_grad}")
499
  if len(requires_grad) == 0:
500
  LOG.warning("there are no parameters that require gradient updates")
501
- model.config.use_cache = False
 
502
 
503
  if cfg.flash_optimum:
504
  model = BetterTransformer.transform(model)
 
4
  import os
5
  from typing import Optional, Tuple # noqa: F401
6
 
7
+ import addict
8
  import bitsandbytes as bnb
9
  import torch
10
  import transformers
 
22
  PreTrainedTokenizerBase,
23
  )
24
 
25
+ from axolotl.models.mamba import fix_mamba_attn_for_loss
26
  from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
27
  from axolotl.utils.bench import log_gpu_memory_usage
28
  from axolotl.utils.dict import DictDefault
 
54
  def load_model_config(cfg):
55
  model_config_name = cfg.base_model_config or cfg.base_model
56
  trust_remote_code = cfg.trust_remote_code is True
57
+ try:
58
+ model_config = AutoConfig.from_pretrained(
59
+ model_config_name, trust_remote_code=trust_remote_code
60
+ )
61
+ except ValueError as err:
62
+ if "mamba" in model_config_name:
63
+ return addict.Dict(
64
+ {
65
+ "model_type": "mamba",
66
+ }
67
+ )
68
+ raise err
69
+
70
  if cfg.model_config:
71
  for key, val in cfg.model_config.items():
72
  setattr(model_config, key, val)
 
363
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
364
  **model_kwargs,
365
  )
366
+ elif model_type == "MambaLMHeadModel":
367
+ # FIXME this is janky at best and hacked together to make it work
368
+ MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
369
+
370
+ model_kwargs["dtype"] = model_kwargs["torch_dtype"]
371
+ model_kwargs["device"] = torch.cuda.current_device()
372
+ del model_kwargs["torch_dtype"]
373
+ del model_kwargs["device_map"]
374
+ del model_kwargs["max_memory"]
375
+
376
+ model = MambaLMHeadModel.from_pretrained(
377
+ base_model,
378
+ **model_kwargs,
379
+ )
380
  elif model_type and not cfg.trust_remote_code:
381
  if cfg.gptq:
382
  model = AutoModelForCausalLM.from_pretrained(
 
436
  if cfg.resize_token_embeddings_to_32x
437
  else len(tokenizer)
438
  )
439
+ if (
440
+ hasattr(model, "get_input_embeddings")
441
+ and model.get_input_embeddings().num_embeddings < embeddings_len
442
+ ):
443
  model.resize_token_embeddings(embeddings_len)
444
  else:
445
  model.tie_weights()
446
 
447
  if (
448
+ hasattr(model, "config")
449
+ and hasattr(model.config, "max_position_embeddings")
450
  and model.config.max_position_embeddings
451
  and cfg.sequence_len > model.config.max_position_embeddings
452
  ):
 
456
  model.config.max_position_embeddings = cfg.sequence_len
457
 
458
  if (
459
+ hasattr(model, "config")
460
+ and hasattr(model.config, "bos_token_id")
461
  and model.config.bos_token_id
462
  and model.config.bos_token_id != tokenizer.bos_token_id
463
  ):
464
  model.config.bos_token_id = tokenizer.bos_token_id
465
 
466
  if (
467
+ hasattr(model, "config")
468
+ and hasattr(model.config, "eos_token_id")
469
  and model.config.eos_token_id
470
  and model.config.eos_token_id != tokenizer.eos_token_id
471
  ):
472
  model.config.eos_token_id = tokenizer.eos_token_id
473
 
474
+ if hasattr(model, "device") and model.device.type == "cuda":
475
  log_gpu_memory_usage(LOG, "after model load", model.device)
476
 
477
  # make sure these are fp32 per Ramesh et al. (2021)
 
530
  requires_grad.append(f"{name}: {param.requires_grad}")
531
  if len(requires_grad) == 0:
532
  LOG.warning("there are no parameters that require gradient updates")
533
+ if hasattr(model, "config"):
534
+ model.config.use_cache = False
535
 
536
  if cfg.flash_optimum:
537
  model = BetterTransformer.transform(model)
src/axolotl/utils/trainer.py CHANGED
@@ -131,8 +131,10 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
131
  )
132
 
133
  # Phi doesn't want the attention_mask feature when training
134
- if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
135
- cfg.is_mistral_derived_model and cfg.flash_attention
 
 
136
  ):
137
  train_dataset = train_dataset.remove_columns("attention_mask")
138
  if eval_dataset:
@@ -153,7 +155,9 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
153
  if update:
154
  cfg.total_num_tokens = total_num_tokens
155
 
156
- if not cfg.total_supervised_tokens:
 
 
157
  total_supervised_tokens = (
158
  train_dataset.data.column("labels")
159
  .to_pandas()
@@ -167,7 +171,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
167
  if update:
168
  cfg.total_supervised_tokens = total_supervised_tokens
169
 
170
- if cfg.sample_packing:
171
  # we have to drop anything longer then sequence len otherwise
172
  # flash attention with position ids fails
173
 
 
131
  )
132
 
133
  # Phi doesn't want the attention_mask feature when training
134
+ if (
135
+ "CodeGenTokenizer" in tokenizer.__class__.__name__
136
+ or (cfg.is_mistral_derived_model and cfg.flash_attention)
137
+ or cfg.model_config_type == "mamba"
138
  ):
139
  train_dataset = train_dataset.remove_columns("attention_mask")
140
  if eval_dataset:
 
155
  if update:
156
  cfg.total_num_tokens = total_num_tokens
157
 
158
+ skip_estimates = cfg.model_config_type == "mamba"
159
+
160
+ if not skip_estimates and not cfg.total_supervised_tokens:
161
  total_supervised_tokens = (
162
  train_dataset.data.column("labels")
163
  .to_pandas()
 
171
  if update:
172
  cfg.total_supervised_tokens = total_supervised_tokens
173
 
174
+ if not skip_estimates and cfg.sample_packing:
175
  # we have to drop anything longer then sequence len otherwise
176
  # flash attention with position ids fails
177
 
tests/e2e/test_mamba.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ E2E tests for lora llama
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ import unittest
8
+ from pathlib import Path
9
+
10
+ from axolotl.cli import load_datasets
11
+ from axolotl.common.cli import TrainerCliArgs
12
+ from axolotl.train import train
13
+ from axolotl.utils.config import normalize_config
14
+ from axolotl.utils.dict import DictDefault
15
+
16
+ from .utils import with_temp_dir
17
+
18
+ LOG = logging.getLogger("axolotl.tests.e2e")
19
+ os.environ["WANDB_DISABLED"] = "true"
20
+
21
+
22
+ class TestMistral(unittest.TestCase):
23
+ """
24
+ Test case for Llama models using LoRA
25
+ """
26
+
27
+ @with_temp_dir
28
+ def test_fft(self, temp_dir):
29
+ # pylint: disable=duplicate-code
30
+ cfg = DictDefault(
31
+ {
32
+ "base_model": "state-spaces/mamba-130m",
33
+ "model_type": "MambaLMHeadModel",
34
+ "tokenizer_type": "AutoTokenizer",
35
+ "tokenizer_config": "EleutherAI/gpt-neox-20b",
36
+ "flash_attention": False,
37
+ "sequence_len": 1024,
38
+ "load_in_8bit": False,
39
+ "val_set_size": 0.0,
40
+ "datasets": [
41
+ {
42
+ "path": "mhenrichsen/alpaca_2k_test",
43
+ "type": "alpaca",
44
+ },
45
+ ],
46
+ "gradient_checkpointing": False,
47
+ "num_epochs": 2,
48
+ "micro_batch_size": 2,
49
+ "gradient_accumulation_steps": 1,
50
+ "output_dir": temp_dir,
51
+ "learning_rate": 0.00001,
52
+ "optimizer": "adamw_torch",
53
+ "lr_scheduler": "cosine",
54
+ "max_steps": 20,
55
+ "save_steps": 10,
56
+ "eval_steps": None,
57
+ "save_safetensors": False,
58
+ }
59
+ )
60
+ normalize_config(cfg)
61
+ cli_args = TrainerCliArgs()
62
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
63
+
64
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
65
+ assert (Path(temp_dir) / "pytorch_model.bin").exists()