File size: 8,474 Bytes
d727a17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import json
import logging
from typing import Dict, Optional

logging.basicConfig(level=logging.WARN)
logger = logging.getLogger(__name__)


def load_json(fn: str):
    with open(fn, "r") as fp:
        d = json.load(fp)
    return d


class DataHandler:
    """Helper class to handle prompt generation and data tokenization.

    Args:
        tokenizer: The tokenizer to use for tokenization.
        prompt_template (str, optio
        nal):
            The path to the JSON file containing the prompt template.
            Defaults to "/home/ubuntu/LLM/.conda/om/medAlpaca/medalpaca/prompts/medalpaca.json".
        model_max_length (int, optional):
            The maximum length of the tokenized sequence.
            Should not exceed 2048, as LLaMA is trained with this. Defaults to 256.
        train_on_inputs (bool, optional):
            If False, masks out inputs in loss. Defaults to True.

    Methods:
        tokenize(prompt: str, add_eos_token: bool = True) -> Dict:
            Tokenizes the given prompt and optionally adds an end-of-sequence (EOS) token.

        generate_and_tokenize_prompt(data_point: Dict) -> Dict:
            Generates a prompt based on the given data point and tokenizes it.

    """

    def __init__(
        self,
        tokenizer,
        prompt_template: str = "prompt_templates/medalpaca.json",
        model_max_length: int = 256,
        train_on_inputs: bool = True,
    ) -> None:
        if model_max_length > 2048:
            logger.warn(f"{model_max_length} exceeds the max token length LLaMA was trained with.")
        self.prompt_template = load_json(prompt_template)
        self.model_max_length = model_max_length
        self.train_on_inputs = train_on_inputs
        self.tokenizer = tokenizer

    def tokenize(self, prompt: str, add_eos_token: bool = True, return_tensors: str = None, truncation: bool = True) -> Dict[str, list]:
        """
        Tokenize the given prompt and optionally add an end-of-sequence (EOS) token.

        This function tokenizes the input prompt without adding special tokens by default.
        If the `add_eos_token` parameter is True and the tokenized sequence doesn't already
        end with an EOS token, an EOS token will be added to the end of the sequence.

        Args:
            prompt (str): The text to be tokenized.
            add_eos_token (bool, optional): Whether to add an EOS token at the end of
                the tokenized sequence. Defaults to True.
            return_tensors (str, optional): If tensors should be returned (and what type).
            trunctaion (bool, optional); Whether to truncate the input to max_model_length
            

        Returns:
            Dict: A dictionary containing the tokenized data:
                - input_ids: The tokenized input IDs of the prompt.
                - attention_mask: The attention mask for the tokenized input IDs.
                - labels: The labels for the tokenized input IDs (identical to input_ids).
        """
        # TODO: optimize (roll back changes from debugging)
        result: Dict = self.tokenizer(
            prompt,
            truncation=truncation,
            max_length=self.model_max_length,
            padding=False,
            return_tensors=return_tensors,
            add_special_tokens=False,
        )
        if (
            result["input_ids"][-1] != self.tokenizer.eos_token_id
            and len(result["input_ids"]) < self.model_max_length
            and add_eos_token
        ):
            result["input_ids"].append(self.tokenizer.eos_token_id)
            result["attention_mask"].append(1)

        result["labels"] = result["input_ids"].copy()

        return result

    def generate_and_tokenize_prompt(self, data_point: Dict):
        """
        Generate a prompt based on the given data point and tokenize it.

        This function creates a prompt using the given data point, which consists
        of an instruction, input, and output. It then tokenizes the generated prompt
        and returns the tokenized representation. If the `train_on_inputs` global
        variable is False, the function will create a user prompt without the
        expected output and only tokenize that part, masking the output part in the
        "labels" field with -100.

        Args:
            data_point (Dict): A dictionary containing the following keys:
                - instruction: The instruction text for the prompt.
                - input: The input text for the prompt.
                - output: The output text for the prompt.

        Returns:
            Dict: A dictionary containing the tokenized prompt and associated data:
                - input_ids: The tokenized input IDs of the generated prompt.
                - attention_mask: The attention mask for the tokenized input IDs.
                - labels: The labels to be used during model training, with the output
                part unmasked and the rest masked with -100 if `train_on_inputs` is False.
        """
        prompt: str = self.generate_prompt(
            instruction=data_point.get("instruction", ""),
            input=data_point.get("input", ""),
            output=data_point.get("output", ""),
        )
        tokenized_prompt: Dict = self.tokenize(prompt)
        if not self.train_on_inputs:
            user_prompt: str = self.generate_prompt(
                instruction=data_point.get("instruction", ""), input=data_point.get("input", "")
            )
            tokenized_user_prompt: Dict = self.tokenize(user_prompt, add_eos_token=False)
            user_prompt_len = len(tokenized_user_prompt["input_ids"])
            # mask out the inputs
            tokenized_prompt["labels"] = [
                -100 if i < user_prompt_len else label
                for i, label in enumerate(tokenized_prompt["labels"])
            ]
        return tokenized_prompt

    def generate_prompt(
        self,
        instruction: Optional[str] = None,
        input: Optional[str] = None,
        output: Optional[str] = None,
    ) -> str:
        """
        Generates a prompt for the given instruction, input and output using the specified prompt
        template.

        Args:
            instruction (Optional[str]):
                An optional string representing the instruction to be included in the prompt.
            input (Optional[str]):
                An optional string representing the input to be included in the prompt.
            output (Optional[str]):
                An optional string representing the output to be included in the prompt.

        Returns:
            str: The prompt string created using the specified prompt template.

        Raises:
            ValueError: If none of `instruction`, `input`, and `output` is defined.

        ## Example
        using ``

        {
        "instruction":
        },

        data_handler = DataHandler(tokenizer, "prompt_templates/medalpaca.json")
        prompt = data_hanlder.generate_prompt(
            instruction = "Provide a short answer to this medical question.",
            input = "What to expect if I have Aortic coarctation  (Outlook/Prognosis)?",
            output = (
                "The prognosis of aortic coarctation depends on whether balloon "
                "angioplasty and stenting or the surgery has been done or not."
            )
        )
        print(prompt)
        >>> Below is an instruction that describes a task, paired with an input that provides
            further context. Write a response that appropriately completes the request.

            ### Instruction:
            Provide a short answer to this medical question.

            ### Input:
            What to expect if I have Aortic coarctation  (Outlook/Prognosis)?

            ### Response:
            The prognosis of aortic coarctation depends on whether balloon angioplasty and
            stenting or the surgery has been done or not.
        """

        if not any([instruction, input, output]):
            raise ValueError("At least one of `instruction`, `input`, `output` should be defined")

        prompt = (
            f'{self.prompt_template["primer"]}'
            f'{self.prompt_template["instruction"]}{instruction or ""}'
            f'{self.prompt_template["input"]}{input or ""}'
            f'{self.prompt_template["output"]}{output or ""}'
        )

        return prompt

    def resolve_output(self, output: str): 
        pass