add support for alpaca reflect training (#2)
Browse files- configs/vicuna_13B_4bit_reflect.yml +45 -0
- scripts/finetune.py +12 -2
- src/axolotl/prompt_tokenizers.py +61 -0
- src/axolotl/prompters.py +29 -0
configs/vicuna_13B_4bit_reflect.yml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_model: anon8231489123/vicuna-13b-GPTQ-4bit-128g
|
2 |
+
base_model_config: anon8231489123/vicuna-13b-GPTQ-4bit-128g
|
3 |
+
model_type: LlamaForCausalLM
|
4 |
+
tokenizer_type: LlamaTokenizer
|
5 |
+
load_in_8bit: false
|
6 |
+
load_4bit: true
|
7 |
+
gptq_groupsize: 128
|
8 |
+
gptq_model_v1: false
|
9 |
+
datasets:
|
10 |
+
# https://github.com/vaguenebula/AlpacaDataReflect/blob/main/alpaca_reflect_pruned.json
|
11 |
+
- path: data/alpaca_reflect_pruned.jsonl
|
12 |
+
type: reflection
|
13 |
+
dataset_prepared_path: data/last_run_prepared
|
14 |
+
val_set_size: 0.04
|
15 |
+
adapter: lora
|
16 |
+
lora_model_dir:
|
17 |
+
sequence_len: 2048
|
18 |
+
max_packed_sequence_len: 2048
|
19 |
+
lora_r: 8
|
20 |
+
lora_alpha: 16
|
21 |
+
lora_dropout: 0.05
|
22 |
+
lora_target_modules:
|
23 |
+
- q_proj
|
24 |
+
- v_proj
|
25 |
+
# - k_proj
|
26 |
+
# - o_proj
|
27 |
+
lora_fan_in_fan_out: false
|
28 |
+
wandb_project:
|
29 |
+
wandb_watch:
|
30 |
+
wandb_run_id:
|
31 |
+
wandb_log_model: checkpoint
|
32 |
+
output_dir: ./lora-reflect
|
33 |
+
batch_size: 8
|
34 |
+
micro_batch_size: 2
|
35 |
+
num_epochs: 3
|
36 |
+
learning_rate: 0.00003
|
37 |
+
train_on_inputs: false
|
38 |
+
group_by_length: false
|
39 |
+
bf16: true
|
40 |
+
tf32: true
|
41 |
+
gradient_checkpointing: false
|
42 |
+
early_stopping_patience: 3
|
43 |
+
resume_from_checkpoint:
|
44 |
+
local_rank:
|
45 |
+
flash_attention: true
|
scripts/finetune.py
CHANGED
@@ -37,9 +37,9 @@ from axolotl.prompt_tokenizers import (
|
|
37 |
ShareGPTPromptTokenizingStrategy,
|
38 |
LLAMA_DEFAULT_PAD_TOKEN,
|
39 |
GPTeacherPromptTokenizingStrategy,
|
40 |
-
OpenAssistantPromptTokenizingStrategy,
|
41 |
)
|
42 |
-
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
|
43 |
|
44 |
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
|
45 |
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
@@ -395,6 +395,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
395 |
)
|
396 |
trainer_kwargs["optimizers"] = (adam_bnb_optim, lr_scheduler)
|
397 |
|
|
|
398 |
if cfg.early_stopping_patience:
|
399 |
early_stop_cb = EarlyStoppingCallback(
|
400 |
cfg.early_stopping_patience,
|
@@ -540,6 +541,15 @@ def train(
|
|
540 |
)
|
541 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
542 |
datasets.append(ds_wrapper)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
543 |
elif d.type == "sharegpt":
|
544 |
ds_strategy = ShareGPTPromptTokenizingStrategy(
|
545 |
ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
|
|
37 |
ShareGPTPromptTokenizingStrategy,
|
38 |
LLAMA_DEFAULT_PAD_TOKEN,
|
39 |
GPTeacherPromptTokenizingStrategy,
|
40 |
+
OpenAssistantPromptTokenizingStrategy, AlpacaReflectionPTStrategy,
|
41 |
)
|
42 |
+
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter, ReflectAlpacaPrompter
|
43 |
|
44 |
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
|
45 |
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
|
|
395 |
)
|
396 |
trainer_kwargs["optimizers"] = (adam_bnb_optim, lr_scheduler)
|
397 |
|
398 |
+
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
399 |
if cfg.early_stopping_patience:
|
400 |
early_stop_cb = EarlyStoppingCallback(
|
401 |
cfg.early_stopping_patience,
|
|
|
541 |
)
|
542 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
543 |
datasets.append(ds_wrapper)
|
544 |
+
elif d.type == "reflection":
|
545 |
+
ds_strategy = AlpacaReflectionPTStrategy(
|
546 |
+
ReflectAlpacaPrompter(),
|
547 |
+
tokenizer,
|
548 |
+
cfg.train_on_inputs,
|
549 |
+
cfg.sequence_len,
|
550 |
+
)
|
551 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
552 |
+
datasets.append(ds_wrapper)
|
553 |
elif d.type == "sharegpt":
|
554 |
ds_strategy = ShareGPTPromptTokenizingStrategy(
|
555 |
ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
src/axolotl/prompt_tokenizers.py
CHANGED
@@ -107,6 +107,67 @@ class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
|
107 |
)
|
108 |
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
111 |
def tokenize_prompt(self, prompt):
|
112 |
try:
|
|
|
107 |
)
|
108 |
|
109 |
|
110 |
+
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
111 |
+
def parse_instruction_fields(self, prompt) -> (str, str, str, str, str):
|
112 |
+
raise NotImplementedError
|
113 |
+
|
114 |
+
def tokenize_prompt(self, prompt):
|
115 |
+
instruction, input, output, reflection, corrected = self.parse_instruction_fields(prompt)
|
116 |
+
full_prompt = self._build_full_prompt(instruction, input, output, reflection, corrected)
|
117 |
+
tokenized_full_prompt = self._tokenize(full_prompt)
|
118 |
+
if not self.train_on_inputs:
|
119 |
+
user_prompt = self.prompter.build_prompt(
|
120 |
+
instruction,
|
121 |
+
input,
|
122 |
+
)
|
123 |
+
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
124 |
+
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
125 |
+
# TODO this could be sped up using numpy array slicing
|
126 |
+
tokenized_full_prompt["labels"] = [
|
127 |
+
-100
|
128 |
+
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
|
129 |
+
|
130 |
+
return tokenized_full_prompt
|
131 |
+
|
132 |
+
def _build_full_prompt(self, instruction, input, output, reflection, corrected):
|
133 |
+
return self.prompter.build_prompt(
|
134 |
+
instruction,
|
135 |
+
input,
|
136 |
+
output,
|
137 |
+
reflection,
|
138 |
+
corrected,
|
139 |
+
)
|
140 |
+
|
141 |
+
def _tokenize(self, prompt, add_eos_token=True):
|
142 |
+
result = self.tokenizer(
|
143 |
+
prompt,
|
144 |
+
truncation=True,
|
145 |
+
max_length=self.sequence_len,
|
146 |
+
padding=False,
|
147 |
+
return_tensors=None,
|
148 |
+
)
|
149 |
+
if (
|
150 |
+
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
151 |
+
and len(result["input_ids"]) < self.sequence_len
|
152 |
+
and add_eos_token
|
153 |
+
):
|
154 |
+
result["input_ids"].append(self.tokenizer.eos_token_id)
|
155 |
+
result["attention_mask"].append(1)
|
156 |
+
|
157 |
+
result["labels"] = result["input_ids"].copy()
|
158 |
+
return result
|
159 |
+
|
160 |
+
|
161 |
+
class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
|
162 |
+
def parse_instruction_fields(self, prompt) -> (str, str, str, str, str):
|
163 |
+
return (
|
164 |
+
prompt["instruction"],
|
165 |
+
prompt["input"] if "input" in prompt else "",
|
166 |
+
prompt["output"],
|
167 |
+
prompt["reflection"],
|
168 |
+
prompt["corrected"],
|
169 |
+
)
|
170 |
+
|
171 |
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
172 |
def tokenize_prompt(self, prompt):
|
173 |
try:
|
src/axolotl/prompters.py
CHANGED
@@ -35,6 +35,35 @@ class GPTeacherPrompter(AlpacaPrompter):
|
|
35 |
...
|
36 |
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
class SeparatorStyle(Enum):
|
39 |
"""Different separator style."""
|
40 |
|
|
|
35 |
...
|
36 |
|
37 |
|
38 |
+
class ReflectAlpacaPrompter:
|
39 |
+
prompt_input = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
40 |
+
prompt_no_input = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n### Instruction:\n{instruction}\n\n### Response:\n"
|
41 |
+
agent_label = "{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
|
42 |
+
response_split = "### Response:"
|
43 |
+
|
44 |
+
def build_prompt(
|
45 |
+
self,
|
46 |
+
instruction: str,
|
47 |
+
input: Union[None, str] = None,
|
48 |
+
output: Union[None, str] = None,
|
49 |
+
reflection: Union[None, str] = None,
|
50 |
+
corrected: Union[None, str] = None,
|
51 |
+
) -> str:
|
52 |
+
# returns the full prompt from instruction and optional input
|
53 |
+
# if a label (=response, =output) is provided, it's also appended.
|
54 |
+
if input:
|
55 |
+
res = self.prompt_input.format(instruction=instruction, input=input)
|
56 |
+
else:
|
57 |
+
res = self.prompt_no_input.format(instruction=instruction)
|
58 |
+
if output and reflection and corrected:
|
59 |
+
label = self.agent_label.format(output=output, reflection=reflection, corrected=corrected)
|
60 |
+
res = f"{res}{label}"
|
61 |
+
return res
|
62 |
+
|
63 |
+
def get_response(self, output: str) -> str:
|
64 |
+
return output.split(self.response_split)[1].strip()
|
65 |
+
|
66 |
+
|
67 |
class SeparatorStyle(Enum):
|
68 |
"""Different separator style."""
|
69 |
|