winglian commited on
Commit
8d959a7
1 Parent(s): ce24f5e

make it work with pythia in the cloud

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ data/*.jsonl filter=lfs diff=lfs merge=lfs -text
configs/pythia_1_2B_alpaca.yml CHANGED
@@ -3,35 +3,36 @@ model_type: GPTNeoXForCausalLM
3
  tokenizer_type: AutoTokenizer
4
  load_in_8bit: true
5
  datasets:
6
- - path: ./data/alpaca_data_gpt4.jsonl
7
  type: alpaca
8
- - path: ./data/vicuna_cleaned.jsonl
9
  type: sharegpt
10
- - path: ./data/gpt4-instruct-similarity-0.6-dataset.jsonl
11
  type: gpteacher
12
- - path: ./data/roleplay-similarity_0.6-instruct-dataset.jsonl
13
  type: gpteacher
14
  val_set_size: 0.05
15
  adapter: lora
16
  sequence_len: 2048
17
- lora_r: 16
18
  lora_alpha: 32
19
  lora_dropout: 0.05
20
  lora_target_modules:
21
- - q_proj
22
- - v_proj
23
- wandb_project:
24
  wandb_watch:
25
- wandb:run_name:
26
  wandb_log_model: checkpoint
27
  output_dir: ./lora-alpaca
28
- batch_size: 128
29
- micro_batch_size: 8
30
  num_epochs: 5
31
  learning_rate: 0.0003
32
  train_on_inputs: false
 
33
  bf16: True
34
- fp16: True
35
  resume_from_checkpoint:
36
  local_rank:
37
  deepspeed:
 
3
  tokenizer_type: AutoTokenizer
4
  load_in_8bit: true
5
  datasets:
6
+ - path: data/alpaca_data_gpt4.jsonl
7
  type: alpaca
8
+ - path: data/vicuna_cleaned.jsonl
9
  type: sharegpt
10
+ - path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
11
  type: gpteacher
12
+ - path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
13
  type: gpteacher
14
  val_set_size: 0.05
15
  adapter: lora
16
  sequence_len: 2048
17
+ lora_r: 8
18
  lora_alpha: 32
19
  lora_dropout: 0.05
20
  lora_target_modules:
21
+ - query_key_value
22
+ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
23
+ wandb_project: pythia-1.4b-lora
24
  wandb_watch:
25
+ wandb_run_name:
26
  wandb_log_model: checkpoint
27
  output_dir: ./lora-alpaca
28
+ batch_size: 32
29
+ micro_batch_size: 4
30
  num_epochs: 5
31
  learning_rate: 0.0003
32
  train_on_inputs: false
33
+ group_by_length: false
34
  bf16: True
35
+ tf32: True
36
  resume_from_checkpoint:
37
  local_rank:
38
  deepspeed:
scripts/finetune.py CHANGED
@@ -1,26 +1,32 @@
 
1
  import os
 
2
  import sys
3
  from pathlib import Path
4
 
 
5
  import fire
6
  import torch
7
  import transformers
8
  import yaml
9
  from attrdict import AttrDict
10
- from datasets import load_dataset, IterableDataset
11
  from peft import (
12
  LoraConfig,
13
  get_peft_model,
14
- prepare_model_for_int8_training,
15
  )
 
16
  from transformers import AutoModelForCausalLM, AutoTokenizer
17
 
18
  # add src to the pythonpath so we don't need to pip install this
 
 
19
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
20
  src_dir = os.path.join(project_root, 'src')
21
  sys.path.insert(0, src_dir)
22
 
23
- from axolotl.datasets import TokenizedPromptDataset
24
  from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy, ShareGPTPromptTokenizingStrategy, \
25
  LLAMA_DEFAULT_PAD_TOKEN, GPTeacherPromptTokenizingStrategy
26
  from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
@@ -29,9 +35,9 @@ def setup_wandb_env_vars(cfg):
29
  if len(cfg.wandb_project) > 0:
30
  os.environ["WANDB_PROJECT"] = cfg.wandb_project
31
  cfg.use_wandb = True
32
- if len(cfg.wandb_watch) > 0:
33
  os.environ["WANDB_WATCH"] = cfg.wandb_watch
34
- if len(cfg.wandb_log_model) > 0:
35
  os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
36
 
37
 
@@ -61,6 +67,10 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
61
  if tokenizer.__class__.__name__ == "LlamaTokenizer":
62
  tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
63
 
 
 
 
 
64
  if cfg.load_in_8bit:
65
  model = prepare_model_for_int8_training(model)
66
 
@@ -69,6 +79,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
69
  lora_alpha=cfg.lora_alpha,
70
  target_modules=cfg.lora_target_modules,
71
  lora_dropout=cfg.lora_dropout,
 
72
  bias="none",
73
  task_type="CAUSAL_LM",
74
  )
@@ -79,7 +90,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
79
  # TODO resume_from_checkpoint handling
80
 
81
  model.print_trainable_parameters()
82
- return model, tokenizer
83
 
84
 
85
  def train(
@@ -88,7 +99,7 @@ def train(
88
  ):
89
  # load the config from the yaml file
90
  with open(config, 'r') as f:
91
- cfg: AttrDict = AttrDict(yaml.load(f))
92
  # if there are any options passed in the cli, if it is something that seems valid from the yaml,
93
  # then overwrite the value
94
  for k, v in enumerate(kwargs):
@@ -107,23 +118,116 @@ def train(
107
  setup_wandb_env_vars(cfg)
108
 
109
  # Load the model and tokenizer
110
- model, tokenizer = load_model(cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter)
111
  datasets = []
112
  for d in cfg.datasets:
113
- ds: IterableDataset = load_dataset("json", data_files=d.path, streaming=True, num_proc=4, split=None)
114
  if d.type == "alpaca":
115
  ds_strategy = AlpacaPromptTokenizingStrategy(AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
116
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
117
  datasets.append(ds_wrapper)
118
  elif d.type == "gpteacher":
119
  ds_strategy = GPTeacherPromptTokenizingStrategy(GPTeacherPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
120
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
121
  datasets.append(ds_wrapper)
122
  elif d.type == "sharegpt":
123
  ds_strategy = ShareGPTPromptTokenizingStrategy(ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
124
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
125
  datasets.append(ds_wrapper)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
 
127
 
128
  if __name__ == "__main__":
129
  fire.Fire(train)
 
1
+ import math
2
  import os
3
+ import signal
4
  import sys
5
  from pathlib import Path
6
 
7
+ import bitsandbytes as bnb
8
  import fire
9
  import torch
10
  import transformers
11
  import yaml
12
  from attrdict import AttrDict
13
+ from datasets import load_dataset, IterableDataset, Dataset
14
  from peft import (
15
  LoraConfig,
16
  get_peft_model,
17
+ prepare_model_for_int8_training, get_peft_model_state_dict,
18
  )
19
+ from torch import nn
20
  from transformers import AutoModelForCausalLM, AutoTokenizer
21
 
22
  # add src to the pythonpath so we don't need to pip install this
23
+ from transformers.trainer_pt_utils import get_parameter_names
24
+
25
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
26
  src_dir = os.path.join(project_root, 'src')
27
  sys.path.insert(0, src_dir)
28
 
29
+ from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
30
  from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy, ShareGPTPromptTokenizingStrategy, \
31
  LLAMA_DEFAULT_PAD_TOKEN, GPTeacherPromptTokenizingStrategy
32
  from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
 
35
  if len(cfg.wandb_project) > 0:
36
  os.environ["WANDB_PROJECT"] = cfg.wandb_project
37
  cfg.use_wandb = True
38
+ if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
39
  os.environ["WANDB_WATCH"] = cfg.wandb_watch
40
+ if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
41
  os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
42
 
43
 
 
67
  if tokenizer.__class__.__name__ == "LlamaTokenizer":
68
  tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
69
 
70
+ if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
71
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
72
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
73
+
74
  if cfg.load_in_8bit:
75
  model = prepare_model_for_int8_training(model)
76
 
 
79
  lora_alpha=cfg.lora_alpha,
80
  target_modules=cfg.lora_target_modules,
81
  lora_dropout=cfg.lora_dropout,
82
+ fan_in_fan_out=cfg.lora_fan_in_fan_out,
83
  bias="none",
84
  task_type="CAUSAL_LM",
85
  )
 
90
  # TODO resume_from_checkpoint handling
91
 
92
  model.print_trainable_parameters()
93
+ return model, tokenizer, lora_config
94
 
95
 
96
  def train(
 
99
  ):
100
  # load the config from the yaml file
101
  with open(config, 'r') as f:
102
+ cfg: AttrDict = AttrDict(yaml.load(f, Loader=yaml.Loader))
103
  # if there are any options passed in the cli, if it is something that seems valid from the yaml,
104
  # then overwrite the value
105
  for k, v in enumerate(kwargs):
 
118
  setup_wandb_env_vars(cfg)
119
 
120
  # Load the model and tokenizer
121
+ model, tokenizer, lora_config = load_model(cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter)
122
  datasets = []
123
  for d in cfg.datasets:
124
+ ds: IterableDataset = load_dataset("json", data_files=d.path, streaming=True, split=None)
125
  if d.type == "alpaca":
126
  ds_strategy = AlpacaPromptTokenizingStrategy(AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
127
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
128
  datasets.append(ds_wrapper)
129
  elif d.type == "gpteacher":
130
  ds_strategy = GPTeacherPromptTokenizingStrategy(GPTeacherPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
131
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
132
  datasets.append(ds_wrapper)
133
  elif d.type == "sharegpt":
134
  ds_strategy = ShareGPTPromptTokenizingStrategy(ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
135
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
136
  datasets.append(ds_wrapper)
137
+ constant_len_dataset = ConstantLengthDataset(tokenizer, datasets, seq_length=cfg.sequence_len)
138
+ constant_len_dataset = Dataset.from_list([_ for _ in constant_len_dataset]).train_test_split(
139
+ test_size=cfg.val_set_size, shuffle=True, seed=42
140
+ )
141
+
142
+ print(constant_len_dataset)
143
+ train_dataset = constant_len_dataset["train"]
144
+ eval_dataset = constant_len_dataset["test"]
145
+
146
+ total_num_steps = int(math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size))
147
+ warmup_steps = min(int(0.03 * total_num_steps), 100)
148
+ logging_steps = min(int(0.005 * total_num_steps), 10)
149
+ save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
150
+
151
+ training_args = transformers.TrainingArguments(
152
+ per_device_train_batch_size=cfg.micro_batch_size,
153
+ gradient_accumulation_steps=cfg.gradient_accumulation_steps,
154
+ warmup_steps=warmup_steps,
155
+ num_train_epochs=cfg.num_epochs,
156
+ learning_rate=cfg.learning_rate,
157
+ bf16=cfg.bf16,
158
+ tf32=cfg.tf32,
159
+ logging_steps=logging_steps,
160
+ evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
161
+ save_strategy="steps",
162
+ eval_steps=eval_steps if cfg.val_set_size > 0 else None,
163
+ save_steps=save_steps,
164
+ output_dir=cfg.output_dir,
165
+ save_total_limit=3,
166
+ load_best_model_at_end=True if cfg.val_set_size > 0 else False,
167
+ ddp_find_unused_parameters=False if cfg.ddp else None,
168
+ group_by_length=cfg.group_by_length,
169
+ report_to="wandb" if cfg.use_wandb else None,
170
+ run_name=cfg.wandb_run_name if cfg.use_wandb else None,
171
+ )
172
+
173
+ decay_parameters = get_parameter_names(model, [nn.LayerNorm])
174
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
175
+ optimizer_grouped_parameters = [
176
+ {
177
+ "params": [p for n, p in model.named_parameters() if n in decay_parameters],
178
+ "weight_decay": training_args.weight_decay,
179
+ },
180
+ {
181
+ "params": [p for n, p in model.named_parameters() if n not in decay_parameters],
182
+ "weight_decay": 0.0,
183
+ },
184
+ ]
185
+
186
+ adam_bnb_optim = bnb.optim.Adam8bit(
187
+ optimizer_grouped_parameters,
188
+ betas=(training_args.adam_beta1, training_args.adam_beta2),
189
+ eps=training_args.adam_epsilon,
190
+ lr=training_args.learning_rate,
191
+ )
192
+
193
+ lr_scheduler = transformers.get_cosine_schedule_with_warmup(
194
+ adam_bnb_optim,
195
+ training_args.warmup_steps,
196
+ total_num_steps,
197
+ )
198
+
199
+ trainer = transformers.Trainer(
200
+ model=model,
201
+ train_dataset=train_dataset,
202
+ eval_dataset=eval_dataset,
203
+ args=training_args,
204
+ optimizers=(adam_bnb_optim, lr_scheduler),
205
+ data_collator=transformers.DataCollatorForSeq2Seq(
206
+ tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
207
+ ),
208
+ )
209
+ model.config.use_cache = False
210
+
211
+ old_state_dict = model.state_dict
212
+ model.state_dict = (
213
+ lambda self, *_, **__: get_peft_model_state_dict(
214
+ self, old_state_dict()
215
+ )
216
+ ).__get__(model, type(model))
217
+
218
+ if torch.__version__ >= "2" and sys.platform != "win32":
219
+ model = torch.compile(model)
220
+
221
+ signal.signal(signal.SIGINT, lambda signal, frame: (
222
+ model.save_pretrained(cfg.output_dir),
223
+ exit(0)
224
+ ))
225
+
226
+ # go ahead and presave the adapter config
227
+ lora_config.save_pretrained(cfg.output_dir)
228
+ trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
229
 
230
+ model.save_pretrained(cfg.output_dir)
231
 
232
  if __name__ == "__main__":
233
  fire.Fire(train)
src/axolotl/convert.py CHANGED
@@ -44,6 +44,7 @@ class JsonToJsonlConverter:
44
  def convert(self, input_file_path, output_file_path):
45
  content = self.file_reader.read(input_file_path)
46
  data = self.json_parser.parse(content)
 
47
  jsonl_content = self.jsonl_serializer.serialize(data)
48
  self.file_writer.write(jsonl_content)
49
 
 
44
  def convert(self, input_file_path, output_file_path):
45
  content = self.file_reader.read(input_file_path)
46
  data = self.json_parser.parse(content)
47
+ # data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations
48
  jsonl_content = self.jsonl_serializer.serialize(data)
49
  self.file_writer.write(jsonl_content)
50
 
src/axolotl/datasets.py CHANGED
@@ -2,7 +2,7 @@ from typing import List
2
 
3
  import torch
4
  from datasets import IterableDataset
5
- from .prompt_tokenizers import PromptTokenizingStrategy
6
 
7
 
8
  # We want this to be a wrapper for an existing dataset that we have loaded
@@ -23,7 +23,12 @@ class TokenizedPromptDataset(IterableDataset):
23
 
24
  def __iter__(self):
25
  iterator = iter(self.dataset)
26
- yield self.prompt_tokenizer.tokenize_prompt(next(iterator))
 
 
 
 
 
27
 
28
 
29
  class ConstantLengthDataset(IterableDataset):
@@ -32,55 +37,68 @@ class ConstantLengthDataset(IterableDataset):
32
  Args:
33
  tokenizer (Tokenizer): The processor used for proccessing the data.
34
  dataset (dataset.Dataset): Dataset with text files.
35
- infinite (bool): If True the iterator is reset after dataset reaches end else stops.
36
  seq_length (int): Length of token sequences to return.
37
- chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
38
  """
39
 
40
  def __init__(
41
  self,
42
  tokenizer,
43
  datasets,
44
- infinite=False,
45
  seq_length=2048,
46
- num_of_sequences=1024,
47
- chars_per_token=3.6,
48
  ):
49
  self.tokenizer = tokenizer
50
- self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else args.eos_token_id
51
  self.datasets: List[IterableDataset] = datasets
52
  self.seq_length = seq_length
53
- self.infinite = infinite
54
- self.current_size = 0
55
- self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
56
 
57
  def __iter__(self):
58
- iterator = iter(self.datasets)
59
- more_examples = True
60
- while more_examples:
61
- buffer, buffer_len = [], 0
62
- while True:
63
- if buffer_len >= self.max_buffer_size:
64
- break
65
  try:
66
- buffer.append(next(iterator))
67
- buffer_len += len(buffer[-1])
68
  except StopIteration:
69
- if self.infinite:
70
- iterator = iter(self.datasets)
71
- else:
72
- more_examples = False
73
- break
74
- tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
75
- all_token_ids = []
76
- for tokenized_input in tokenized_inputs:
77
- all_token_ids.extend(tokenized_input + [self.concat_token_id])
78
- for i in range(0, len(all_token_ids), self.seq_length):
79
- input_ids = all_token_ids[i : i + self.seq_length]
80
- if len(input_ids) == self.seq_length:
81
- self.current_size += 1
82
- yield {
83
- "input_ids": torch.LongTensor(input_ids),
84
- "labels": torch.LongTensor(input_ids),
85
- "attention_masks": torch.LongTensor(input_ids),
86
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import torch
4
  from datasets import IterableDataset
5
+ from .prompt_tokenizers import PromptTokenizingStrategy, InvalidDataException
6
 
7
 
8
  # We want this to be a wrapper for an existing dataset that we have loaded
 
23
 
24
  def __iter__(self):
25
  iterator = iter(self.dataset)
26
+ # Loop through the entire dataset
27
+ for example in iterator:
28
+ try:
29
+ yield self.prompt_tokenizer.tokenize_prompt(example)
30
+ except InvalidDataException:
31
+ pass
32
 
33
 
34
  class ConstantLengthDataset(IterableDataset):
 
37
  Args:
38
  tokenizer (Tokenizer): The processor used for proccessing the data.
39
  dataset (dataset.Dataset): Dataset with text files.
 
40
  seq_length (int): Length of token sequences to return.
 
41
  """
42
 
43
  def __init__(
44
  self,
45
  tokenizer,
46
  datasets,
 
47
  seq_length=2048,
 
 
48
  ):
49
  self.tokenizer = tokenizer
50
+ self.concat_token_id = tokenizer.eos_token_id
51
  self.datasets: List[IterableDataset] = datasets
52
  self.seq_length = seq_length
 
 
 
53
 
54
  def __iter__(self):
55
+ buffer = {"input_ids": [], "attention_mask": [], "labels": []}
56
+ buffer_len = 0
57
+ for dataset in self.datasets:
58
+ iterator = iter(dataset)
59
+ more_examples = True
60
+ while more_examples:
 
61
  try:
62
+ example = next(iterator)
 
63
  except StopIteration:
64
+ more_examples = False
65
+ example = None
66
+
67
+ add_concat_token = False
68
+ if example:
69
+ example_len = len(example["input_ids"])
70
+ add_concat_token = example["input_ids"][-1] != self.concat_token_id
71
+ else:
72
+ example_len = 0
73
+
74
+ if not example_len or buffer_len + int(add_concat_token) + example_len > self.seq_length:
75
+ if buffer["input_ids"]:
76
+ input_ids = torch.cat(buffer["input_ids"], dim=-1)[: self.seq_length]
77
+ attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[: self.seq_length]
78
+ labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
79
+ yield {
80
+ "input_ids": input_ids,
81
+ "labels": labels,
82
+ "attention_mask": attention_mask,
83
+ }
84
+ buffer = {"input_ids": [], "attention_mask": [], "labels": []}
85
+ buffer_len = 0
86
+
87
+ if example:
88
+ input_ids = example["input_ids"]
89
+ attention_mask = example["attention_mask"]
90
+ labels = example["labels"]
91
+
92
+ if add_concat_token:
93
+ input_ids.append(self.concat_token_id)
94
+ attention_mask.append(1)
95
+ labels.append(self.concat_token_id)
96
+
97
+ input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long)
98
+ attention_mask_with_concat = torch.tensor(attention_mask, dtype=torch.long)
99
+ labels_with_concat = torch.tensor(labels, dtype=torch.long)
100
+
101
+ buffer["input_ids"].append(input_ids_with_concat)
102
+ buffer["attention_mask"].append(attention_mask_with_concat)
103
+ buffer["labels"].append(labels_with_concat)
104
+ buffer_len += len(input_ids)
src/axolotl/prompt_tokenizers.py CHANGED
@@ -9,6 +9,10 @@ LLAMA_DEFAULT_BOS_TOKEN = "<s>"
9
  LLAMA_DEFAULT_UNK_TOKEN = "<unk>"
10
 
11
 
 
 
 
 
12
  class PromptTokenizingStrategy(abc.ABC):
13
  def __init__(
14
  self,
@@ -32,7 +36,7 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
32
  full_prompt = self._tokenize_full_prompt(prompt)
33
  tokenized_full_prompt = self._tokenize(full_prompt)
34
  if not self.train_on_inputs:
35
- user_prompt = self.prompter.generate_prompt(
36
  prompt["instruction"], prompt["input"]
37
  )
38
  tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
@@ -43,7 +47,7 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
43
  return tokenized_full_prompt
44
 
45
  def _tokenize_full_prompt(self, prompt):
46
- return self.prompter.generate_prompt(
47
  prompt["instruction"],
48
  prompt["input"],
49
  prompt["output"],
@@ -71,7 +75,7 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
71
 
72
  class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy):
73
  def _tokenize_full_prompt(self, prompt):
74
- return self.prompter.generate_prompt(
75
  prompt["instruction"],
76
  prompt["input"],
77
  prompt["response"],
@@ -80,4 +84,7 @@ class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy):
80
 
81
  class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
82
  def tokenize_prompt(self, prompt):
83
- pass
 
 
 
 
9
  LLAMA_DEFAULT_UNK_TOKEN = "<unk>"
10
 
11
 
12
+ class InvalidDataException(Exception):
13
+ pass
14
+
15
+
16
  class PromptTokenizingStrategy(abc.ABC):
17
  def __init__(
18
  self,
 
36
  full_prompt = self._tokenize_full_prompt(prompt)
37
  tokenized_full_prompt = self._tokenize(full_prompt)
38
  if not self.train_on_inputs:
39
+ user_prompt = self.prompter.build_prompt(
40
  prompt["instruction"], prompt["input"]
41
  )
42
  tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
 
47
  return tokenized_full_prompt
48
 
49
  def _tokenize_full_prompt(self, prompt):
50
+ return self.prompter.build_prompt(
51
  prompt["instruction"],
52
  prompt["input"],
53
  prompt["output"],
 
75
 
76
  class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy):
77
  def _tokenize_full_prompt(self, prompt):
78
+ return self.prompter.build_prompt(
79
  prompt["instruction"],
80
  prompt["input"],
81
  prompt["response"],
 
84
 
85
  class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
86
  def tokenize_prompt(self, prompt):
87
+ try:
88
+ return self.prompter.build_prompt(prompt["conversations"], self.tokenizer)
89
+ except (KeyError, AssertionError) as e:
90
+ raise InvalidDataException(str(e))
src/axolotl/prompters.py CHANGED
@@ -1,10 +1,160 @@
 
 
 
 
 
 
 
 
1
  class AlpacaPrompter:
2
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
 
5
  class ShareGPTPrompter:
6
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
 
 
 
8
 
9
- class GPTeacherPrompter:
10
- pass
 
1
+ import copy
2
+ import dataclasses
3
+ from enum import auto, Enum
4
+ from typing import List, Tuple, Any, Union
5
+
6
+ IGNORE_TOKEN_ID = -100
7
+
8
+
9
  class AlpacaPrompter:
10
+ prompt_input = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
11
+ prompt_no_input = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n"
12
+ response_split = "### Response:"
13
+
14
+ def build_prompt(
15
+ self,
16
+ instruction: str,
17
+ input: Union[None, str] = None,
18
+ output: Union[None, str] = None,
19
+ ) -> str:
20
+ # returns the full prompt from instruction and optional input
21
+ # if a label (=response, =output) is provided, it's also appended.
22
+ if input:
23
+ res = self.prompt_input.format(
24
+ instruction=instruction, input=input
25
+ )
26
+ else:
27
+ res = self.prompt_no_input.format(
28
+ instruction=instruction
29
+ )
30
+ if output:
31
+ res = f"{res}{output}"
32
+ return res
33
+
34
+ def get_response(self, output: str) -> str:
35
+ return output.split(self.response_split)[1].strip()
36
+
37
+
38
+ class GPTeacherPrompter(AlpacaPrompter):
39
+ ...
40
+
41
+
42
+ class SeparatorStyle(Enum):
43
+ """Different separator style."""
44
+ SINGLE = auto()
45
+ TWO = auto()
46
+ DOLLY = auto()
47
+
48
+
49
+ # TODO clean this 💩 up
50
+ @dataclasses.dataclass
51
+ class Conversation:
52
+ """A class that keeps all conversation history."""
53
+ system: str
54
+ roles: List[str]
55
+ messages: List[List[str]]
56
+ offset: int
57
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
58
+ sep: str = "###"
59
+ sep2: str = None
60
+
61
+ def get_prompt(self):
62
+ seps = [self.sep, self.sep2]
63
+ ret = self.system + seps[0]
64
+ for i, (role, message) in enumerate(self.messages):
65
+ if message:
66
+ ret += role + ": " + message + seps[i % 2]
67
+ else:
68
+ ret += role + ":"
69
+ return ret
70
+
71
+ def copy(self):
72
+ return Conversation(
73
+ system=self.system,
74
+ roles=self.roles,
75
+ messages=[[x, y] for x, y in self.messages],
76
+ offset=self.offset,
77
+ sep_style=self.sep_style,
78
+ sep=self.sep,
79
+ sep2=self.sep2,
80
+ )
81
+
82
+ def append_message(self, role, message):
83
+ self.messages.append([role, message])
84
+
85
+
86
+ conv_vicuna_v1_1 = Conversation(
87
+ system="A chat between a curious user and an artificial intelligence assistant. "
88
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
89
+ roles=["USER", "ASSISTANT"],
90
+ messages=[],
91
+ offset=0,
92
+ sep_style=SeparatorStyle.TWO,
93
+ sep=" ",
94
+ sep2="</s>",
95
+ )
96
 
97
 
98
  class ShareGPTPrompter:
99
+ def build_prompt(
100
+ self,
101
+ source,
102
+ tokenizer
103
+ ):
104
+ if len(source) < 2:
105
+ # If there isn't a back and forth conversation, ignore it
106
+ # also happens on the data splitting leaving empty conversations
107
+ raise IndexError
108
+
109
+ conv = conv_vicuna_v1_1.copy()
110
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
111
+
112
+ try:
113
+ # Apply prompt templates
114
+ if source[0]["from"] not in roles or roles[source[0]["from"]] != conv.roles[0]:
115
+ # Skip the first one if it is not from human
116
+ source = source[1:]
117
+ except IndexError as e:
118
+ # sometimes there is a bing or system chat
119
+ raise e
120
+
121
+ conv.messages = []
122
+ for j, sentence in enumerate(source):
123
+ role = roles[sentence["from"]]
124
+ assert role == conv.roles[j % 2]
125
+ conv.append_message(role, sentence["value"])
126
+ conversation = conv.get_prompt()
127
+
128
+ # Tokenize conversations
129
+ tokenized_result = tokenizer(
130
+ conversation,
131
+ truncation=True,
132
+ max_length=2048, # FIXME
133
+ padding=False,
134
+ return_tensors=None,
135
+ )
136
+ target = copy.deepcopy(tokenized_result["input_ids"])
137
+
138
+ # Mask targets
139
+ sep = conv.sep + conv.roles[1] + ": "
140
+
141
+ rounds = conversation.split(conv.sep2)
142
+ cur_len = 1
143
+ for i, rou in enumerate(rounds):
144
+ if rou == "":
145
+ break
146
+
147
+ parts = rou.split(sep)
148
+ if len(parts) != 2:
149
+ break
150
+ parts[0] += sep
151
+ round_len = len(tokenizer(rou)["input_ids"])
152
+ instruction_len = len(tokenizer(parts[0])["input_ids"]) - 2
153
+ target[cur_len:cur_len+instruction_len] = [IGNORE_TOKEN_ID] * instruction_len
154
 
155
+ cur_len += round_len
156
+ target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
157
+ attention_mask = [1 if x != tokenizer.pad_token_id else 0 for x in tokenized_result["input_ids"]]
158
 
159
+ return dict(input_ids=tokenized_result["input_ids"], labels=target,
160
+ attention_mask=attention_mask)