Mixtral fixes 20240124 (#1192) [skip ci]
Browse files* mixtral nccl fixes
* make sure to patch for z3
- README.md +3 -3
- {deepspeed β deepspeed_configs}/zero1.json +0 -0
- {deepspeed β deepspeed_configs}/zero2.json +0 -0
- {deepspeed β deepspeed_configs}/zero3.json +0 -0
- {deepspeed β deepspeed_configs}/zero3_bf16.json +0 -0
- examples/llama-2/fft_optimized.yml +1 -1
- examples/mistral/Mistral-7b-example/code.ipynb +1 -1
- examples/mistral/Mistral-7b-example/config.yml +1 -1
- examples/mistral/README.md +1 -1
- examples/mistral/mixtral.yml +1 -1
- examples/phi/README.md +1 -1
- src/axolotl/monkeypatch/mixtral/__init__.py +50 -1
- src/axolotl/train.py +1 -1
- src/axolotl/utils/models.py +11 -2
README.md
CHANGED
@@ -861,7 +861,7 @@ tokens:
|
|
861 |
fsdp:
|
862 |
fsdp_config:
|
863 |
|
864 |
-
# Deepspeed config path. e.g.,
|
865 |
deepspeed:
|
866 |
|
867 |
# Advanced DDP Arguments
|
@@ -982,11 +982,11 @@ for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usa
|
|
982 |
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
|
983 |
|
984 |
```yaml
|
985 |
-
deepspeed:
|
986 |
```
|
987 |
|
988 |
```shell
|
989 |
-
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed
|
990 |
```
|
991 |
|
992 |
##### FSDP
|
|
|
861 |
fsdp:
|
862 |
fsdp_config:
|
863 |
|
864 |
+
# Deepspeed config path. e.g., deepspeed_configs/zero3.json
|
865 |
deepspeed:
|
866 |
|
867 |
# Advanced DDP Arguments
|
|
|
982 |
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
|
983 |
|
984 |
```yaml
|
985 |
+
deepspeed: deepspeed_configs/zero1.json
|
986 |
```
|
987 |
|
988 |
```shell
|
989 |
+
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed_configs/zero1.json
|
990 |
```
|
991 |
|
992 |
##### FSDP
|
{deepspeed β deepspeed_configs}/zero1.json
RENAMED
File without changes
|
{deepspeed β deepspeed_configs}/zero2.json
RENAMED
File without changes
|
{deepspeed β deepspeed_configs}/zero3.json
RENAMED
File without changes
|
{deepspeed β deepspeed_configs}/zero3_bf16.json
RENAMED
File without changes
|
examples/llama-2/fft_optimized.yml
CHANGED
@@ -62,7 +62,7 @@ evals_per_epoch: 4
|
|
62 |
eval_table_size:
|
63 |
saves_per_epoch: 1
|
64 |
debug:
|
65 |
-
deepspeed: #
|
66 |
weight_decay: 0.1
|
67 |
fsdp:
|
68 |
fsdp_config:
|
|
|
62 |
eval_table_size:
|
63 |
saves_per_epoch: 1
|
64 |
debug:
|
65 |
+
deepspeed: #deepspeed_configs/zero2.json # multi-gpu only
|
66 |
weight_decay: 0.1
|
67 |
fsdp:
|
68 |
fsdp_config:
|
examples/mistral/Mistral-7b-example/code.ipynb
CHANGED
@@ -942,7 +942,7 @@
|
|
942 |
"not only optimizer states but also gradients and parameters across GPUs. The bf16 indicate mixed precision training using bfloat16.\n",
|
943 |
"For more information read axolotl's readme\n",
|
944 |
"\"\"\"\n",
|
945 |
-
"!accelerate launch -m axolotl.cli.train /folder/config.yml --deepspeed
|
946 |
]
|
947 |
}
|
948 |
],
|
|
|
942 |
"not only optimizer states but also gradients and parameters across GPUs. The bf16 indicate mixed precision training using bfloat16.\n",
|
943 |
"For more information read axolotl's readme\n",
|
944 |
"\"\"\"\n",
|
945 |
+
"!accelerate launch -m axolotl.cli.train /folder/config.yml --deepspeed deepspeed_configs/zero3_bf16.json"
|
946 |
]
|
947 |
}
|
948 |
],
|
examples/mistral/Mistral-7b-example/config.yml
CHANGED
@@ -65,7 +65,7 @@ eval_table_max_new_tokens: 128
|
|
65 |
saves_per_epoch: 1
|
66 |
debug:
|
67 |
#default deepspeed, can use more aggresive if needed like zero2, zero3
|
68 |
-
deepspeed:
|
69 |
weight_decay: 0.0
|
70 |
fsdp:
|
71 |
fsdp_config:
|
|
|
65 |
saves_per_epoch: 1
|
66 |
debug:
|
67 |
#default deepspeed, can use more aggresive if needed like zero2, zero3
|
68 |
+
deepspeed: deepspeed_configs/zero1.json
|
69 |
weight_decay: 0.0
|
70 |
fsdp:
|
71 |
fsdp_config:
|
examples/mistral/README.md
CHANGED
@@ -8,5 +8,5 @@ accelerate launch -m axolotl.cli.train examples/mistral/config.yml
|
|
8 |
|
9 |
If you run into CUDA OOM, use deepspeed with config zero2.json:
|
10 |
```shell
|
11 |
-
accelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed
|
12 |
```
|
|
|
8 |
|
9 |
If you run into CUDA OOM, use deepspeed with config zero2.json:
|
10 |
```shell
|
11 |
+
accelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed deepspeed_configs/zero2.json
|
12 |
```
|
examples/mistral/mixtral.yml
CHANGED
@@ -84,7 +84,7 @@ eval_table_size:
|
|
84 |
eval_table_max_new_tokens: 128
|
85 |
saves_per_epoch: 1
|
86 |
debug:
|
87 |
-
deepspeed:
|
88 |
weight_decay: 0.0
|
89 |
fsdp:
|
90 |
fsdp_config:
|
|
|
84 |
eval_table_max_new_tokens: 128
|
85 |
saves_per_epoch: 1
|
86 |
debug:
|
87 |
+
deepspeed: deepspeed_configs/zero2.json
|
88 |
weight_decay: 0.0
|
89 |
fsdp:
|
90 |
fsdp_config:
|
examples/phi/README.md
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
Due to some nuances with the phi code, please use deepspeed when training phi for full finetune.
|
4 |
|
5 |
```shell
|
6 |
-
accelerate launch -m axolotl.cli.train examples/phi/phi-ft.yml --deepspeed
|
7 |
|
8 |
# OR
|
9 |
|
|
|
3 |
Due to some nuances with the phi code, please use deepspeed when training phi for full finetune.
|
4 |
|
5 |
```shell
|
6 |
+
accelerate launch -m axolotl.cli.train examples/phi/phi-ft.yml --deepspeed deepspeed_configs/zero1.json
|
7 |
|
8 |
# OR
|
9 |
|
src/axolotl/monkeypatch/mixtral/__init__.py
CHANGED
@@ -1,12 +1,61 @@
|
|
1 |
"""
|
2 |
Patches to support multipack for mixtral
|
3 |
"""
|
|
|
4 |
import transformers
|
5 |
|
6 |
from axolotl.monkeypatch.utils import get_unpad_data
|
7 |
|
8 |
|
9 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
11 |
get_unpad_data
|
12 |
)
|
|
|
|
|
|
1 |
"""
|
2 |
Patches to support multipack for mixtral
|
3 |
"""
|
4 |
+
import torch
|
5 |
import transformers
|
6 |
|
7 |
from axolotl.monkeypatch.utils import get_unpad_data
|
8 |
|
9 |
|
10 |
+
def patch_mixtral_moe_forward_zero3() -> None:
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
def mlp_forward(self, hidden_states):
|
14 |
+
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(
|
15 |
+
hidden_states
|
16 |
+
)
|
17 |
+
current_hidden_states = self.w2(current_hidden_states)
|
18 |
+
return current_hidden_states
|
19 |
+
|
20 |
+
# Ref. https://huggingface.co/deepseek-ai/deepseek-moe-16b-base/blob/main/modeling_deepseek.py
|
21 |
+
def moe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
22 |
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
23 |
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
24 |
+
# router_logits: (batch * sequence_length, n_experts)
|
25 |
+
router_logits = self.gate(hidden_states)
|
26 |
+
|
27 |
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
28 |
+
topk_weight, topk_idx = torch.topk(
|
29 |
+
routing_weights, self.top_k, dim=-1, sorted=False
|
30 |
+
)
|
31 |
+
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
32 |
+
# we cast back to the input dtype
|
33 |
+
topk_weight = topk_weight.to(hidden_states.dtype)
|
34 |
+
|
35 |
+
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
|
36 |
+
y = torch.empty_like(hidden_states) # pylint: disable=invalid-name
|
37 |
+
flat_topk_idx = topk_idx.view(-1)
|
38 |
+
for i in range(self.num_experts):
|
39 |
+
expert = self.experts[i]
|
40 |
+
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
|
41 |
+
y = ( # pylint: disable=invalid-name
|
42 |
+
y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)
|
43 |
+
).sum(dim=1)
|
44 |
+
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
|
45 |
+
return final_hidden_states, router_logits
|
46 |
+
|
47 |
+
from transformers.models.mixtral.modeling_mixtral import (
|
48 |
+
MixtralBLockSparseTop2MLP,
|
49 |
+
MixtralSparseMoeBlock,
|
50 |
+
)
|
51 |
+
|
52 |
+
MixtralBLockSparseTop2MLP.forward = mlp_forward
|
53 |
+
MixtralSparseMoeBlock.forward = moe_forward
|
54 |
+
|
55 |
+
|
56 |
+
def replace_mixtral_attn_with_multipack_flash_attn(for_zero3=False):
|
57 |
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
58 |
get_unpad_data
|
59 |
)
|
60 |
+
if for_zero3:
|
61 |
+
patch_mixtral_moe_forward_zero3()
|
src/axolotl/train.py
CHANGED
@@ -15,7 +15,7 @@ from optimum.bettertransformer import BetterTransformer
|
|
15 |
from peft import PeftModel
|
16 |
from pkg_resources import get_distribution # type: ignore
|
17 |
from transformers import PreTrainedModel, PreTrainedTokenizer
|
18 |
-
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
19 |
|
20 |
from axolotl.common.cli import TrainerCliArgs
|
21 |
from axolotl.logging_config import configure_logging
|
|
|
15 |
from peft import PeftModel
|
16 |
from pkg_resources import get_distribution # type: ignore
|
17 |
from transformers import PreTrainedModel, PreTrainedTokenizer
|
18 |
+
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
19 |
|
20 |
from axolotl.common.cli import TrainerCliArgs
|
21 |
from axolotl.logging_config import configure_logging
|
src/axolotl/utils/models.py
CHANGED
@@ -21,7 +21,7 @@ from transformers import ( # noqa: F401
|
|
21 |
PreTrainedModel,
|
22 |
PreTrainedTokenizerBase,
|
23 |
)
|
24 |
-
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
25 |
|
26 |
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
27 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
@@ -333,7 +333,10 @@ def load_model(
|
|
333 |
)
|
334 |
|
335 |
LOG.info("patching mixtral with flash attention")
|
336 |
-
|
|
|
|
|
|
|
337 |
|
338 |
if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
|
339 |
from axolotl.monkeypatch.falcon import (
|
@@ -646,6 +649,12 @@ def load_model(
|
|
646 |
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
647 |
skip_prepare_model_for_kbit_training = False
|
648 |
|
|
|
|
|
|
|
|
|
|
|
|
|
649 |
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
650 |
# Qwen doesn't play nicely with LoRA if this is enabled
|
651 |
skip_prepare_model_for_kbit_training = True
|
|
|
21 |
PreTrainedModel,
|
22 |
PreTrainedTokenizerBase,
|
23 |
)
|
24 |
+
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
25 |
|
26 |
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
27 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
|
|
333 |
)
|
334 |
|
335 |
LOG.info("patching mixtral with flash attention")
|
336 |
+
mixtral_patch_kwargs = {}
|
337 |
+
if is_deepspeed_zero3_enabled():
|
338 |
+
mixtral_patch_kwargs["for_zero3"] = True
|
339 |
+
replace_mixtral_attn_with_multipack_flash_attn(**mixtral_patch_kwargs)
|
340 |
|
341 |
if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
|
342 |
from axolotl.monkeypatch.falcon import (
|
|
|
649 |
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
650 |
skip_prepare_model_for_kbit_training = False
|
651 |
|
652 |
+
if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
|
653 |
+
from deepspeed.utils import set_z3_leaf_modules
|
654 |
+
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
655 |
+
|
656 |
+
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
|
657 |
+
|
658 |
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
659 |
# Qwen doesn't play nicely with LoRA if this is enabled
|
660 |
skip_prepare_model_for_kbit_training = True
|