support for mamba (#915)
Browse files* support for mamba
* more mamba fixes
* use fork for mamba kwargs fix
* grad checkpointing doesn't work
* fix extras for mamaba
* mamba loss fix
* use fp32 and remove verbose logging
* mamba fixes
* fix collator for mamba
* set model_type on training_args
* don't save safetensors for mamba
* update mamba config to disable safetensor checkpooints, install for tests
* no evals for mamba tests
* handle save_pretrained
* handle unused safetensors arg
- .github/workflows/tests.yml +1 -1
- examples/mamba/config.yml +61 -0
- setup.py +3 -0
- src/axolotl/core/trainer_builder.py +48 -7
- src/axolotl/models/mamba/__init__.py +12 -0
- src/axolotl/models/mamba/configuration_mamba.py +42 -0
- src/axolotl/models/mamba/modeling_mamba.py +128 -0
- src/axolotl/train.py +4 -2
- src/axolotl/utils/collators.py +33 -1
- src/axolotl/utils/models.py +42 -9
- src/axolotl/utils/trainer.py +8 -4
- tests/e2e/test_mamba.py +65 -0
.github/workflows/tests.yml
CHANGED
@@ -73,7 +73,7 @@ jobs:
|
|
73 |
run: |
|
74 |
pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
|
75 |
pip3 uninstall -y transformers accelerate
|
76 |
-
pip3 install -U -e .[flash-attn]
|
77 |
pip3 install -r requirements-tests.txt
|
78 |
|
79 |
- name: Run e2e tests
|
|
|
73 |
run: |
|
74 |
pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
|
75 |
pip3 uninstall -y transformers accelerate
|
76 |
+
pip3 install -U -e .[flash-attn,mamba-ssm]
|
77 |
pip3 install -r requirements-tests.txt
|
78 |
|
79 |
- name: Run e2e tests
|
examples/mamba/config.yml
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_model: state-spaces/mamba-2.8b
|
2 |
+
model_type: MambaLMHeadModel
|
3 |
+
tokenizer_type: AutoTokenizer
|
4 |
+
tokenizer_config: EleutherAI/gpt-neox-20b
|
5 |
+
|
6 |
+
load_in_8bit: false
|
7 |
+
load_in_4bit: false
|
8 |
+
strict: false
|
9 |
+
|
10 |
+
datasets:
|
11 |
+
- path: mhenrichsen/alpaca_2k_test
|
12 |
+
type: alpaca
|
13 |
+
dataset_prepared_path:
|
14 |
+
val_set_size: 0.0
|
15 |
+
output_dir: ./out
|
16 |
+
|
17 |
+
sequence_len: 2048
|
18 |
+
sample_packing: false
|
19 |
+
pad_to_sequence_len: false
|
20 |
+
|
21 |
+
wandb_project:
|
22 |
+
wandb_entity:
|
23 |
+
wandb_watch:
|
24 |
+
wandb_name:
|
25 |
+
wandb_log_model:
|
26 |
+
|
27 |
+
gradient_accumulation_steps: 4
|
28 |
+
micro_batch_size: 1
|
29 |
+
num_epochs: 2
|
30 |
+
optimizer: paged_adamw_8bit
|
31 |
+
lr_scheduler: cosine
|
32 |
+
learning_rate: 5e-5
|
33 |
+
|
34 |
+
train_on_inputs: false
|
35 |
+
group_by_length: true
|
36 |
+
|
37 |
+
bf16: true
|
38 |
+
fp16: false
|
39 |
+
tf32: true
|
40 |
+
|
41 |
+
gradient_checkpointing: false
|
42 |
+
early_stopping_patience:
|
43 |
+
resume_from_checkpoint:
|
44 |
+
local_rank:
|
45 |
+
logging_steps: 1
|
46 |
+
xformers_attention:
|
47 |
+
flash_attention:
|
48 |
+
|
49 |
+
warmup_steps: 10
|
50 |
+
eval_steps:
|
51 |
+
eval_table_size:
|
52 |
+
eval_table_max_new_tokens: 128
|
53 |
+
save_steps: 0.25
|
54 |
+
debug:
|
55 |
+
deepspeed:
|
56 |
+
weight_decay: 0.0
|
57 |
+
fsdp:
|
58 |
+
fsdp_config:
|
59 |
+
special_tokens:
|
60 |
+
tokens:
|
61 |
+
save_safetensors: False
|
setup.py
CHANGED
@@ -51,5 +51,8 @@ setup(
|
|
51 |
"deepspeed": [
|
52 |
"deepspeed",
|
53 |
],
|
|
|
|
|
|
|
54 |
},
|
55 |
)
|
|
|
51 |
"deepspeed": [
|
52 |
"deepspeed",
|
53 |
],
|
54 |
+
"mamba-ssm": [
|
55 |
+
"mamba-ssm==1.0.1",
|
56 |
+
],
|
57 |
},
|
58 |
)
|
src/axolotl/core/trainer_builder.py
CHANGED
@@ -31,7 +31,10 @@ from axolotl.utils.callbacks import (
|
|
31 |
bench_eval_callback_factory,
|
32 |
log_prediction_callback_factory,
|
33 |
)
|
34 |
-
from axolotl.utils.collators import
|
|
|
|
|
|
|
35 |
from axolotl.utils.samplers import MultipackBatchSampler
|
36 |
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
37 |
|
@@ -49,6 +52,9 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|
49 |
Extend the base TrainingArguments for axolotl helpers
|
50 |
"""
|
51 |
|
|
|
|
|
|
|
52 |
lr_quadratic_warmup: bool = field(
|
53 |
default=False,
|
54 |
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
@@ -285,6 +291,32 @@ class AxolotlTrainer(Trainer):
|
|
285 |
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
286 |
|
287 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
289 |
"""
|
290 |
Trainer subclass that uses the OneCycleLR scheduler
|
@@ -462,6 +494,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
462 |
return OneCycleLRSchedulerTrainer
|
463 |
if self.cfg.relora_steps:
|
464 |
return ReLoRATrainer
|
|
|
|
|
465 |
return AxolotlTrainer
|
466 |
|
467 |
def build(self, total_num_steps):
|
@@ -529,7 +563,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
529 |
if self.cfg.hub_strategy:
|
530 |
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
531 |
|
532 |
-
if self.cfg.save_safetensors:
|
533 |
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
534 |
|
535 |
if self.cfg.sample_packing_eff_est:
|
@@ -677,6 +711,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
677 |
training_arguments_kwargs = self.hook_pre_create_training_args(
|
678 |
training_arguments_kwargs
|
679 |
)
|
|
|
680 |
training_args = (
|
681 |
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
682 |
**training_arguments_kwargs,
|
@@ -731,11 +766,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
731 |
train_dataset=self.train_dataset,
|
732 |
eval_dataset=self.eval_dataset,
|
733 |
args=training_args,
|
734 |
-
data_collator=
|
735 |
-
self.tokenizer,
|
736 |
-
return_tensors="pt",
|
737 |
-
**data_collator_kwargs,
|
738 |
-
),
|
739 |
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
740 |
self.tokenizer,
|
741 |
return_tensors="pt",
|
@@ -755,3 +786,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
755 |
] = self.cfg.micro_batch_size
|
756 |
|
757 |
return trainer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
bench_eval_callback_factory,
|
32 |
log_prediction_callback_factory,
|
33 |
)
|
34 |
+
from axolotl.utils.collators import (
|
35 |
+
BatchSamplerDataCollatorForSeq2Seq,
|
36 |
+
MambaDataCollator,
|
37 |
+
)
|
38 |
from axolotl.utils.samplers import MultipackBatchSampler
|
39 |
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
40 |
|
|
|
52 |
Extend the base TrainingArguments for axolotl helpers
|
53 |
"""
|
54 |
|
55 |
+
model_type: Optional[str] = field(
|
56 |
+
default=None, metadata={"help": "HF model configuration model_type."}
|
57 |
+
)
|
58 |
lr_quadratic_warmup: bool = field(
|
59 |
default=False,
|
60 |
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
|
|
291 |
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
292 |
|
293 |
|
294 |
+
class AxolotlMambaTrainer(AxolotlTrainer):
|
295 |
+
"""
|
296 |
+
Mamba specific trainer to handle loss calculation
|
297 |
+
"""
|
298 |
+
|
299 |
+
def compute_loss(
|
300 |
+
self,
|
301 |
+
model,
|
302 |
+
inputs,
|
303 |
+
return_outputs=False, # pylint: disable=unused-argument
|
304 |
+
):
|
305 |
+
input_ids = inputs.pop("input_ids")
|
306 |
+
lm_logits = model(input_ids).logits
|
307 |
+
|
308 |
+
labels = input_ids.to(lm_logits.device)
|
309 |
+
shift_logits = lm_logits[:, :-1, :].contiguous()
|
310 |
+
labels = labels[:, 1:].contiguous()
|
311 |
+
|
312 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
313 |
+
lm_loss = loss_fct(
|
314 |
+
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
315 |
+
)
|
316 |
+
|
317 |
+
return lm_loss
|
318 |
+
|
319 |
+
|
320 |
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
321 |
"""
|
322 |
Trainer subclass that uses the OneCycleLR scheduler
|
|
|
494 |
return OneCycleLRSchedulerTrainer
|
495 |
if self.cfg.relora_steps:
|
496 |
return ReLoRATrainer
|
497 |
+
if self.cfg.model_config_type == "mamba":
|
498 |
+
return AxolotlMambaTrainer
|
499 |
return AxolotlTrainer
|
500 |
|
501 |
def build(self, total_num_steps):
|
|
|
563 |
if self.cfg.hub_strategy:
|
564 |
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
565 |
|
566 |
+
if self.cfg.save_safetensors is not None:
|
567 |
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
568 |
|
569 |
if self.cfg.sample_packing_eff_est:
|
|
|
711 |
training_arguments_kwargs = self.hook_pre_create_training_args(
|
712 |
training_arguments_kwargs
|
713 |
)
|
714 |
+
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
715 |
training_args = (
|
716 |
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
717 |
**training_arguments_kwargs,
|
|
|
766 |
train_dataset=self.train_dataset,
|
767 |
eval_dataset=self.eval_dataset,
|
768 |
args=training_args,
|
769 |
+
data_collator=self.build_collator(**data_collator_kwargs),
|
|
|
|
|
|
|
|
|
770 |
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
771 |
self.tokenizer,
|
772 |
return_tensors="pt",
|
|
|
786 |
] = self.cfg.micro_batch_size
|
787 |
|
788 |
return trainer
|
789 |
+
|
790 |
+
def build_collator(self, **kwargs):
|
791 |
+
if self.cfg.model_config_type == "mamba":
|
792 |
+
return MambaDataCollator(tokenizer=self.tokenizer)
|
793 |
+
|
794 |
+
return BatchSamplerDataCollatorForSeq2Seq(
|
795 |
+
self.tokenizer,
|
796 |
+
return_tensors="pt",
|
797 |
+
**kwargs,
|
798 |
+
)
|
src/axolotl/models/mamba/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modeling module for Mamba models
|
3 |
+
"""
|
4 |
+
|
5 |
+
|
6 |
+
def fix_mamba_attn_for_loss():
|
7 |
+
from mamba_ssm.models import mixer_seq_simple
|
8 |
+
|
9 |
+
from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed
|
10 |
+
|
11 |
+
mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed
|
12 |
+
return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name
|
src/axolotl/models/mamba/configuration_mamba.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
HF Transformers MambaConfig
|
3 |
+
"""
|
4 |
+
from transformers import PretrainedConfig
|
5 |
+
|
6 |
+
|
7 |
+
class MambaConfig(PretrainedConfig):
|
8 |
+
"""
|
9 |
+
modeling configuration for state space model/mamba
|
10 |
+
"""
|
11 |
+
|
12 |
+
model_type = "mamba"
|
13 |
+
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
vocab_size=50280,
|
17 |
+
d_model=2560,
|
18 |
+
n_layer=64,
|
19 |
+
rms_norm=True,
|
20 |
+
residual_in_fp32=True,
|
21 |
+
fused_add_norm=True,
|
22 |
+
pad_vocab_size_multiple=8,
|
23 |
+
pad_token_id=50277,
|
24 |
+
bos_token_id=0,
|
25 |
+
eos_token_id=0,
|
26 |
+
tie_word_embeddings=False,
|
27 |
+
**kwargs,
|
28 |
+
):
|
29 |
+
self.vocab_size = vocab_size
|
30 |
+
self.d_model = d_model
|
31 |
+
self.n_layer = n_layer
|
32 |
+
self.rms_norm = rms_norm
|
33 |
+
self.residual_in_fp32 = residual_in_fp32
|
34 |
+
self.fused_add_norm = fused_add_norm
|
35 |
+
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
36 |
+
super().__init__(
|
37 |
+
pad_token_id=pad_token_id,
|
38 |
+
bos_token_id=bos_token_id,
|
39 |
+
eos_token_id=eos_token_id,
|
40 |
+
tie_word_embeddings=tie_word_embeddings,
|
41 |
+
**kwargs,
|
42 |
+
)
|
src/axolotl/models/mamba/modeling_mamba.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pylint: skip-file
|
2 |
+
import os
|
3 |
+
from collections import namedtuple
|
4 |
+
from functools import partial
|
5 |
+
from typing import Optional, Union
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights
|
9 |
+
from mamba_ssm.utils.generation import GenerationMixin
|
10 |
+
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
|
11 |
+
from torch import nn
|
12 |
+
from torch.nn import CrossEntropyLoss
|
13 |
+
|
14 |
+
from axolotl.models.mamba.configuration_mamba import MambaConfig
|
15 |
+
|
16 |
+
|
17 |
+
class MambaLMHeadModel(nn.Module, GenerationMixin):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
d_model: int,
|
21 |
+
n_layer: int,
|
22 |
+
vocab_size: int,
|
23 |
+
initializer_cfg=None,
|
24 |
+
pad_vocab_size_multiple: int = 1,
|
25 |
+
device=None,
|
26 |
+
dtype=None,
|
27 |
+
**backbone_kwargs,
|
28 |
+
) -> None:
|
29 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
30 |
+
super().__init__()
|
31 |
+
if vocab_size % pad_vocab_size_multiple != 0:
|
32 |
+
vocab_size += pad_vocab_size_multiple - (
|
33 |
+
vocab_size % pad_vocab_size_multiple
|
34 |
+
)
|
35 |
+
self.config = MambaConfig(
|
36 |
+
vocab_size=vocab_size,
|
37 |
+
d_model=d_model,
|
38 |
+
n_layer=n_layer,
|
39 |
+
pad_vocab_size_multiple=pad_vocab_size_multiple,
|
40 |
+
)
|
41 |
+
self.backbone = MixerModel(
|
42 |
+
d_model=d_model,
|
43 |
+
n_layer=n_layer,
|
44 |
+
vocab_size=vocab_size,
|
45 |
+
initializer_cfg=initializer_cfg,
|
46 |
+
**backbone_kwargs,
|
47 |
+
**factory_kwargs,
|
48 |
+
)
|
49 |
+
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
50 |
+
|
51 |
+
# Initialize weights and apply final processing
|
52 |
+
self.apply(
|
53 |
+
partial(
|
54 |
+
_init_weights,
|
55 |
+
n_layer=n_layer,
|
56 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
57 |
+
)
|
58 |
+
)
|
59 |
+
self.tie_weights()
|
60 |
+
|
61 |
+
def tie_weights(self):
|
62 |
+
self.lm_head.weight = self.backbone.embedding.weight
|
63 |
+
|
64 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
65 |
+
return self.backbone.allocate_inference_cache(
|
66 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
67 |
+
)
|
68 |
+
|
69 |
+
def forward(
|
70 |
+
self,
|
71 |
+
input_ids,
|
72 |
+
position_ids=None,
|
73 |
+
inference_params=None,
|
74 |
+
num_last_tokens=0,
|
75 |
+
labels=None,
|
76 |
+
**kwargs,
|
77 |
+
):
|
78 |
+
"""
|
79 |
+
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
80 |
+
num_last_tokens: if > 0, only return the logits for the last n tokens
|
81 |
+
"""
|
82 |
+
hidden_states = self.backbone(input_ids, inference_params=inference_params)
|
83 |
+
if num_last_tokens > 0:
|
84 |
+
hidden_states = hidden_states[:, -num_last_tokens:]
|
85 |
+
lm_logits = self.lm_head(hidden_states)
|
86 |
+
|
87 |
+
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
88 |
+
return CausalLMOutput(logits=lm_logits)
|
89 |
+
|
90 |
+
loss = None
|
91 |
+
if labels is not None:
|
92 |
+
logits = lm_logits
|
93 |
+
# Shift so that tokens < n predict n
|
94 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
95 |
+
shift_labels = labels[..., 1:].contiguous()
|
96 |
+
# Flatten the tokens
|
97 |
+
loss_fct = CrossEntropyLoss()
|
98 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
99 |
+
shift_labels = shift_labels.view(-1)
|
100 |
+
# Enable model parallelism
|
101 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
102 |
+
loss = loss_fct(shift_logits, shift_labels)
|
103 |
+
CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "loss"])
|
104 |
+
print(loss)
|
105 |
+
return CausalLMOutput(logits=lm_logits, loss=loss)
|
106 |
+
|
107 |
+
else:
|
108 |
+
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
109 |
+
return CausalLMOutput(logits=lm_logits)
|
110 |
+
|
111 |
+
def save_pretrained(
|
112 |
+
self,
|
113 |
+
save_directory: Union[str, os.PathLike],
|
114 |
+
state_dict: Optional[dict] = None,
|
115 |
+
safe_serialization: Optional[bool] = None, # pylint: disable=unused-argument
|
116 |
+
):
|
117 |
+
if state_dict is None:
|
118 |
+
state_dict = self.state_dict()
|
119 |
+
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
|
120 |
+
|
121 |
+
@classmethod
|
122 |
+
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
123 |
+
config = load_config_hf(pretrained_model_name)
|
124 |
+
model = cls(**config, device=device, dtype=dtype, **kwargs)
|
125 |
+
model.load_state_dict(
|
126 |
+
load_state_dict_hf(pretrained_model_name, device={"": device}, dtype=dtype)
|
127 |
+
)
|
128 |
+
return model
|
src/axolotl/train.py
CHANGED
@@ -82,7 +82,8 @@ def train(
|
|
82 |
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
83 |
)
|
84 |
|
85 |
-
model
|
|
|
86 |
|
87 |
# go ahead and presave, so we have the adapter config available to inspect
|
88 |
if peft_config:
|
@@ -92,7 +93,8 @@ def train(
|
|
92 |
if not Path(cfg.output_dir).is_dir():
|
93 |
os.makedirs(cfg.output_dir, exist_ok=True)
|
94 |
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
|
95 |
-
model
|
|
|
96 |
|
97 |
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
98 |
if cfg.local_rank == 0:
|
|
|
82 |
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
83 |
)
|
84 |
|
85 |
+
if hasattr(model, "config"):
|
86 |
+
model.config.use_cache = False
|
87 |
|
88 |
# go ahead and presave, so we have the adapter config available to inspect
|
89 |
if peft_config:
|
|
|
93 |
if not Path(cfg.output_dir).is_dir():
|
94 |
os.makedirs(cfg.output_dir, exist_ok=True)
|
95 |
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
|
96 |
+
if hasattr(model, "config"):
|
97 |
+
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
98 |
|
99 |
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
100 |
if cfg.local_rank == 0:
|
src/axolotl/utils/collators.py
CHANGED
@@ -2,12 +2,16 @@
|
|
2 |
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
3 |
"""
|
4 |
from dataclasses import dataclass
|
5 |
-
from typing import Any, Optional, Union
|
6 |
|
7 |
import numpy as np
|
|
|
|
|
8 |
from transformers import PreTrainedTokenizerBase
|
9 |
from transformers.utils import PaddingStrategy
|
10 |
|
|
|
|
|
11 |
|
12 |
@dataclass
|
13 |
class DataCollatorForSeq2Seq:
|
@@ -146,3 +150,31 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|
146 |
chunked_data[feature] = np.concatenate(arrays)
|
147 |
features = [chunked_data]
|
148 |
return super().__call__(features, return_tensors=return_tensors)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
3 |
"""
|
4 |
from dataclasses import dataclass
|
5 |
+
from typing import Any, Dict, Optional, Sequence, Union
|
6 |
|
7 |
import numpy as np
|
8 |
+
import torch
|
9 |
+
import transformers
|
10 |
from transformers import PreTrainedTokenizerBase
|
11 |
from transformers.utils import PaddingStrategy
|
12 |
|
13 |
+
IGNORE_INDEX = -100
|
14 |
+
|
15 |
|
16 |
@dataclass
|
17 |
class DataCollatorForSeq2Seq:
|
|
|
150 |
chunked_data[feature] = np.concatenate(arrays)
|
151 |
features = [chunked_data]
|
152 |
return super().__call__(features, return_tensors=return_tensors)
|
153 |
+
|
154 |
+
|
155 |
+
@dataclass
|
156 |
+
class MambaDataCollator:
|
157 |
+
"""
|
158 |
+
Collator for State Space Models (Mamba)
|
159 |
+
"""
|
160 |
+
|
161 |
+
tokenizer: transformers.PreTrainedTokenizer
|
162 |
+
|
163 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
164 |
+
input_ids, labels = tuple(
|
165 |
+
[torch.LongTensor(instance[key]) for instance in instances]
|
166 |
+
for key in ("input_ids", "labels")
|
167 |
+
)
|
168 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
169 |
+
input_ids,
|
170 |
+
batch_first=True,
|
171 |
+
padding_value=self.tokenizer.pad_token_id,
|
172 |
+
)
|
173 |
+
labels = torch.nn.utils.rnn.pad_sequence(
|
174 |
+
labels, batch_first=True, padding_value=IGNORE_INDEX
|
175 |
+
)
|
176 |
+
|
177 |
+
return {
|
178 |
+
"input_ids": input_ids,
|
179 |
+
"labels": labels,
|
180 |
+
}
|
src/axolotl/utils/models.py
CHANGED
@@ -4,6 +4,7 @@ import math
|
|
4 |
import os
|
5 |
from typing import Optional, Tuple # noqa: F401
|
6 |
|
|
|
7 |
import bitsandbytes as bnb
|
8 |
import torch
|
9 |
import transformers
|
@@ -21,6 +22,7 @@ from transformers import ( # noqa: F401
|
|
21 |
PreTrainedTokenizerBase,
|
22 |
)
|
23 |
|
|
|
24 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
25 |
from axolotl.utils.bench import log_gpu_memory_usage
|
26 |
from axolotl.utils.dict import DictDefault
|
@@ -52,9 +54,19 @@ def check_model_config(cfg: DictDefault, model_config: AutoConfig):
|
|
52 |
def load_model_config(cfg):
|
53 |
model_config_name = cfg.base_model_config or cfg.base_model
|
54 |
trust_remote_code = cfg.trust_remote_code is True
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
if cfg.model_config:
|
59 |
for key, val in cfg.model_config.items():
|
60 |
setattr(model_config, key, val)
|
@@ -351,6 +363,20 @@ def load_model(
|
|
351 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
352 |
**model_kwargs,
|
353 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
elif model_type and not cfg.trust_remote_code:
|
355 |
if cfg.gptq:
|
356 |
model = AutoModelForCausalLM.from_pretrained(
|
@@ -410,13 +436,17 @@ def load_model(
|
|
410 |
if cfg.resize_token_embeddings_to_32x
|
411 |
else len(tokenizer)
|
412 |
)
|
413 |
-
if
|
|
|
|
|
|
|
414 |
model.resize_token_embeddings(embeddings_len)
|
415 |
else:
|
416 |
model.tie_weights()
|
417 |
|
418 |
if (
|
419 |
-
hasattr(model
|
|
|
420 |
and model.config.max_position_embeddings
|
421 |
and cfg.sequence_len > model.config.max_position_embeddings
|
422 |
):
|
@@ -426,20 +456,22 @@ def load_model(
|
|
426 |
model.config.max_position_embeddings = cfg.sequence_len
|
427 |
|
428 |
if (
|
429 |
-
hasattr(model
|
|
|
430 |
and model.config.bos_token_id
|
431 |
and model.config.bos_token_id != tokenizer.bos_token_id
|
432 |
):
|
433 |
model.config.bos_token_id = tokenizer.bos_token_id
|
434 |
|
435 |
if (
|
436 |
-
hasattr(model
|
|
|
437 |
and model.config.eos_token_id
|
438 |
and model.config.eos_token_id != tokenizer.eos_token_id
|
439 |
):
|
440 |
model.config.eos_token_id = tokenizer.eos_token_id
|
441 |
|
442 |
-
if model.device.type == "cuda":
|
443 |
log_gpu_memory_usage(LOG, "after model load", model.device)
|
444 |
|
445 |
# make sure these are fp32 per Ramesh et al. (2021)
|
@@ -498,7 +530,8 @@ def load_model(
|
|
498 |
requires_grad.append(f"{name}: {param.requires_grad}")
|
499 |
if len(requires_grad) == 0:
|
500 |
LOG.warning("there are no parameters that require gradient updates")
|
501 |
-
model
|
|
|
502 |
|
503 |
if cfg.flash_optimum:
|
504 |
model = BetterTransformer.transform(model)
|
|
|
4 |
import os
|
5 |
from typing import Optional, Tuple # noqa: F401
|
6 |
|
7 |
+
import addict
|
8 |
import bitsandbytes as bnb
|
9 |
import torch
|
10 |
import transformers
|
|
|
22 |
PreTrainedTokenizerBase,
|
23 |
)
|
24 |
|
25 |
+
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
26 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
27 |
from axolotl.utils.bench import log_gpu_memory_usage
|
28 |
from axolotl.utils.dict import DictDefault
|
|
|
54 |
def load_model_config(cfg):
|
55 |
model_config_name = cfg.base_model_config or cfg.base_model
|
56 |
trust_remote_code = cfg.trust_remote_code is True
|
57 |
+
try:
|
58 |
+
model_config = AutoConfig.from_pretrained(
|
59 |
+
model_config_name, trust_remote_code=trust_remote_code
|
60 |
+
)
|
61 |
+
except ValueError as err:
|
62 |
+
if "mamba" in model_config_name:
|
63 |
+
return addict.Dict(
|
64 |
+
{
|
65 |
+
"model_type": "mamba",
|
66 |
+
}
|
67 |
+
)
|
68 |
+
raise err
|
69 |
+
|
70 |
if cfg.model_config:
|
71 |
for key, val in cfg.model_config.items():
|
72 |
setattr(model_config, key, val)
|
|
|
363 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
364 |
**model_kwargs,
|
365 |
)
|
366 |
+
elif model_type == "MambaLMHeadModel":
|
367 |
+
# FIXME this is janky at best and hacked together to make it work
|
368 |
+
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
369 |
+
|
370 |
+
model_kwargs["dtype"] = model_kwargs["torch_dtype"]
|
371 |
+
model_kwargs["device"] = torch.cuda.current_device()
|
372 |
+
del model_kwargs["torch_dtype"]
|
373 |
+
del model_kwargs["device_map"]
|
374 |
+
del model_kwargs["max_memory"]
|
375 |
+
|
376 |
+
model = MambaLMHeadModel.from_pretrained(
|
377 |
+
base_model,
|
378 |
+
**model_kwargs,
|
379 |
+
)
|
380 |
elif model_type and not cfg.trust_remote_code:
|
381 |
if cfg.gptq:
|
382 |
model = AutoModelForCausalLM.from_pretrained(
|
|
|
436 |
if cfg.resize_token_embeddings_to_32x
|
437 |
else len(tokenizer)
|
438 |
)
|
439 |
+
if (
|
440 |
+
hasattr(model, "get_input_embeddings")
|
441 |
+
and model.get_input_embeddings().num_embeddings < embeddings_len
|
442 |
+
):
|
443 |
model.resize_token_embeddings(embeddings_len)
|
444 |
else:
|
445 |
model.tie_weights()
|
446 |
|
447 |
if (
|
448 |
+
hasattr(model, "config")
|
449 |
+
and hasattr(model.config, "max_position_embeddings")
|
450 |
and model.config.max_position_embeddings
|
451 |
and cfg.sequence_len > model.config.max_position_embeddings
|
452 |
):
|
|
|
456 |
model.config.max_position_embeddings = cfg.sequence_len
|
457 |
|
458 |
if (
|
459 |
+
hasattr(model, "config")
|
460 |
+
and hasattr(model.config, "bos_token_id")
|
461 |
and model.config.bos_token_id
|
462 |
and model.config.bos_token_id != tokenizer.bos_token_id
|
463 |
):
|
464 |
model.config.bos_token_id = tokenizer.bos_token_id
|
465 |
|
466 |
if (
|
467 |
+
hasattr(model, "config")
|
468 |
+
and hasattr(model.config, "eos_token_id")
|
469 |
and model.config.eos_token_id
|
470 |
and model.config.eos_token_id != tokenizer.eos_token_id
|
471 |
):
|
472 |
model.config.eos_token_id = tokenizer.eos_token_id
|
473 |
|
474 |
+
if hasattr(model, "device") and model.device.type == "cuda":
|
475 |
log_gpu_memory_usage(LOG, "after model load", model.device)
|
476 |
|
477 |
# make sure these are fp32 per Ramesh et al. (2021)
|
|
|
530 |
requires_grad.append(f"{name}: {param.requires_grad}")
|
531 |
if len(requires_grad) == 0:
|
532 |
LOG.warning("there are no parameters that require gradient updates")
|
533 |
+
if hasattr(model, "config"):
|
534 |
+
model.config.use_cache = False
|
535 |
|
536 |
if cfg.flash_optimum:
|
537 |
model = BetterTransformer.transform(model)
|
src/axolotl/utils/trainer.py
CHANGED
@@ -131,8 +131,10 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
|
131 |
)
|
132 |
|
133 |
# Phi doesn't want the attention_mask feature when training
|
134 |
-
if
|
135 |
-
|
|
|
|
|
136 |
):
|
137 |
train_dataset = train_dataset.remove_columns("attention_mask")
|
138 |
if eval_dataset:
|
@@ -153,7 +155,9 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|
153 |
if update:
|
154 |
cfg.total_num_tokens = total_num_tokens
|
155 |
|
156 |
-
|
|
|
|
|
157 |
total_supervised_tokens = (
|
158 |
train_dataset.data.column("labels")
|
159 |
.to_pandas()
|
@@ -167,7 +171,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|
167 |
if update:
|
168 |
cfg.total_supervised_tokens = total_supervised_tokens
|
169 |
|
170 |
-
if cfg.sample_packing:
|
171 |
# we have to drop anything longer then sequence len otherwise
|
172 |
# flash attention with position ids fails
|
173 |
|
|
|
131 |
)
|
132 |
|
133 |
# Phi doesn't want the attention_mask feature when training
|
134 |
+
if (
|
135 |
+
"CodeGenTokenizer" in tokenizer.__class__.__name__
|
136 |
+
or (cfg.is_mistral_derived_model and cfg.flash_attention)
|
137 |
+
or cfg.model_config_type == "mamba"
|
138 |
):
|
139 |
train_dataset = train_dataset.remove_columns("attention_mask")
|
140 |
if eval_dataset:
|
|
|
155 |
if update:
|
156 |
cfg.total_num_tokens = total_num_tokens
|
157 |
|
158 |
+
skip_estimates = cfg.model_config_type == "mamba"
|
159 |
+
|
160 |
+
if not skip_estimates and not cfg.total_supervised_tokens:
|
161 |
total_supervised_tokens = (
|
162 |
train_dataset.data.column("labels")
|
163 |
.to_pandas()
|
|
|
171 |
if update:
|
172 |
cfg.total_supervised_tokens = total_supervised_tokens
|
173 |
|
174 |
+
if not skip_estimates and cfg.sample_packing:
|
175 |
# we have to drop anything longer then sequence len otherwise
|
176 |
# flash attention with position ids fails
|
177 |
|
tests/e2e/test_mamba.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
E2E tests for lora llama
|
3 |
+
"""
|
4 |
+
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import unittest
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
from axolotl.cli import load_datasets
|
11 |
+
from axolotl.common.cli import TrainerCliArgs
|
12 |
+
from axolotl.train import train
|
13 |
+
from axolotl.utils.config import normalize_config
|
14 |
+
from axolotl.utils.dict import DictDefault
|
15 |
+
|
16 |
+
from .utils import with_temp_dir
|
17 |
+
|
18 |
+
LOG = logging.getLogger("axolotl.tests.e2e")
|
19 |
+
os.environ["WANDB_DISABLED"] = "true"
|
20 |
+
|
21 |
+
|
22 |
+
class TestMistral(unittest.TestCase):
|
23 |
+
"""
|
24 |
+
Test case for Llama models using LoRA
|
25 |
+
"""
|
26 |
+
|
27 |
+
@with_temp_dir
|
28 |
+
def test_fft(self, temp_dir):
|
29 |
+
# pylint: disable=duplicate-code
|
30 |
+
cfg = DictDefault(
|
31 |
+
{
|
32 |
+
"base_model": "state-spaces/mamba-130m",
|
33 |
+
"model_type": "MambaLMHeadModel",
|
34 |
+
"tokenizer_type": "AutoTokenizer",
|
35 |
+
"tokenizer_config": "EleutherAI/gpt-neox-20b",
|
36 |
+
"flash_attention": False,
|
37 |
+
"sequence_len": 1024,
|
38 |
+
"load_in_8bit": False,
|
39 |
+
"val_set_size": 0.0,
|
40 |
+
"datasets": [
|
41 |
+
{
|
42 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
43 |
+
"type": "alpaca",
|
44 |
+
},
|
45 |
+
],
|
46 |
+
"gradient_checkpointing": False,
|
47 |
+
"num_epochs": 2,
|
48 |
+
"micro_batch_size": 2,
|
49 |
+
"gradient_accumulation_steps": 1,
|
50 |
+
"output_dir": temp_dir,
|
51 |
+
"learning_rate": 0.00001,
|
52 |
+
"optimizer": "adamw_torch",
|
53 |
+
"lr_scheduler": "cosine",
|
54 |
+
"max_steps": 20,
|
55 |
+
"save_steps": 10,
|
56 |
+
"eval_steps": None,
|
57 |
+
"save_safetensors": False,
|
58 |
+
}
|
59 |
+
)
|
60 |
+
normalize_config(cfg)
|
61 |
+
cli_args = TrainerCliArgs()
|
62 |
+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
63 |
+
|
64 |
+
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
65 |
+
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|