fix(config): Set eos/bos to tokenizer if different (#801)
Browse files* fix(config): Set eos/bos to tokenizer if different
* chore: fix lint
- src/axolotl/utils/models.py +14 -0
src/axolotl/utils/models.py
CHANGED
@@ -386,6 +386,20 @@ def load_model(
|
|
386 |
)
|
387 |
model.config.max_position_embeddings = cfg.sequence_len
|
388 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
389 |
if model.device.type == "cuda":
|
390 |
log_gpu_memory_usage(LOG, "after model load", model.device)
|
391 |
|
|
|
386 |
)
|
387 |
model.config.max_position_embeddings = cfg.sequence_len
|
388 |
|
389 |
+
if (
|
390 |
+
hasattr(model.config, "bos_token_id")
|
391 |
+
and model.config.bos_token_id
|
392 |
+
and model.config.bos_token_id != tokenizer.bos_token_id
|
393 |
+
):
|
394 |
+
model.config.bos_token_id = tokenizer.bos_token_id
|
395 |
+
|
396 |
+
if (
|
397 |
+
hasattr(model.config, "eos_token_id")
|
398 |
+
and model.config.eos_token_id
|
399 |
+
and model.config.eos_token_id != tokenizer.eos_token_id
|
400 |
+
):
|
401 |
+
model.config.eos_token_id = tokenizer.eos_token_id
|
402 |
+
|
403 |
if model.device.type == "cuda":
|
404 |
log_gpu_memory_usage(LOG, "after model load", model.device)
|
405 |
|