MilesQLi winglian commited on
Commit
0800885
1 Parent(s): d3193be

Update to adapt to sharegpt datasets with "assistant" rather than "gp… (#774)

Browse files

* Update to adapt to sharegpt datasets with "assistant" rather than "gpt" as the machine answers.

* use a strict option for hanedling incorrect turn data

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>

src/axolotl/prompt_strategies/sharegpt.py CHANGED
@@ -24,7 +24,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
24
  )
25
  field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
26
  field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
27
- return SimpleShareGPTPromptTokenizingStrategy(
28
  ShareGPTPrompterV2(
29
  conversation=conversation,
30
  role_key_model=field_model,
@@ -34,6 +34,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
34
  cfg.train_on_inputs,
35
  cfg.sequence_len,
36
  )
 
 
 
37
 
38
 
39
  def load_role(tokenizer, cfg):
@@ -59,8 +62,26 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
59
  basic sharegpt strategy to grab conversations from the sample row
60
  """
61
 
 
 
 
 
 
 
 
 
 
 
62
  def get_conversation_thread(self, prompt):
63
- return prompt["conversations"]
 
 
 
 
 
 
 
 
64
 
65
 
66
  class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
 
24
  )
25
  field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
26
  field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
27
+ strategy = SimpleShareGPTPromptTokenizingStrategy(
28
  ShareGPTPrompterV2(
29
  conversation=conversation,
30
  role_key_model=field_model,
 
34
  cfg.train_on_inputs,
35
  cfg.sequence_len,
36
  )
37
+ if ds_cfg and "strict" in ds_cfg:
38
+ strategy.strict = ds_cfg["strict"]
39
+ return strategy
40
 
41
 
42
  def load_role(tokenizer, cfg):
 
62
  basic sharegpt strategy to grab conversations from the sample row
63
  """
64
 
65
+ _strict = True
66
+
67
+ @property
68
+ def strict(self):
69
+ return self._strict
70
+
71
+ @strict.setter
72
+ def strict(self, strict):
73
+ self._strict = strict
74
+
75
  def get_conversation_thread(self, prompt):
76
+ conversations = prompt["conversations"]
77
+ if self.strict:
78
+ return conversations
79
+ # remap roles - allow for assistant turn
80
+ role_map = {"human": "human", "assistant": "gpt", "gpt": "gpt"}
81
+ turns = [
82
+ {"from": role_map[t["from"]], "value": t["value"]} for t in conversations
83
+ ]
84
+ return turns
85
 
86
 
87
  class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):