Mixtral fixes 20240124 (#1192) [skip ci]
Browse files* mixtral nccl fixes
* make sure to patch for z3
- README.md +3 -3
- {deepspeed → deepspeed_configs}/zero1.json +0 -0
- {deepspeed → deepspeed_configs}/zero2.json +0 -0
- {deepspeed → deepspeed_configs}/zero3.json +0 -0
- {deepspeed → deepspeed_configs}/zero3_bf16.json +0 -0
- examples/llama-2/fft_optimized.yml +1 -1
- examples/mistral/Mistral-7b-example/code.ipynb +1 -1
- examples/mistral/Mistral-7b-example/config.yml +1 -1
- examples/mistral/README.md +1 -1
- examples/mistral/mixtral.yml +1 -1
- examples/phi/README.md +1 -1
- src/axolotl/monkeypatch/mixtral/__init__.py +50 -1
- src/axolotl/train.py +1 -1
- src/axolotl/utils/models.py +11 -2
README.md
CHANGED
|
@@ -861,7 +861,7 @@ tokens:
|
|
| 861 |
fsdp:
|
| 862 |
fsdp_config:
|
| 863 |
|
| 864 |
-
# Deepspeed config path. e.g.,
|
| 865 |
deepspeed:
|
| 866 |
|
| 867 |
# Advanced DDP Arguments
|
|
@@ -982,11 +982,11 @@ for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usa
|
|
| 982 |
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
|
| 983 |
|
| 984 |
```yaml
|
| 985 |
-
deepspeed:
|
| 986 |
```
|
| 987 |
|
| 988 |
```shell
|
| 989 |
-
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed
|
| 990 |
```
|
| 991 |
|
| 992 |
##### FSDP
|
|
|
|
| 861 |
fsdp:
|
| 862 |
fsdp_config:
|
| 863 |
|
| 864 |
+
# Deepspeed config path. e.g., deepspeed_configs/zero3.json
|
| 865 |
deepspeed:
|
| 866 |
|
| 867 |
# Advanced DDP Arguments
|
|
|
|
| 982 |
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
|
| 983 |
|
| 984 |
```yaml
|
| 985 |
+
deepspeed: deepspeed_configs/zero1.json
|
| 986 |
```
|
| 987 |
|
| 988 |
```shell
|
| 989 |
+
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed_configs/zero1.json
|
| 990 |
```
|
| 991 |
|
| 992 |
##### FSDP
|
{deepspeed → deepspeed_configs}/zero1.json
RENAMED
|
File without changes
|
{deepspeed → deepspeed_configs}/zero2.json
RENAMED
|
File without changes
|
{deepspeed → deepspeed_configs}/zero3.json
RENAMED
|
File without changes
|
{deepspeed → deepspeed_configs}/zero3_bf16.json
RENAMED
|
File without changes
|
examples/llama-2/fft_optimized.yml
CHANGED
|
@@ -62,7 +62,7 @@ evals_per_epoch: 4
|
|
| 62 |
eval_table_size:
|
| 63 |
saves_per_epoch: 1
|
| 64 |
debug:
|
| 65 |
-
deepspeed: #
|
| 66 |
weight_decay: 0.1
|
| 67 |
fsdp:
|
| 68 |
fsdp_config:
|
|
|
|
| 62 |
eval_table_size:
|
| 63 |
saves_per_epoch: 1
|
| 64 |
debug:
|
| 65 |
+
deepspeed: #deepspeed_configs/zero2.json # multi-gpu only
|
| 66 |
weight_decay: 0.1
|
| 67 |
fsdp:
|
| 68 |
fsdp_config:
|
examples/mistral/Mistral-7b-example/code.ipynb
CHANGED
|
@@ -942,7 +942,7 @@
|
|
| 942 |
"not only optimizer states but also gradients and parameters across GPUs. The bf16 indicate mixed precision training using bfloat16.\n",
|
| 943 |
"For more information read axolotl's readme\n",
|
| 944 |
"\"\"\"\n",
|
| 945 |
-
"!accelerate launch -m axolotl.cli.train /folder/config.yml --deepspeed
|
| 946 |
]
|
| 947 |
}
|
| 948 |
],
|
|
|
|
| 942 |
"not only optimizer states but also gradients and parameters across GPUs. The bf16 indicate mixed precision training using bfloat16.\n",
|
| 943 |
"For more information read axolotl's readme\n",
|
| 944 |
"\"\"\"\n",
|
| 945 |
+
"!accelerate launch -m axolotl.cli.train /folder/config.yml --deepspeed deepspeed_configs/zero3_bf16.json"
|
| 946 |
]
|
| 947 |
}
|
| 948 |
],
|
examples/mistral/Mistral-7b-example/config.yml
CHANGED
|
@@ -65,7 +65,7 @@ eval_table_max_new_tokens: 128
|
|
| 65 |
saves_per_epoch: 1
|
| 66 |
debug:
|
| 67 |
#default deepspeed, can use more aggresive if needed like zero2, zero3
|
| 68 |
-
deepspeed:
|
| 69 |
weight_decay: 0.0
|
| 70 |
fsdp:
|
| 71 |
fsdp_config:
|
|
|
|
| 65 |
saves_per_epoch: 1
|
| 66 |
debug:
|
| 67 |
#default deepspeed, can use more aggresive if needed like zero2, zero3
|
| 68 |
+
deepspeed: deepspeed_configs/zero1.json
|
| 69 |
weight_decay: 0.0
|
| 70 |
fsdp:
|
| 71 |
fsdp_config:
|
examples/mistral/README.md
CHANGED
|
@@ -8,5 +8,5 @@ accelerate launch -m axolotl.cli.train examples/mistral/config.yml
|
|
| 8 |
|
| 9 |
If you run into CUDA OOM, use deepspeed with config zero2.json:
|
| 10 |
```shell
|
| 11 |
-
accelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed
|
| 12 |
```
|
|
|
|
| 8 |
|
| 9 |
If you run into CUDA OOM, use deepspeed with config zero2.json:
|
| 10 |
```shell
|
| 11 |
+
accelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed deepspeed_configs/zero2.json
|
| 12 |
```
|
examples/mistral/mixtral.yml
CHANGED
|
@@ -84,7 +84,7 @@ eval_table_size:
|
|
| 84 |
eval_table_max_new_tokens: 128
|
| 85 |
saves_per_epoch: 1
|
| 86 |
debug:
|
| 87 |
-
deepspeed:
|
| 88 |
weight_decay: 0.0
|
| 89 |
fsdp:
|
| 90 |
fsdp_config:
|
|
|
|
| 84 |
eval_table_max_new_tokens: 128
|
| 85 |
saves_per_epoch: 1
|
| 86 |
debug:
|
| 87 |
+
deepspeed: deepspeed_configs/zero2.json
|
| 88 |
weight_decay: 0.0
|
| 89 |
fsdp:
|
| 90 |
fsdp_config:
|
examples/phi/README.md
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
Due to some nuances with the phi code, please use deepspeed when training phi for full finetune.
|
| 4 |
|
| 5 |
```shell
|
| 6 |
-
accelerate launch -m axolotl.cli.train examples/phi/phi-ft.yml --deepspeed
|
| 7 |
|
| 8 |
# OR
|
| 9 |
|
|
|
|
| 3 |
Due to some nuances with the phi code, please use deepspeed when training phi for full finetune.
|
| 4 |
|
| 5 |
```shell
|
| 6 |
+
accelerate launch -m axolotl.cli.train examples/phi/phi-ft.yml --deepspeed deepspeed_configs/zero1.json
|
| 7 |
|
| 8 |
# OR
|
| 9 |
|
src/axolotl/monkeypatch/mixtral/__init__.py
CHANGED
|
@@ -1,12 +1,61 @@
|
|
| 1 |
"""
|
| 2 |
Patches to support multipack for mixtral
|
| 3 |
"""
|
|
|
|
| 4 |
import transformers
|
| 5 |
|
| 6 |
from axolotl.monkeypatch.utils import get_unpad_data
|
| 7 |
|
| 8 |
|
| 9 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
| 11 |
get_unpad_data
|
| 12 |
)
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
Patches to support multipack for mixtral
|
| 3 |
"""
|
| 4 |
+
import torch
|
| 5 |
import transformers
|
| 6 |
|
| 7 |
from axolotl.monkeypatch.utils import get_unpad_data
|
| 8 |
|
| 9 |
|
| 10 |
+
def patch_mixtral_moe_forward_zero3() -> None:
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
def mlp_forward(self, hidden_states):
|
| 14 |
+
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(
|
| 15 |
+
hidden_states
|
| 16 |
+
)
|
| 17 |
+
current_hidden_states = self.w2(current_hidden_states)
|
| 18 |
+
return current_hidden_states
|
| 19 |
+
|
| 20 |
+
# Ref. https://huggingface.co/deepseek-ai/deepseek-moe-16b-base/blob/main/modeling_deepseek.py
|
| 21 |
+
def moe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 22 |
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 23 |
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 24 |
+
# router_logits: (batch * sequence_length, n_experts)
|
| 25 |
+
router_logits = self.gate(hidden_states)
|
| 26 |
+
|
| 27 |
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
| 28 |
+
topk_weight, topk_idx = torch.topk(
|
| 29 |
+
routing_weights, self.top_k, dim=-1, sorted=False
|
| 30 |
+
)
|
| 31 |
+
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
| 32 |
+
# we cast back to the input dtype
|
| 33 |
+
topk_weight = topk_weight.to(hidden_states.dtype)
|
| 34 |
+
|
| 35 |
+
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
|
| 36 |
+
y = torch.empty_like(hidden_states) # pylint: disable=invalid-name
|
| 37 |
+
flat_topk_idx = topk_idx.view(-1)
|
| 38 |
+
for i in range(self.num_experts):
|
| 39 |
+
expert = self.experts[i]
|
| 40 |
+
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
|
| 41 |
+
y = ( # pylint: disable=invalid-name
|
| 42 |
+
y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)
|
| 43 |
+
).sum(dim=1)
|
| 44 |
+
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
|
| 45 |
+
return final_hidden_states, router_logits
|
| 46 |
+
|
| 47 |
+
from transformers.models.mixtral.modeling_mixtral import (
|
| 48 |
+
MixtralBLockSparseTop2MLP,
|
| 49 |
+
MixtralSparseMoeBlock,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
MixtralBLockSparseTop2MLP.forward = mlp_forward
|
| 53 |
+
MixtralSparseMoeBlock.forward = moe_forward
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def replace_mixtral_attn_with_multipack_flash_attn(for_zero3=False):
|
| 57 |
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
| 58 |
get_unpad_data
|
| 59 |
)
|
| 60 |
+
if for_zero3:
|
| 61 |
+
patch_mixtral_moe_forward_zero3()
|
src/axolotl/train.py
CHANGED
|
@@ -15,7 +15,7 @@ from optimum.bettertransformer import BetterTransformer
|
|
| 15 |
from peft import PeftModel
|
| 16 |
from pkg_resources import get_distribution # type: ignore
|
| 17 |
from transformers import PreTrainedModel, PreTrainedTokenizer
|
| 18 |
-
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
| 19 |
|
| 20 |
from axolotl.common.cli import TrainerCliArgs
|
| 21 |
from axolotl.logging_config import configure_logging
|
|
|
|
| 15 |
from peft import PeftModel
|
| 16 |
from pkg_resources import get_distribution # type: ignore
|
| 17 |
from transformers import PreTrainedModel, PreTrainedTokenizer
|
| 18 |
+
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 19 |
|
| 20 |
from axolotl.common.cli import TrainerCliArgs
|
| 21 |
from axolotl.logging_config import configure_logging
|
src/axolotl/utils/models.py
CHANGED
|
@@ -21,7 +21,7 @@ from transformers import ( # noqa: F401
|
|
| 21 |
PreTrainedModel,
|
| 22 |
PreTrainedTokenizerBase,
|
| 23 |
)
|
| 24 |
-
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
| 25 |
|
| 26 |
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
| 27 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
|
@@ -333,7 +333,10 @@ def load_model(
|
|
| 333 |
)
|
| 334 |
|
| 335 |
LOG.info("patching mixtral with flash attention")
|
| 336 |
-
|
|
|
|
|
|
|
|
|
|
| 337 |
|
| 338 |
if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
|
| 339 |
from axolotl.monkeypatch.falcon import (
|
|
@@ -646,6 +649,12 @@ def load_model(
|
|
| 646 |
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
| 647 |
skip_prepare_model_for_kbit_training = False
|
| 648 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 649 |
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
| 650 |
# Qwen doesn't play nicely with LoRA if this is enabled
|
| 651 |
skip_prepare_model_for_kbit_training = True
|
|
|
|
| 21 |
PreTrainedModel,
|
| 22 |
PreTrainedTokenizerBase,
|
| 23 |
)
|
| 24 |
+
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 25 |
|
| 26 |
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
| 27 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
|
|
|
| 333 |
)
|
| 334 |
|
| 335 |
LOG.info("patching mixtral with flash attention")
|
| 336 |
+
mixtral_patch_kwargs = {}
|
| 337 |
+
if is_deepspeed_zero3_enabled():
|
| 338 |
+
mixtral_patch_kwargs["for_zero3"] = True
|
| 339 |
+
replace_mixtral_attn_with_multipack_flash_attn(**mixtral_patch_kwargs)
|
| 340 |
|
| 341 |
if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
|
| 342 |
from axolotl.monkeypatch.falcon import (
|
|
|
|
| 649 |
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
| 650 |
skip_prepare_model_for_kbit_training = False
|
| 651 |
|
| 652 |
+
if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
|
| 653 |
+
from deepspeed.utils import set_z3_leaf_modules
|
| 654 |
+
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
| 655 |
+
|
| 656 |
+
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
|
| 657 |
+
|
| 658 |
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
| 659 |
# Qwen doesn't play nicely with LoRA if this is enabled
|
| 660 |
skip_prepare_model_for_kbit_training = True
|