Fix: Higher vram usage for mistral and sample_packing (#691)
Browse files* Fix: Higher vram usage for mistral and sample_packing
* chore: update comment
* chore: lint
examples/mistral/qlora.yml
CHANGED
@@ -36,10 +36,10 @@ lora_target_modules:
|
|
36 |
- k_proj
|
37 |
- o_proj
|
38 |
|
39 |
-
wandb_project:
|
40 |
-
wandb_entity:
|
41 |
wandb_watch:
|
42 |
-
wandb_run_id:
|
43 |
wandb_log_model:
|
44 |
|
45 |
gradient_accumulation_steps: 4
|
@@ -76,4 +76,4 @@ fsdp_config:
|
|
76 |
special_tokens:
|
77 |
bos_token: "<s>"
|
78 |
eos_token: "</s>"
|
79 |
-
unk_token: "<unk>"
|
|
|
36 |
- k_proj
|
37 |
- o_proj
|
38 |
|
39 |
+
wandb_project:
|
40 |
+
wandb_entity:
|
41 |
wandb_watch:
|
42 |
+
wandb_run_id:
|
43 |
wandb_log_model:
|
44 |
|
45 |
gradient_accumulation_steps: 4
|
|
|
76 |
special_tokens:
|
77 |
bos_token: "<s>"
|
78 |
eos_token: "</s>"
|
79 |
+
unk_token: "<unk>"
|
src/axolotl/utils/models.py
CHANGED
@@ -81,7 +81,8 @@ def load_tokenizer(cfg):
|
|
81 |
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
82 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
83 |
|
84 |
-
|
|
|
85 |
tokenizer.padding_side = "left"
|
86 |
|
87 |
if cfg.special_tokens:
|
|
|
81 |
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
82 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
83 |
|
84 |
+
# Mistral's official FA implementation requires left padding
|
85 |
+
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
|
86 |
tokenizer.padding_side = "left"
|
87 |
|
88 |
if cfg.special_tokens:
|