Add debug option for RL dataset preprocessing (#1404)
Browse files* adding debug option for RL dataset preprocessing
* Refine formatting of debugging code in RL dataset preprocessing
* Update __init__.py
* chore: fix lint
---------
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
- src/axolotl/cli/__init__.py +17 -0
- src/axolotl/utils/tokenization.py +58 -3
src/axolotl/cli/__init__.py
CHANGED
@@ -433,6 +433,23 @@ def load_rl_datasets(
|
|
433 |
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
434 |
)
|
435 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
436 |
return TrainDatasetMeta(
|
437 |
train_dataset=train_dataset,
|
438 |
eval_dataset=eval_dataset,
|
|
|
433 |
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
434 |
)
|
435 |
|
436 |
+
if cli_args.debug or cfg.debug:
|
437 |
+
LOG.info("check_dataset_labels...")
|
438 |
+
|
439 |
+
tokenizer = load_tokenizer(cfg)
|
440 |
+
check_dataset_labels(
|
441 |
+
train_dataset.select(
|
442 |
+
[
|
443 |
+
random.randrange(0, len(train_dataset) - 1) # nosec
|
444 |
+
for _ in range(cli_args.debug_num_examples)
|
445 |
+
]
|
446 |
+
),
|
447 |
+
tokenizer,
|
448 |
+
num_examples=cli_args.debug_num_examples,
|
449 |
+
text_only=cli_args.debug_text_only,
|
450 |
+
rl_mode=True,
|
451 |
+
)
|
452 |
+
|
453 |
return TrainDatasetMeta(
|
454 |
train_dataset=train_dataset,
|
455 |
eval_dataset=eval_dataset,
|
src/axolotl/utils/tokenization.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
"""Module for tokenization utilities"""
|
2 |
|
3 |
-
|
4 |
import logging
|
5 |
import re
|
6 |
from typing import Dict, List
|
@@ -10,10 +9,19 @@ from termcolor import colored
|
|
10 |
LOG = logging.getLogger("axolotl")
|
11 |
|
12 |
|
13 |
-
def check_dataset_labels(
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
# the dataset is already shuffled, so let's just check the first 5 elements
|
15 |
for idx in range(num_examples):
|
16 |
-
|
|
|
|
|
|
|
17 |
|
18 |
|
19 |
def check_example_labels(example, tokenizer, text_only=False):
|
@@ -40,6 +48,53 @@ def check_example_labels(example, tokenizer, text_only=False):
|
|
40 |
return " ".join(colored_tokens)
|
41 |
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
|
44 |
GLAIVE_TO_SHAREGPT_ROLE = {
|
45 |
"SYSTEM": "system",
|
|
|
1 |
"""Module for tokenization utilities"""
|
2 |
|
|
|
3 |
import logging
|
4 |
import re
|
5 |
from typing import Dict, List
|
|
|
9 |
LOG = logging.getLogger("axolotl")
|
10 |
|
11 |
|
12 |
+
def check_dataset_labels(
|
13 |
+
dataset,
|
14 |
+
tokenizer,
|
15 |
+
num_examples=5,
|
16 |
+
text_only=False,
|
17 |
+
rl_mode=False,
|
18 |
+
):
|
19 |
# the dataset is already shuffled, so let's just check the first 5 elements
|
20 |
for idx in range(num_examples):
|
21 |
+
if not rl_mode:
|
22 |
+
check_example_labels(dataset[idx], tokenizer, text_only=text_only)
|
23 |
+
else:
|
24 |
+
check_rl_example_labels(dataset[idx], tokenizer, text_only=text_only)
|
25 |
|
26 |
|
27 |
def check_example_labels(example, tokenizer, text_only=False):
|
|
|
48 |
return " ".join(colored_tokens)
|
49 |
|
50 |
|
51 |
+
def color_token_for_rl_debug(decoded_token, encoded_token, color, text_only):
|
52 |
+
"""Helper function to color tokens based on their type."""
|
53 |
+
colored_text = colored(decoded_token, color)
|
54 |
+
return (
|
55 |
+
colored_text
|
56 |
+
if text_only
|
57 |
+
else f"{colored_text}{colored(f'({encoded_token})', 'white')}"
|
58 |
+
)
|
59 |
+
|
60 |
+
|
61 |
+
def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
|
62 |
+
"""Helper function to process and color tokens."""
|
63 |
+
colored_tokens = [
|
64 |
+
color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
|
65 |
+
for token in tokenizer.encode(tokens)
|
66 |
+
]
|
67 |
+
return colored_tokens
|
68 |
+
|
69 |
+
|
70 |
+
def check_rl_example_labels(example, tokenizer, text_only=False):
|
71 |
+
field_prompt, field_chosen, field_rejected = "prompt", "chosen", "rejected"
|
72 |
+
|
73 |
+
input_tokens = example[field_prompt]
|
74 |
+
labels_chosen, labels_rejected = example[field_chosen], example[field_rejected]
|
75 |
+
|
76 |
+
# Process and color each type of token
|
77 |
+
colored_tokens = process_tokens_for_rl_debug(
|
78 |
+
input_tokens, "yellow", tokenizer, text_only
|
79 |
+
)
|
80 |
+
colored_chosens = process_tokens_for_rl_debug(
|
81 |
+
labels_chosen, "green", tokenizer, text_only
|
82 |
+
)
|
83 |
+
colored_rejecteds = process_tokens_for_rl_debug(
|
84 |
+
labels_rejected, "red", tokenizer, text_only
|
85 |
+
)
|
86 |
+
|
87 |
+
# Create a delimiter based on text_only flag
|
88 |
+
delimiter = "" if text_only else " "
|
89 |
+
|
90 |
+
# Logging information
|
91 |
+
LOG.info(f"INPUT PROMPT: {delimiter.join(colored_tokens)}\n\n")
|
92 |
+
LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n")
|
93 |
+
LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n")
|
94 |
+
|
95 |
+
return delimiter.join(colored_tokens)
|
96 |
+
|
97 |
+
|
98 |
GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
|
99 |
GLAIVE_TO_SHAREGPT_ROLE = {
|
100 |
"SYSTEM": "system",
|