"""Module containing the InstructShareGPTPromptTokenizingStrategy class""" | |
from typing import Any, Dict, Optional | |
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy | |
from axolotl.prompters import ShareGPTPrompterV2 | |
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): | |
conversation = ( | |
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None | |
) | |
strategy = InstructShareGPTPromptTokenizingStrategy( | |
# pylint: disable=duplicate-code | |
ShareGPTPrompterV2( | |
conversation=conversation, | |
), | |
tokenizer, | |
cfg.train_on_inputs, | |
cfg.sequence_len, | |
) | |
return strategy | |
class InstructShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): | |
""" | |
basic sharegpt strategy to grab conversations from the sample row | |
""" | |
def get_conversation_thread(self, prompt): | |
return [ | |
{"from": "human", "value": prompt["instruction"]}, | |
{"from": "gpt", "value": prompt["output"]}, | |
] | |