File size: 5,433 Bytes
3a38271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b57ed7
3a38271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bb0b78
 
 
 
 
3a38271
3d4984b
 
 
3a38271
3d4984b
 
 
3a38271
 
 
 
 
3d4984b
 
 
 
 
 
2bb0b78
3d4984b
1b7e860
 
 
3d4984b
 
 
2bb0b78
 
 
 
 
 
 
3d4984b
 
78a1e1f
 
 
 
 
 
 
 
 
 
 
 
 
 
3a38271
924bbfd
 
 
 
 
 
 
 
 
 
 
 
 
3a38271
 
 
 
 
 
78a1e1f
 
 
 
3d4984b
78a1e1f
 
 
 
2bb0b78
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""
Prompt strategies loader for alpaca instruction datasets with system prompts
"""
from typing import Generator, Tuple, Union

from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle


class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy):
    """
    Tokenizing strategy for instruction-based prompts.
    """

    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
        return (
            prompt["instruction"],
            prompt["input"] if "input" in prompt else "",
            prompt["output"],
            prompt["system"],
        )

    def tokenize_prompt(self, prompt):
        # pylint: disable=duplicate-code
        (
            instruction,
            input,  # pylint: disable=redefined-builtin
            response,
            system,
        ) = self.parse_instruction_fields(prompt)
        user_prompt = next(
            iter(
                self.prompter.build_prompt_w_system(
                    system,
                    instruction,
                    input,
                )
            )
        )
        tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
        if not self.train_on_inputs:
            user_prompt_len = len(tokenized_prompt["input_ids"])
            # TODO this could be sped up using numpy array slicing
            tokenized_prompt["labels"] = [-100] * user_prompt_len
        tokenized_res_prompt = self._tokenize(
            response, strip_bos_token=True, add_eos_token=True
        )
        tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
        tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
        tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]

        return tokenized_prompt


class SystemDataPrompter(AlpacaPrompter):
    """
    Alpaca Style Prompter that uses system prompts from the dataset
    """

    def build_prompt_w_system(
        self,
        system: str,
        instruction: str,
        input: Union[None, str] = None,  # pylint: disable=redefined-builtin
        output: Union[None, str] = None,
    ) -> Generator[str, None, None]:
        # returns the full prompt from instruction and optional input
        # if a label (=response, =output) is provided, it's also appended.
        formatted_sys_prompt = (
            self.system_format.format(system=system)
            if system and self.system_format
            else ""
        )
        if input:
            res = formatted_sys_prompt + self.turn_format.format(
                instruction=instruction, input=input
            )
        else:
            res = formatted_sys_prompt + self.turn_no_input_format.format(
                instruction=instruction
            )
        if output:
            res = f"{res}{output}"
        yield res


class OpenOrcaSystemDataPrompter(SystemDataPrompter):
    """
    Alpaca Style Prompter that uses system prompts from the dataset, with OpenOrca prompts
    """

    def match_prompt_style(self):
        # pylint: disable=duplicate-code
        if self.prompt_style == PromptStyle.INSTRUCT.value:
            self.turn_format = "### Human:\n{instruction}\n### Additional Context:\n{input}\n### Assistant:\n"
            self.turn_no_input_format = "### Human:\n{instruction}\n### Assistant:\n"
            self.system_format = "### System:\n{system}\n"
        if self.prompt_style == PromptStyle.CHAT.value:
            self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
            self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
            self.system_format = "SYSTEM: {system}\n"
        if self.prompt_style == PromptStyle.CHATML.value:
            self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
            self.turn_no_input_format = (
                "<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
            )
            self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"


class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
    """
    Tokenizing strategy for OpenOrca datasets
    """

    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
        return (
            prompt["question"],
            "",
            prompt["response"],
            prompt["system_prompt"],
        )


def load(tokenizer, cfg):
    return load_chat(tokenizer, cfg)


def load_instruct(tokenizer, cfg):
    return InstructionWSystemPromptTokenizingStrategy(
        SystemDataPrompter(PromptStyle.INSTRUCT.value),
        tokenizer,
        cfg.train_on_inputs,
        cfg.sequence_len,
    )


def load_chat(tokenizer, cfg):
    return InstructionWSystemPromptTokenizingStrategy(
        SystemDataPrompter(PromptStyle.CHAT.value),
        tokenizer,
        cfg.train_on_inputs,
        cfg.sequence_len,
    )


def load_open_orca(tokenizer, cfg):
    return OpenOrcaPromptTokenizingStrategy(
        OpenOrcaSystemDataPrompter(PromptStyle.INSTRUCT.value),
        tokenizer,
        cfg.train_on_inputs,
        cfg.sequence_len,
    )


def load_open_orca_chatml(tokenizer, cfg):
    return OpenOrcaPromptTokenizingStrategy(
        OpenOrcaSystemDataPrompter(PromptStyle.CHATML.value),
        tokenizer,
        cfg.train_on_inputs,
        cfg.sequence_len,
    )