winglian commited on
Commit
cb9797e
1 Parent(s): bde3c5a

improve llama pad token handling (#475)

Browse files

* improve llama pad token handling

* tweak logic to not clobber

examples/gptq-lora-7b/config.yml CHANGED
@@ -57,7 +57,7 @@ weight_decay: 0.0001
57
  fsdp:
58
  fsdp_config:
59
  tokens:
60
- pad_token: "[PAD]"
61
  bos_token: "<s>"
62
  eos_token: "</s>"
63
  unk_token: "<unk>"
 
57
  fsdp:
58
  fsdp_config:
59
  tokens:
60
+ pad_token: "<pad>"
61
  bos_token: "<s>"
62
  eos_token: "</s>"
63
  unk_token: "<unk>"
src/axolotl/prompt_tokenizers.py CHANGED
@@ -13,7 +13,7 @@ from axolotl.prompters import IGNORE_TOKEN_ID
13
  LOG = logging.getLogger("axolotl")
14
 
15
  IGNORE_INDEX = -100
16
- LLAMA_DEFAULT_PAD_TOKEN = "[PAD]" # nosec
17
  LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
18
  LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
19
  LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
 
13
  LOG = logging.getLogger("axolotl")
14
 
15
  IGNORE_INDEX = -100
16
+ LLAMA_DEFAULT_PAD_TOKEN = "<pad>" # nosec
17
  LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
18
  LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
19
  LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
src/axolotl/utils/data.py CHANGED
@@ -54,9 +54,10 @@ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
54
 
55
  def prepare_dataset(cfg, tokenizer):
56
  if not cfg.pretraining_dataset:
57
- train_dataset, eval_dataset = load_prepare_datasets(
58
- tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
59
- )
 
60
  else:
61
  train_dataset = load_pretraining_dataset(
62
  cfg.pretraining_dataset,
 
54
 
55
  def prepare_dataset(cfg, tokenizer):
56
  if not cfg.pretraining_dataset:
57
+ with zero_first(is_main_process()):
58
+ train_dataset, eval_dataset = load_prepare_datasets(
59
+ tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
60
+ )
61
  else:
62
  train_dataset = load_pretraining_dataset(
63
  cfg.pretraining_dataset,
src/axolotl/utils/models.py CHANGED
@@ -22,7 +22,7 @@ from transformers import ( # noqa: F401
22
  PreTrainedTokenizerBase,
23
  )
24
 
25
- from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
26
  from axolotl.utils.bench import log_gpu_memory_usage
27
 
28
  LOG = logging.getLogger("axolotl")
@@ -58,8 +58,9 @@ def load_tokenizer(cfg):
58
  if tokenizer.__class__.__name__ in [
59
  "LlamaTokenizer",
60
  "LlamaTokenizerFast",
61
- ]:
62
- tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
 
63
 
64
  LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
65
  LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
 
22
  PreTrainedTokenizerBase,
23
  )
24
 
25
+ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
26
  from axolotl.utils.bench import log_gpu_memory_usage
27
 
28
  LOG = logging.getLogger("axolotl")
 
58
  if tokenizer.__class__.__name__ in [
59
  "LlamaTokenizer",
60
  "LlamaTokenizerFast",
61
+ ] and not hasattr(tokenizer, "pad_token"):
62
+ # set a pad_token, but use eos_token so we don't add a new token
63
+ tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
64
 
65
  LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
66
  LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")