improve llama pad token handling (#475)
Browse files* improve llama pad token handling
* tweak logic to not clobber
examples/gptq-lora-7b/config.yml
CHANGED
@@ -57,7 +57,7 @@ weight_decay: 0.0001
|
|
57 |
fsdp:
|
58 |
fsdp_config:
|
59 |
tokens:
|
60 |
-
pad_token: "
|
61 |
bos_token: "<s>"
|
62 |
eos_token: "</s>"
|
63 |
unk_token: "<unk>"
|
|
|
57 |
fsdp:
|
58 |
fsdp_config:
|
59 |
tokens:
|
60 |
+
pad_token: "<pad>"
|
61 |
bos_token: "<s>"
|
62 |
eos_token: "</s>"
|
63 |
unk_token: "<unk>"
|
src/axolotl/prompt_tokenizers.py
CHANGED
@@ -13,7 +13,7 @@ from axolotl.prompters import IGNORE_TOKEN_ID
|
|
13 |
LOG = logging.getLogger("axolotl")
|
14 |
|
15 |
IGNORE_INDEX = -100
|
16 |
-
LLAMA_DEFAULT_PAD_TOKEN = "
|
17 |
LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
|
18 |
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
|
19 |
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
|
|
|
13 |
LOG = logging.getLogger("axolotl")
|
14 |
|
15 |
IGNORE_INDEX = -100
|
16 |
+
LLAMA_DEFAULT_PAD_TOKEN = "<pad>" # nosec
|
17 |
LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
|
18 |
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
|
19 |
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
|
src/axolotl/utils/data.py
CHANGED
@@ -54,9 +54,10 @@ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
|
54 |
|
55 |
def prepare_dataset(cfg, tokenizer):
|
56 |
if not cfg.pretraining_dataset:
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
60 |
else:
|
61 |
train_dataset = load_pretraining_dataset(
|
62 |
cfg.pretraining_dataset,
|
|
|
54 |
|
55 |
def prepare_dataset(cfg, tokenizer):
|
56 |
if not cfg.pretraining_dataset:
|
57 |
+
with zero_first(is_main_process()):
|
58 |
+
train_dataset, eval_dataset = load_prepare_datasets(
|
59 |
+
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
60 |
+
)
|
61 |
else:
|
62 |
train_dataset = load_pretraining_dataset(
|
63 |
cfg.pretraining_dataset,
|
src/axolotl/utils/models.py
CHANGED
@@ -22,7 +22,7 @@ from transformers import ( # noqa: F401
|
|
22 |
PreTrainedTokenizerBase,
|
23 |
)
|
24 |
|
25 |
-
from axolotl.prompt_tokenizers import
|
26 |
from axolotl.utils.bench import log_gpu_memory_usage
|
27 |
|
28 |
LOG = logging.getLogger("axolotl")
|
@@ -58,8 +58,9 @@ def load_tokenizer(cfg):
|
|
58 |
if tokenizer.__class__.__name__ in [
|
59 |
"LlamaTokenizer",
|
60 |
"LlamaTokenizerFast",
|
61 |
-
]:
|
62 |
-
|
|
|
63 |
|
64 |
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
65 |
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
|
|
22 |
PreTrainedTokenizerBase,
|
23 |
)
|
24 |
|
25 |
+
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
26 |
from axolotl.utils.bench import log_gpu_memory_usage
|
27 |
|
28 |
LOG = logging.getLogger("axolotl")
|
|
|
58 |
if tokenizer.__class__.__name__ in [
|
59 |
"LlamaTokenizer",
|
60 |
"LlamaTokenizerFast",
|
61 |
+
] and not hasattr(tokenizer, "pad_token"):
|
62 |
+
# set a pad_token, but use eos_token so we don't add a new token
|
63 |
+
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
|
64 |
|
65 |
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
66 |
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|