Update to adapt to sharegpt datasets with "assistant" rather than "gp… (#774)
Browse files* Update to adapt to sharegpt datasets with "assistant" rather than "gpt" as the machine answers.
* use a strict option for hanedling incorrect turn data
* chore: lint
---------
Co-authored-by: Wing Lian <wing.lian@gmail.com>
src/axolotl/prompt_strategies/sharegpt.py
CHANGED
@@ -24,7 +24,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
24 |
)
|
25 |
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
26 |
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
27 |
-
|
28 |
ShareGPTPrompterV2(
|
29 |
conversation=conversation,
|
30 |
role_key_model=field_model,
|
@@ -34,6 +34,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
34 |
cfg.train_on_inputs,
|
35 |
cfg.sequence_len,
|
36 |
)
|
|
|
|
|
|
|
37 |
|
38 |
|
39 |
def load_role(tokenizer, cfg):
|
@@ -59,8 +62,26 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|
59 |
basic sharegpt strategy to grab conversations from the sample row
|
60 |
"""
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
def get_conversation_thread(self, prompt):
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
|
66 |
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|
|
24 |
)
|
25 |
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
26 |
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
27 |
+
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
28 |
ShareGPTPrompterV2(
|
29 |
conversation=conversation,
|
30 |
role_key_model=field_model,
|
|
|
34 |
cfg.train_on_inputs,
|
35 |
cfg.sequence_len,
|
36 |
)
|
37 |
+
if ds_cfg and "strict" in ds_cfg:
|
38 |
+
strategy.strict = ds_cfg["strict"]
|
39 |
+
return strategy
|
40 |
|
41 |
|
42 |
def load_role(tokenizer, cfg):
|
|
|
62 |
basic sharegpt strategy to grab conversations from the sample row
|
63 |
"""
|
64 |
|
65 |
+
_strict = True
|
66 |
+
|
67 |
+
@property
|
68 |
+
def strict(self):
|
69 |
+
return self._strict
|
70 |
+
|
71 |
+
@strict.setter
|
72 |
+
def strict(self, strict):
|
73 |
+
self._strict = strict
|
74 |
+
|
75 |
def get_conversation_thread(self, prompt):
|
76 |
+
conversations = prompt["conversations"]
|
77 |
+
if self.strict:
|
78 |
+
return conversations
|
79 |
+
# remap roles - allow for assistant turn
|
80 |
+
role_map = {"human": "human", "assistant": "gpt", "gpt": "gpt"}
|
81 |
+
turns = [
|
82 |
+
{"from": role_map[t["from"]], "value": t["value"]} for t in conversations
|
83 |
+
]
|
84 |
+
return turns
|
85 |
|
86 |
|
87 |
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|