|
""" |
|
Prompt Strategy for finetuning Llama2 chat models |
|
see also https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L213 for ma reference implementation. |
|
|
|
This implementation is based on the Vicuna PR and the fastchat repo, see also: |
|
https://github.com/lm-sys/FastChat/blob/cdd7730686cb1bf9ae2b768ee171bdf7d1ff04f3/fastchat/conversation.py#L847 |
|
|
|
Use dataset type: "llama2_chat" in conig.yml to use this prompt style. |
|
|
|
E.g. in the config.yml: |
|
``` |
|
datasets: |
|
- path: llama_finetune_train.jsonl |
|
type: llama2_chat |
|
``` |
|
|
|
The dataset itself should look like this: |
|
``` |
|
{'conversations':[{"from": "human", "value": "Who are you?"}, {"from": "gpt", "value": "I am Vicuna"},...]} |
|
``` |
|
in a jsonl file. The first message should be from the human, the second from gpt. |
|
For a custom system message, the first "from" can be "system" (followed by alternating "human" and "gpt" turns). |
|
|
|
Important: Don't use "special_tokens:" in your config.yml if you are not sure what you are doing! |
|
""" |
|
|
|
import logging |
|
from dataclasses import dataclass, field |
|
from typing import Generator, List, Sequence |
|
|
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy |
|
from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE |
|
|
|
|
|
@dataclass |
|
class Llama2ChatConversation: |
|
"""A class that manages prompt templates and keeps all conversation history. |
|
copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py""" |
|
|
|
name: str = "llama2" |
|
|
|
system: str = ( |
|
"[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. " |
|
"Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. " |
|
"Please ensure that your responses are socially unbiased and positive in nature.\n\n" |
|
"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. " |
|
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n" |
|
) |
|
roles: Sequence[str] = ("[INST]", "[/INST]") |
|
messages: List[List[str]] = field(default_factory=list) |
|
offset: int = 0 |
|
sep = " " |
|
sep2 = " </s><s>" |
|
stop_token_ids = [2] |
|
|
|
def get_prompt(self) -> str: |
|
"""Get the prompt for generation.""" |
|
seps = [self.sep, self.sep2] |
|
ret = "" |
|
for i, (role, message) in enumerate(self.messages): |
|
if (i == len(self.messages) - 1) and (role == self.roles[0]): |
|
|
|
|
|
return ret |
|
if i == 0: |
|
ret += self.system + message.strip() |
|
else: |
|
ret += role + " " + message.strip() + seps[i % 2] |
|
return ret |
|
|
|
def append_message(self, role: str, message: str): |
|
"""Append a new message.""" |
|
self.messages.append([role, message]) |
|
|
|
|
|
class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy): |
|
""" |
|
Tokenizing strategy for ShareGPT prompts. |
|
adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.tokenizer.add_special_tokens( |
|
{"pad_token": getattr(self.tokenizer, "pad_token", "<pad>")} |
|
) |
|
|
|
|
|
def tokenize_prompt(self, prompt): |
|
conv = next(self.prompter.build_prompt(prompt)) |
|
conversation_str = conv.get_prompt() |
|
|
|
|
|
input_ids = self.tokenizer( |
|
conversation_str, |
|
return_tensors="pt", |
|
padding="max_length", |
|
max_length=self.sequence_len, |
|
truncation=True, |
|
).input_ids[0] |
|
target = input_ids.clone() |
|
|
|
|
|
sep = conv.roles[1] |
|
|
|
total_len = int(target.ne(self.tokenizer.pad_token_id).sum()) |
|
|
|
turns = conversation_str.split(conv.sep2) |
|
cur_len = 1 |
|
target[:cur_len] = IGNORE_TOKEN_ID |
|
for turn in turns: |
|
if turn == "": |
|
break |
|
turn_len = len(self.tokenizer(turn).input_ids) |
|
|
|
parts = turn.split(sep) |
|
if len(parts) != 2: |
|
break |
|
parts[0] += sep |
|
|
|
instruction_len = len(self.tokenizer(parts[0]).input_ids) - 1 |
|
|
|
|
|
target[cur_len - 1 : cur_len + instruction_len] = IGNORE_TOKEN_ID |
|
cur_len += turn_len + 2 |
|
|
|
target[cur_len:] = IGNORE_TOKEN_ID |
|
|
|
if cur_len < self.sequence_len: |
|
if cur_len != total_len: |
|
target[:] = IGNORE_TOKEN_ID |
|
logging.warning( |
|
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
|
f" (ignored)" |
|
) |
|
|
|
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).tolist() |
|
input_ids = input_ids.tolist() |
|
target = target.tolist() |
|
|
|
|
|
for i in range(2, total_len - 2): |
|
if input_ids[i] == 29961: |
|
input_ids[i] = 518 |
|
if target[i] == 29961: |
|
target[i] = 518 |
|
return { |
|
"input_ids": input_ids, |
|
"labels": target, |
|
"attention_mask": attention_mask, |
|
} |
|
|
|
|
|
class Llama2ChatPrompter: |
|
""" |
|
A prompter that generates prompts for Llama2 models. |
|
""" |
|
|
|
system_prompt = ( |
|
"[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. " |
|
"Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. " |
|
"Please ensure that your responses are socially unbiased and positive in nature.\n\n" |
|
"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. " |
|
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n" |
|
) |
|
|
|
def build_prompt(self, source) -> Generator[Llama2ChatConversation, None, None]: |
|
|
|
source = source["conversations"] |
|
|
|
|
|
if source[0]["from"] == "system": |
|
system = f"[INST] <<SYS>>\n{source[0]['value']}\n<</SYS>>\n\n" |
|
source = source[1:] |
|
else: |
|
system = self.system_prompt |
|
|
|
conv = Llama2ChatConversation(system=system) |
|
|
|
if len(source) < 2: |
|
|
|
|
|
raise IndexError |
|
|
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
|
|
|
if roles[source[0]["from"]] != conv.roles[0]: |
|
|
|
source = source[1:] |
|
|
|
conv.messages = [] |
|
for j, sentence in enumerate(source): |
|
role = roles[sentence["from"]] |
|
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE |
|
if sentence["value"]: |
|
conv.append_message(role, sentence["value"]) |
|
yield conv |
|
|
|
|
|
def load(tokenizer, cfg) -> LLama2ChatTokenizingStrategy: |
|
return LLama2ChatTokenizingStrategy( |
|
Llama2ChatPrompter(), |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
|