nopperl winglian commited on
Commit
1e3d530
1 Parent(s): 1648279

Support user-defined prompt processing strategies for dpo (#1248)

Browse files

* support user-defined prompt processing strategies for dpo

* interpret dict dataset types as user-defined

* fix lint errors

* setup pydantic config for validation of User defined DPO

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>

src/axolotl/prompt_strategies/dpo/__init__.py CHANGED
@@ -8,14 +8,13 @@ import logging
8
  LOG = logging.getLogger("axolotl")
9
 
10
 
11
- def load(strategy, cfg):
12
  try:
13
  load_fn = strategy.split(".")[-1]
14
  strategy = ".".join(strategy.split(".")[:-1])
15
  mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo")
16
  func = getattr(mod, load_fn)
17
- load_kwargs = {}
18
- return func(cfg, **load_kwargs)
19
  except Exception: # pylint: disable=broad-exception-caught
20
  LOG.warning(f"unable to load strategy {strategy}")
21
  return None
 
8
  LOG = logging.getLogger("axolotl")
9
 
10
 
11
+ def load(strategy, cfg, **kwargs):
12
  try:
13
  load_fn = strategy.split(".")[-1]
14
  strategy = ".".join(strategy.split(".")[:-1])
15
  mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo")
16
  func = getattr(mod, load_fn)
17
+ return func(cfg, **kwargs)
 
18
  except Exception: # pylint: disable=broad-exception-caught
19
  LOG.warning(f"unable to load strategy {strategy}")
20
  return None
src/axolotl/prompt_strategies/dpo/chatml.py CHANGED
@@ -5,6 +5,7 @@ DPO strategies for chatml
5
 
6
  def argilla(
7
  cfg,
 
8
  ): # pylint: disable=possibly-unused-variable,unused-argument
9
  def transform_fn(sample):
10
  if "system" in sample and sample["system"]:
@@ -25,6 +26,7 @@ def argilla(
25
 
26
  def icr(
27
  cfg,
 
28
  ): # pylint: disable=possibly-unused-variable,unused-argument
29
  """
30
  chatml transforms for datasets with system, input, chosen, rejected
@@ -48,7 +50,7 @@ def icr(
48
  return transform_fn
49
 
50
 
51
- def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument
52
  """
53
  For Intel Orca DPO Pairs
54
  """
@@ -70,7 +72,9 @@ def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument
70
  return transform_fn
71
 
72
 
73
- def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
 
 
74
  def transform_fn(sample):
75
  if "system" in sample and sample["system"]:
76
  sample["prompt"] = (
@@ -88,7 +92,7 @@ def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argume
88
  return transform_fn
89
 
90
 
91
- def ultra(cfg): # pylint: disable=possibly-unused-variable,unused-argument
92
  """
93
  for ultrafeedback binarized conversations
94
  """
 
5
 
6
  def argilla(
7
  cfg,
8
+ **kwargs,
9
  ): # pylint: disable=possibly-unused-variable,unused-argument
10
  def transform_fn(sample):
11
  if "system" in sample and sample["system"]:
 
26
 
27
  def icr(
28
  cfg,
29
+ **kwargs,
30
  ): # pylint: disable=possibly-unused-variable,unused-argument
31
  """
32
  chatml transforms for datasets with system, input, chosen, rejected
 
50
  return transform_fn
51
 
52
 
53
+ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
54
  """
55
  For Intel Orca DPO Pairs
56
  """
 
72
  return transform_fn
73
 
74
 
75
+ def prompt_pairs(
76
+ cfg, **kwargs
77
+ ): # pylint: disable=possibly-unused-variable,unused-argument
78
  def transform_fn(sample):
79
  if "system" in sample and sample["system"]:
80
  sample["prompt"] = (
 
92
  return transform_fn
93
 
94
 
95
+ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
96
  """
97
  for ultrafeedback binarized conversations
98
  """
src/axolotl/prompt_strategies/dpo/user_defined.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ User-defined DPO strategies
3
+ """
4
+
5
+
6
+ def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument
7
+ ds_cfg = cfg["datasets"][dataset_idx]["type"]
8
+ if not isinstance(ds_cfg, dict):
9
+ raise ValueError(
10
+ f"User-defined dataset type must be a dictionary. Got: {ds_cfg}"
11
+ )
12
+ field_prompt = ds_cfg.get("field_prompt", "prompt")
13
+ field_system = ds_cfg.get("field_system", "system")
14
+ field_chosen = ds_cfg.get("field_chosen", "chosen")
15
+ field_rejected = ds_cfg.get("field_rejected", "rejected")
16
+ prompt_format = ds_cfg.get("prompt_format")
17
+ if not prompt_format:
18
+ prompt_format = "{" + field_prompt + "}"
19
+ chosen_format = ds_cfg.get("chosen_format")
20
+ if not chosen_format:
21
+ chosen_format = "{" + field_chosen + "}"
22
+ rejected_format = ds_cfg.get("rejected_format")
23
+ if not rejected_format:
24
+ rejected_format = "{" + field_rejected + "}"
25
+
26
+ def transform_fn(sample):
27
+ if (
28
+ "{" + field_system + "}" in prompt_format
29
+ and field_system in sample
30
+ and sample[field_system]
31
+ ):
32
+ sample["prompt"] = prompt_format.format(
33
+ system=sample[field_system], prompt=sample[field_prompt]
34
+ )
35
+ else:
36
+ sample["prompt"] = prompt_format.format(prompt=sample["prompt"])
37
+ sample["chosen"] = chosen_format.format(chosen=sample[field_chosen])
38
+ sample["rejected"] = rejected_format.format(rejected=sample[field_rejected])
39
+ return sample
40
+
41
+ return transform_fn
src/axolotl/prompt_strategies/dpo/zephyr.py CHANGED
@@ -3,7 +3,7 @@ DPO strategies for zephyr
3
  """
4
 
5
 
6
- def nectar(cfg): # pylint: disable=possibly-unused-variable,unused-argument
7
  def transform_fn(sample):
8
  data = {}
9
  data["prompt"] = (
 
3
  """
4
 
5
 
6
+ def nectar(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
7
  def transform_fn(sample):
8
  data = {}
9
  data["prompt"] = (
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -85,12 +85,24 @@ class SFTDataset(BaseModel):
85
  field_model: Optional[str] = None
86
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  class DPODataset(BaseModel):
89
  """DPO configuration subset"""
90
 
91
  path: Optional[str] = None
92
  split: Optional[str] = None
93
- type: Optional[str] = None
94
  data_files: Optional[List[str]] = None
95
 
96
 
 
85
  field_model: Optional[str] = None
86
 
87
 
88
+ class UserDefinedDPOType(BaseModel):
89
+ """User defined typing for DPO"""
90
+
91
+ field_system: Optional[str] = None
92
+ field_prompt: Optional[str] = None
93
+ field_chosen: Optional[str] = None
94
+ field_rejected: Optional[str] = None
95
+ prompt_format: Optional[str] = None
96
+ chosen_format: Optional[str] = None
97
+ rejected_format: Optional[str] = None
98
+
99
+
100
  class DPODataset(BaseModel):
101
  """DPO configuration subset"""
102
 
103
  path: Optional[str] = None
104
  split: Optional[str] = None
105
+ type: Optional[Union[UserDefinedDPOType, str]] = None
106
  data_files: Optional[List[str]] = None
107
 
108
 
src/axolotl/utils/data.py CHANGED
@@ -937,7 +937,9 @@ def load_prepare_dpo_datasets(cfg):
937
  for i, data_set in enumerate(split_datasets):
938
  _type = dataset_cfgs[i]["type"]
939
  if _type:
940
- ds_transform_fn = load_dpo(_type, _cfg)
 
 
941
  split_datasets[i] = data_set.map(
942
  ds_transform_fn,
943
  desc="Mapping RL Dataset",
 
937
  for i, data_set in enumerate(split_datasets):
938
  _type = dataset_cfgs[i]["type"]
939
  if _type:
940
+ if isinstance(_type, DictDefault):
941
+ _type = "user_defined.default"
942
+ ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
943
  split_datasets[i] = data_set.map(
944
  ds_transform_fn,
945
  desc="Mapping RL Dataset",