abhinand Nanobit commited on
Commit
cc5d31e
1 Parent(s): 1aeece6

Add debug option for RL dataset preprocessing (#1404)

Browse files

* adding debug option for RL dataset preprocessing

* Refine formatting of debugging code in RL dataset preprocessing

* Update __init__.py

* chore: fix lint

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

src/axolotl/cli/__init__.py CHANGED
@@ -433,6 +433,23 @@ def load_rl_datasets(
433
  math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
434
  )
435
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  return TrainDatasetMeta(
437
  train_dataset=train_dataset,
438
  eval_dataset=eval_dataset,
 
433
  math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
434
  )
435
 
436
+ if cli_args.debug or cfg.debug:
437
+ LOG.info("check_dataset_labels...")
438
+
439
+ tokenizer = load_tokenizer(cfg)
440
+ check_dataset_labels(
441
+ train_dataset.select(
442
+ [
443
+ random.randrange(0, len(train_dataset) - 1) # nosec
444
+ for _ in range(cli_args.debug_num_examples)
445
+ ]
446
+ ),
447
+ tokenizer,
448
+ num_examples=cli_args.debug_num_examples,
449
+ text_only=cli_args.debug_text_only,
450
+ rl_mode=True,
451
+ )
452
+
453
  return TrainDatasetMeta(
454
  train_dataset=train_dataset,
455
  eval_dataset=eval_dataset,
src/axolotl/utils/tokenization.py CHANGED
@@ -1,6 +1,5 @@
1
  """Module for tokenization utilities"""
2
 
3
-
4
  import logging
5
  import re
6
  from typing import Dict, List
@@ -10,10 +9,19 @@ from termcolor import colored
10
  LOG = logging.getLogger("axolotl")
11
 
12
 
13
- def check_dataset_labels(dataset, tokenizer, num_examples=5, text_only=False):
 
 
 
 
 
 
14
  # the dataset is already shuffled, so let's just check the first 5 elements
15
  for idx in range(num_examples):
16
- check_example_labels(dataset[idx], tokenizer, text_only=text_only)
 
 
 
17
 
18
 
19
  def check_example_labels(example, tokenizer, text_only=False):
@@ -40,6 +48,53 @@ def check_example_labels(example, tokenizer, text_only=False):
40
  return " ".join(colored_tokens)
41
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
44
  GLAIVE_TO_SHAREGPT_ROLE = {
45
  "SYSTEM": "system",
 
1
  """Module for tokenization utilities"""
2
 
 
3
  import logging
4
  import re
5
  from typing import Dict, List
 
9
  LOG = logging.getLogger("axolotl")
10
 
11
 
12
+ def check_dataset_labels(
13
+ dataset,
14
+ tokenizer,
15
+ num_examples=5,
16
+ text_only=False,
17
+ rl_mode=False,
18
+ ):
19
  # the dataset is already shuffled, so let's just check the first 5 elements
20
  for idx in range(num_examples):
21
+ if not rl_mode:
22
+ check_example_labels(dataset[idx], tokenizer, text_only=text_only)
23
+ else:
24
+ check_rl_example_labels(dataset[idx], tokenizer, text_only=text_only)
25
 
26
 
27
  def check_example_labels(example, tokenizer, text_only=False):
 
48
  return " ".join(colored_tokens)
49
 
50
 
51
+ def color_token_for_rl_debug(decoded_token, encoded_token, color, text_only):
52
+ """Helper function to color tokens based on their type."""
53
+ colored_text = colored(decoded_token, color)
54
+ return (
55
+ colored_text
56
+ if text_only
57
+ else f"{colored_text}{colored(f'({encoded_token})', 'white')}"
58
+ )
59
+
60
+
61
+ def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
62
+ """Helper function to process and color tokens."""
63
+ colored_tokens = [
64
+ color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
65
+ for token in tokenizer.encode(tokens)
66
+ ]
67
+ return colored_tokens
68
+
69
+
70
+ def check_rl_example_labels(example, tokenizer, text_only=False):
71
+ field_prompt, field_chosen, field_rejected = "prompt", "chosen", "rejected"
72
+
73
+ input_tokens = example[field_prompt]
74
+ labels_chosen, labels_rejected = example[field_chosen], example[field_rejected]
75
+
76
+ # Process and color each type of token
77
+ colored_tokens = process_tokens_for_rl_debug(
78
+ input_tokens, "yellow", tokenizer, text_only
79
+ )
80
+ colored_chosens = process_tokens_for_rl_debug(
81
+ labels_chosen, "green", tokenizer, text_only
82
+ )
83
+ colored_rejecteds = process_tokens_for_rl_debug(
84
+ labels_rejected, "red", tokenizer, text_only
85
+ )
86
+
87
+ # Create a delimiter based on text_only flag
88
+ delimiter = "" if text_only else " "
89
+
90
+ # Logging information
91
+ LOG.info(f"INPUT PROMPT: {delimiter.join(colored_tokens)}\n\n")
92
+ LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n")
93
+ LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n")
94
+
95
+ return delimiter.join(colored_tokens)
96
+
97
+
98
  GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
99
  GLAIVE_TO_SHAREGPT_ROLE = {
100
  "SYSTEM": "system",