refactor conversation plucking in sharegpt
Browse files
src/axolotl/prompt_tokenizers.py
CHANGED
@@ -268,6 +268,9 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
|
|
268 |
|
269 |
|
270 |
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
|
|
|
|
|
271 |
def tokenize_prompt(self, prompt):
|
272 |
result = {
|
273 |
"input_ids": [],
|
@@ -279,7 +282,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
279 |
assistant_token = self._get_assistant_token()
|
280 |
try:
|
281 |
for i, part in enumerate(
|
282 |
-
self.prompter.build_prompt(prompt
|
283 |
):
|
284 |
if isinstance(part, tuple):
|
285 |
if part[0] == "USER:":
|
|
|
268 |
|
269 |
|
270 |
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
271 |
+
def get_conversation_thread(self, prompt):
|
272 |
+
return prompt["conversations"]
|
273 |
+
|
274 |
def tokenize_prompt(self, prompt):
|
275 |
result = {
|
276 |
"input_ids": [],
|
|
|
282 |
assistant_token = self._get_assistant_token()
|
283 |
try:
|
284 |
for i, part in enumerate(
|
285 |
+
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
286 |
):
|
287 |
if isinstance(part, tuple):
|
288 |
if part[0] == "USER:":
|