File size: 8,474 Bytes
d727a17 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import json
import logging
from typing import Dict, Optional
logging.basicConfig(level=logging.WARN)
logger = logging.getLogger(__name__)
def load_json(fn: str):
with open(fn, "r") as fp:
d = json.load(fp)
return d
class DataHandler:
"""Helper class to handle prompt generation and data tokenization.
Args:
tokenizer: The tokenizer to use for tokenization.
prompt_template (str, optio
nal):
The path to the JSON file containing the prompt template.
Defaults to "/home/ubuntu/LLM/.conda/om/medAlpaca/medalpaca/prompts/medalpaca.json".
model_max_length (int, optional):
The maximum length of the tokenized sequence.
Should not exceed 2048, as LLaMA is trained with this. Defaults to 256.
train_on_inputs (bool, optional):
If False, masks out inputs in loss. Defaults to True.
Methods:
tokenize(prompt: str, add_eos_token: bool = True) -> Dict:
Tokenizes the given prompt and optionally adds an end-of-sequence (EOS) token.
generate_and_tokenize_prompt(data_point: Dict) -> Dict:
Generates a prompt based on the given data point and tokenizes it.
"""
def __init__(
self,
tokenizer,
prompt_template: str = "prompt_templates/medalpaca.json",
model_max_length: int = 256,
train_on_inputs: bool = True,
) -> None:
if model_max_length > 2048:
logger.warn(f"{model_max_length} exceeds the max token length LLaMA was trained with.")
self.prompt_template = load_json(prompt_template)
self.model_max_length = model_max_length
self.train_on_inputs = train_on_inputs
self.tokenizer = tokenizer
def tokenize(self, prompt: str, add_eos_token: bool = True, return_tensors: str = None, truncation: bool = True) -> Dict[str, list]:
"""
Tokenize the given prompt and optionally add an end-of-sequence (EOS) token.
This function tokenizes the input prompt without adding special tokens by default.
If the `add_eos_token` parameter is True and the tokenized sequence doesn't already
end with an EOS token, an EOS token will be added to the end of the sequence.
Args:
prompt (str): The text to be tokenized.
add_eos_token (bool, optional): Whether to add an EOS token at the end of
the tokenized sequence. Defaults to True.
return_tensors (str, optional): If tensors should be returned (and what type).
trunctaion (bool, optional); Whether to truncate the input to max_model_length
Returns:
Dict: A dictionary containing the tokenized data:
- input_ids: The tokenized input IDs of the prompt.
- attention_mask: The attention mask for the tokenized input IDs.
- labels: The labels for the tokenized input IDs (identical to input_ids).
"""
# TODO: optimize (roll back changes from debugging)
result: Dict = self.tokenizer(
prompt,
truncation=truncation,
max_length=self.model_max_length,
padding=False,
return_tensors=return_tensors,
add_special_tokens=False,
)
if (
result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.model_max_length
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def generate_and_tokenize_prompt(self, data_point: Dict):
"""
Generate a prompt based on the given data point and tokenize it.
This function creates a prompt using the given data point, which consists
of an instruction, input, and output. It then tokenizes the generated prompt
and returns the tokenized representation. If the `train_on_inputs` global
variable is False, the function will create a user prompt without the
expected output and only tokenize that part, masking the output part in the
"labels" field with -100.
Args:
data_point (Dict): A dictionary containing the following keys:
- instruction: The instruction text for the prompt.
- input: The input text for the prompt.
- output: The output text for the prompt.
Returns:
Dict: A dictionary containing the tokenized prompt and associated data:
- input_ids: The tokenized input IDs of the generated prompt.
- attention_mask: The attention mask for the tokenized input IDs.
- labels: The labels to be used during model training, with the output
part unmasked and the rest masked with -100 if `train_on_inputs` is False.
"""
prompt: str = self.generate_prompt(
instruction=data_point.get("instruction", ""),
input=data_point.get("input", ""),
output=data_point.get("output", ""),
)
tokenized_prompt: Dict = self.tokenize(prompt)
if not self.train_on_inputs:
user_prompt: str = self.generate_prompt(
instruction=data_point.get("instruction", ""), input=data_point.get("input", "")
)
tokenized_user_prompt: Dict = self.tokenize(user_prompt, add_eos_token=False)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
# mask out the inputs
tokenized_prompt["labels"] = [
-100 if i < user_prompt_len else label
for i, label in enumerate(tokenized_prompt["labels"])
]
return tokenized_prompt
def generate_prompt(
self,
instruction: Optional[str] = None,
input: Optional[str] = None,
output: Optional[str] = None,
) -> str:
"""
Generates a prompt for the given instruction, input and output using the specified prompt
template.
Args:
instruction (Optional[str]):
An optional string representing the instruction to be included in the prompt.
input (Optional[str]):
An optional string representing the input to be included in the prompt.
output (Optional[str]):
An optional string representing the output to be included in the prompt.
Returns:
str: The prompt string created using the specified prompt template.
Raises:
ValueError: If none of `instruction`, `input`, and `output` is defined.
## Example
using ``
{
"instruction":
},
data_handler = DataHandler(tokenizer, "prompt_templates/medalpaca.json")
prompt = data_hanlder.generate_prompt(
instruction = "Provide a short answer to this medical question.",
input = "What to expect if I have Aortic coarctation (Outlook/Prognosis)?",
output = (
"The prognosis of aortic coarctation depends on whether balloon "
"angioplasty and stenting or the surgery has been done or not."
)
)
print(prompt)
>>> Below is an instruction that describes a task, paired with an input that provides
further context. Write a response that appropriately completes the request.
### Instruction:
Provide a short answer to this medical question.
### Input:
What to expect if I have Aortic coarctation (Outlook/Prognosis)?
### Response:
The prognosis of aortic coarctation depends on whether balloon angioplasty and
stenting or the surgery has been done or not.
"""
if not any([instruction, input, output]):
raise ValueError("At least one of `instruction`, `input`, `output` should be defined")
prompt = (
f'{self.prompt_template["primer"]}'
f'{self.prompt_template["instruction"]}{instruction or ""}'
f'{self.prompt_template["input"]}{input or ""}'
f'{self.prompt_template["output"]}{output or ""}'
)
return prompt
def resolve_output(self, output: str):
pass
|