Nanobit commited on
Commit
669f1d0
1 Parent(s): d4a88e4

Fix: Higher vram usage for mistral and sample_packing (#691)

Browse files

* Fix: Higher vram usage for mistral and sample_packing

* chore: update comment

* chore: lint

examples/mistral/qlora.yml CHANGED
@@ -36,10 +36,10 @@ lora_target_modules:
36
  - k_proj
37
  - o_proj
38
 
39
- wandb_project:
40
- wandb_entity:
41
  wandb_watch:
42
- wandb_run_id:
43
  wandb_log_model:
44
 
45
  gradient_accumulation_steps: 4
@@ -76,4 +76,4 @@ fsdp_config:
76
  special_tokens:
77
  bos_token: "<s>"
78
  eos_token: "</s>"
79
- unk_token: "<unk>"
 
36
  - k_proj
37
  - o_proj
38
 
39
+ wandb_project:
40
+ wandb_entity:
41
  wandb_watch:
42
+ wandb_run_id:
43
  wandb_log_model:
44
 
45
  gradient_accumulation_steps: 4
 
76
  special_tokens:
77
  bos_token: "<s>"
78
  eos_token: "</s>"
79
+ unk_token: "<unk>"
src/axolotl/utils/models.py CHANGED
@@ -81,7 +81,8 @@ def load_tokenizer(cfg):
81
  tokenizer.add_special_tokens({"pad_token": "[PAD]"})
82
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
83
 
84
- if cfg.is_mistral_derived_model:
 
85
  tokenizer.padding_side = "left"
86
 
87
  if cfg.special_tokens:
 
81
  tokenizer.add_special_tokens({"pad_token": "[PAD]"})
82
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
83
 
84
+ # Mistral's official FA implementation requires left padding
85
+ if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
86
  tokenizer.padding_side = "left"
87
 
88
  if cfg.special_tokens: