|
import json |
|
import logging |
|
from typing import Dict, Optional |
|
|
|
logging.basicConfig(level=logging.WARN) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def load_json(fn: str): |
|
with open(fn, "r") as fp: |
|
d = json.load(fp) |
|
return d |
|
|
|
|
|
class DataHandler: |
|
"""Helper class to handle prompt generation and data tokenization. |
|
|
|
Args: |
|
tokenizer: The tokenizer to use for tokenization. |
|
prompt_template (str, optio |
|
nal): |
|
The path to the JSON file containing the prompt template. |
|
Defaults to "/home/ubuntu/LLM/.conda/om/medAlpaca/medalpaca/prompts/medalpaca.json". |
|
model_max_length (int, optional): |
|
The maximum length of the tokenized sequence. |
|
Should not exceed 2048, as LLaMA is trained with this. Defaults to 256. |
|
train_on_inputs (bool, optional): |
|
If False, masks out inputs in loss. Defaults to True. |
|
|
|
Methods: |
|
tokenize(prompt: str, add_eos_token: bool = True) -> Dict: |
|
Tokenizes the given prompt and optionally adds an end-of-sequence (EOS) token. |
|
|
|
generate_and_tokenize_prompt(data_point: Dict) -> Dict: |
|
Generates a prompt based on the given data point and tokenizes it. |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
tokenizer, |
|
prompt_template: str = "prompt_templates/medalpaca.json", |
|
model_max_length: int = 256, |
|
train_on_inputs: bool = True, |
|
) -> None: |
|
if model_max_length > 2048: |
|
logger.warn(f"{model_max_length} exceeds the max token length LLaMA was trained with.") |
|
self.prompt_template = load_json(prompt_template) |
|
self.model_max_length = model_max_length |
|
self.train_on_inputs = train_on_inputs |
|
self.tokenizer = tokenizer |
|
|
|
def tokenize(self, prompt: str, add_eos_token: bool = True, return_tensors: str = None, truncation: bool = True) -> Dict[str, list]: |
|
""" |
|
Tokenize the given prompt and optionally add an end-of-sequence (EOS) token. |
|
|
|
This function tokenizes the input prompt without adding special tokens by default. |
|
If the `add_eos_token` parameter is True and the tokenized sequence doesn't already |
|
end with an EOS token, an EOS token will be added to the end of the sequence. |
|
|
|
Args: |
|
prompt (str): The text to be tokenized. |
|
add_eos_token (bool, optional): Whether to add an EOS token at the end of |
|
the tokenized sequence. Defaults to True. |
|
return_tensors (str, optional): If tensors should be returned (and what type). |
|
trunctaion (bool, optional); Whether to truncate the input to max_model_length |
|
|
|
|
|
Returns: |
|
Dict: A dictionary containing the tokenized data: |
|
- input_ids: The tokenized input IDs of the prompt. |
|
- attention_mask: The attention mask for the tokenized input IDs. |
|
- labels: The labels for the tokenized input IDs (identical to input_ids). |
|
""" |
|
|
|
result: Dict = self.tokenizer( |
|
prompt, |
|
truncation=truncation, |
|
max_length=self.model_max_length, |
|
padding=False, |
|
return_tensors=return_tensors, |
|
add_special_tokens=False, |
|
) |
|
if ( |
|
result["input_ids"][-1] != self.tokenizer.eos_token_id |
|
and len(result["input_ids"]) < self.model_max_length |
|
and add_eos_token |
|
): |
|
result["input_ids"].append(self.tokenizer.eos_token_id) |
|
result["attention_mask"].append(1) |
|
|
|
result["labels"] = result["input_ids"].copy() |
|
|
|
return result |
|
|
|
def generate_and_tokenize_prompt(self, data_point: Dict): |
|
""" |
|
Generate a prompt based on the given data point and tokenize it. |
|
|
|
This function creates a prompt using the given data point, which consists |
|
of an instruction, input, and output. It then tokenizes the generated prompt |
|
and returns the tokenized representation. If the `train_on_inputs` global |
|
variable is False, the function will create a user prompt without the |
|
expected output and only tokenize that part, masking the output part in the |
|
"labels" field with -100. |
|
|
|
Args: |
|
data_point (Dict): A dictionary containing the following keys: |
|
- instruction: The instruction text for the prompt. |
|
- input: The input text for the prompt. |
|
- output: The output text for the prompt. |
|
|
|
Returns: |
|
Dict: A dictionary containing the tokenized prompt and associated data: |
|
- input_ids: The tokenized input IDs of the generated prompt. |
|
- attention_mask: The attention mask for the tokenized input IDs. |
|
- labels: The labels to be used during model training, with the output |
|
part unmasked and the rest masked with -100 if `train_on_inputs` is False. |
|
""" |
|
prompt: str = self.generate_prompt( |
|
instruction=data_point.get("instruction", ""), |
|
input=data_point.get("input", ""), |
|
output=data_point.get("output", ""), |
|
) |
|
tokenized_prompt: Dict = self.tokenize(prompt) |
|
if not self.train_on_inputs: |
|
user_prompt: str = self.generate_prompt( |
|
instruction=data_point.get("instruction", ""), input=data_point.get("input", "") |
|
) |
|
tokenized_user_prompt: Dict = self.tokenize(user_prompt, add_eos_token=False) |
|
user_prompt_len = len(tokenized_user_prompt["input_ids"]) |
|
|
|
tokenized_prompt["labels"] = [ |
|
-100 if i < user_prompt_len else label |
|
for i, label in enumerate(tokenized_prompt["labels"]) |
|
] |
|
return tokenized_prompt |
|
|
|
def generate_prompt( |
|
self, |
|
instruction: Optional[str] = None, |
|
input: Optional[str] = None, |
|
output: Optional[str] = None, |
|
) -> str: |
|
""" |
|
Generates a prompt for the given instruction, input and output using the specified prompt |
|
template. |
|
|
|
Args: |
|
instruction (Optional[str]): |
|
An optional string representing the instruction to be included in the prompt. |
|
input (Optional[str]): |
|
An optional string representing the input to be included in the prompt. |
|
output (Optional[str]): |
|
An optional string representing the output to be included in the prompt. |
|
|
|
Returns: |
|
str: The prompt string created using the specified prompt template. |
|
|
|
Raises: |
|
ValueError: If none of `instruction`, `input`, and `output` is defined. |
|
|
|
## Example |
|
using `` |
|
|
|
{ |
|
"instruction": |
|
}, |
|
|
|
data_handler = DataHandler(tokenizer, "prompt_templates/medalpaca.json") |
|
prompt = data_hanlder.generate_prompt( |
|
instruction = "Provide a short answer to this medical question.", |
|
input = "What to expect if I have Aortic coarctation (Outlook/Prognosis)?", |
|
output = ( |
|
"The prognosis of aortic coarctation depends on whether balloon " |
|
"angioplasty and stenting or the surgery has been done or not." |
|
) |
|
) |
|
print(prompt) |
|
>>> Below is an instruction that describes a task, paired with an input that provides |
|
further context. Write a response that appropriately completes the request. |
|
|
|
### Instruction: |
|
Provide a short answer to this medical question. |
|
|
|
### Input: |
|
What to expect if I have Aortic coarctation (Outlook/Prognosis)? |
|
|
|
### Response: |
|
The prognosis of aortic coarctation depends on whether balloon angioplasty and |
|
stenting or the surgery has been done or not. |
|
""" |
|
|
|
if not any([instruction, input, output]): |
|
raise ValueError("At least one of `instruction`, `input`, `output` should be defined") |
|
|
|
prompt = ( |
|
f'{self.prompt_template["primer"]}' |
|
f'{self.prompt_template["instruction"]}{instruction or ""}' |
|
f'{self.prompt_template["input"]}{input or ""}' |
|
f'{self.prompt_template["output"]}{output or ""}' |
|
) |
|
|
|
return prompt |
|
|
|
def resolve_output(self, output: str): |
|
pass |
|
|