winglian commited on
Commit
80b2ed2
1 Parent(s): 45f77dd

various bugfixes

Browse files
configs/llama_65B_alpaca.yml CHANGED
@@ -1,4 +1,4 @@
1
- base_model: decapoda-research/llama-65b-hf
2
  model_type: LlamaForCausalLM
3
  tokenizer_type: LlamaTokenizer
4
  load_in_8bit: true
@@ -33,8 +33,8 @@ num_epochs: 5
33
  learning_rate: 0.00003
34
  train_on_inputs: false
35
  group_by_length: false
36
- bf16: True
37
- tf32: True
38
  resume_from_checkpoint:
39
  local_rank:
40
  deepspeed:
 
1
+ base_model: huggyllama/llama-7b
2
  model_type: LlamaForCausalLM
3
  tokenizer_type: LlamaTokenizer
4
  load_in_8bit: true
 
33
  learning_rate: 0.00003
34
  train_on_inputs: false
35
  group_by_length: false
36
+ bf16: true
37
+ tf32: true
38
  resume_from_checkpoint:
39
  local_rank:
40
  deepspeed:
requirements.txt CHANGED
@@ -10,3 +10,6 @@ accelerate
10
  sentencepiece
11
  wandb
12
  flash-attn
 
 
 
 
10
  sentencepiece
11
  wandb
12
  flash-attn
13
+ deepspeed
14
+ einops
15
+
scripts/finetune.py CHANGED
@@ -11,7 +11,7 @@ import torch
11
  import transformers
12
  import yaml
13
  from attrdict import AttrDefault
14
- from datasets import load_dataset, IterableDataset, Dataset
15
  from peft import (
16
  LoraConfig,
17
  get_peft_model,
@@ -52,8 +52,9 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
52
  if adapter != "lora":
53
  raise NotImplementedError(f"{adapter} peft adapter not available")
54
  if "llama" in base_model:
55
- from axolotl.flash_attn import replace_llama_attn_with_flash_attn
56
- replace_llama_attn_with_flash_attn()
 
57
 
58
  try:
59
  if "llama" in base_model:
@@ -86,7 +87,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
86
  except:
87
  tokenizer = AutoTokenizer.from_pretrained(base_model)
88
 
89
- if tokenizer.__class__.__name__ == "LlamaTokenizer":
90
  tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
91
 
92
  if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
@@ -255,8 +256,9 @@ def train(
255
  return
256
 
257
  datasets = []
258
- if len(cfg.datasets) == 1 and cfg.datasets[0].type == "arrow":
259
- dataset = load_dataset(cfg.datasets[0].path, split="train")
 
260
  else:
261
  for d in cfg.datasets:
262
  ds: IterableDataset = load_dataset(
@@ -288,7 +290,6 @@ def train(
288
  [_ for _ in constant_len_dataset]
289
  ).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)
290
  dataset.save_to_disk("data/last_run")
291
- print(dataset)
292
 
293
  train_dataset = dataset["train"]
294
  eval_dataset = dataset["test"]
 
11
  import transformers
12
  import yaml
13
  from attrdict import AttrDefault
14
+ from datasets import load_dataset, IterableDataset, Dataset, load_from_disk
15
  from peft import (
16
  LoraConfig,
17
  get_peft_model,
 
52
  if adapter != "lora":
53
  raise NotImplementedError(f"{adapter} peft adapter not available")
54
  if "llama" in base_model:
55
+ if cfg.device not in ["mps", "cpu"]:
56
+ from axolotl.flash_attn import replace_llama_attn_with_flash_attn
57
+ replace_llama_attn_with_flash_attn()
58
 
59
  try:
60
  if "llama" in base_model:
 
87
  except:
88
  tokenizer = AutoTokenizer.from_pretrained(base_model)
89
 
90
+ if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
91
  tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
92
 
93
  if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
 
256
  return
257
 
258
  datasets = []
259
+ if not isinstance(cfg.datasets, list) and isinstance(cfg.datasets, str):
260
+ # assumption that we are loading a previously saved/cached dataset
261
+ dataset = load_from_disk(cfg.datasets)
262
  else:
263
  for d in cfg.datasets:
264
  ds: IterableDataset = load_dataset(
 
290
  [_ for _ in constant_len_dataset]
291
  ).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)
292
  dataset.save_to_disk("data/last_run")
 
293
 
294
  train_dataset = dataset["train"]
295
  eval_dataset = dataset["test"]
setup.cfg CHANGED
@@ -23,6 +23,7 @@ install_requires =
23
  sentencepiece
24
  wandb
25
  flash-attn
 
26
 
27
  [options.packages.find]
28
  where = src
 
23
  sentencepiece
24
  wandb
25
  flash-attn
26
+ einops
27
 
28
  [options.packages.find]
29
  where = src
src/axolotl/datasets.py CHANGED
@@ -93,22 +93,24 @@ class ConstantLengthDataset(IterableDataset):
93
  buffer_len = 0
94
 
95
  if example:
96
- input_ids = example["input_ids"]
97
- attention_mask = example["attention_mask"]
98
- labels = example["labels"]
99
-
100
- if add_concat_token:
101
- input_ids.append(self.concat_token_id)
102
- attention_mask.append(1)
103
- labels.append(self.concat_token_id)
104
-
105
- input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long)
106
- attention_mask_with_concat = torch.tensor(
107
- attention_mask, dtype=torch.long
108
- )
109
- labels_with_concat = torch.tensor(labels, dtype=torch.long)
110
-
111
- buffer["input_ids"].append(input_ids_with_concat)
112
- buffer["attention_mask"].append(attention_mask_with_concat)
113
- buffer["labels"].append(labels_with_concat)
114
- buffer_len += len(input_ids)
 
 
 
93
  buffer_len = 0
94
 
95
  if example:
96
+ # just going to drop data points that are too long
97
+ if len(example["input_ids"]) <= self.seq_length:
98
+ input_ids = example["input_ids"]
99
+ attention_mask = example["attention_mask"]
100
+ labels = example["labels"]
101
+
102
+ if add_concat_token:
103
+ input_ids.append(self.concat_token_id)
104
+ attention_mask.append(1)
105
+ labels.append(self.concat_token_id)
106
+
107
+ input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long)
108
+ attention_mask_with_concat = torch.tensor(
109
+ attention_mask, dtype=torch.long
110
+ )
111
+ labels_with_concat = torch.tensor(labels, dtype=torch.long)
112
+
113
+ buffer["input_ids"].append(input_ids_with_concat)
114
+ buffer["attention_mask"].append(attention_mask_with_concat)
115
+ buffer["labels"].append(labels_with_concat)
116
+ buffer_len += len(input_ids)