fixes for alpaca w chatml, and don't include attention_mask w mistral for flash attention (#728)
Browse files
src/axolotl/prompt_strategies/alpaca_chat.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
-
"""Module
|
2 |
|
3 |
-
from typing import Tuple
|
4 |
|
5 |
from axolotl.prompt_tokenizers import (
|
6 |
AlpacaPromptTokenizingStrategy,
|
@@ -9,9 +9,13 @@ from axolotl.prompt_tokenizers import (
|
|
9 |
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
10 |
|
11 |
|
12 |
-
def load(tokenizer, cfg):
|
|
|
|
|
|
|
|
|
13 |
return AlpacaPromptTokenizingStrategy(
|
14 |
-
AlpacaPrompter(
|
15 |
tokenizer,
|
16 |
cfg.train_on_inputs,
|
17 |
cfg.sequence_len,
|
|
|
1 |
+
"""Module for Alpaca prompt strategy classes"""
|
2 |
|
3 |
+
from typing import Any, Dict, Optional, Tuple
|
4 |
|
5 |
from axolotl.prompt_tokenizers import (
|
6 |
AlpacaPromptTokenizingStrategy,
|
|
|
9 |
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
10 |
|
11 |
|
12 |
+
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
13 |
+
prompt_style = PromptStyle.CHAT.value
|
14 |
+
if ds_cfg and "conversation" in ds_cfg:
|
15 |
+
prompt_style = ds_cfg["conversation"]
|
16 |
+
|
17 |
return AlpacaPromptTokenizingStrategy(
|
18 |
+
AlpacaPrompter(prompt_style),
|
19 |
tokenizer,
|
20 |
cfg.train_on_inputs,
|
21 |
cfg.sequence_len,
|
src/axolotl/utils/trainer.py
CHANGED
@@ -423,7 +423,9 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
|
423 |
)
|
424 |
|
425 |
# Phi doesn't want the attention_mask feature when training
|
426 |
-
if "CodeGenTokenizer" in tokenizer.__class__.__name__
|
|
|
|
|
427 |
train_dataset = train_dataset.remove_columns("attention_mask")
|
428 |
if eval_dataset:
|
429 |
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
|
|
423 |
)
|
424 |
|
425 |
# Phi doesn't want the attention_mask feature when training
|
426 |
+
if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
|
427 |
+
cfg.is_mistral_derived_model and cfg.flash_attention
|
428 |
+
):
|
429 |
train_dataset = train_dataset.remove_columns("attention_mask")
|
430 |
if eval_dataset:
|
431 |
eval_dataset = eval_dataset.remove_columns("attention_mask")
|