winglian commited on
Commit
1edc30c
·
1 Parent(s): 41e4f6c

add support for opimum bettertransformers

Browse files
configs/gpt_neox_20b.yml CHANGED
@@ -1,24 +1,25 @@
1
  base_model: EleutherAI/gpt-neox-20b
 
2
  base_model_ignore_patterns: pytorch* # prefer safetensors
3
  model_type: GPTNeoXForCausalLM
4
  tokenizer_type: AutoTokenizer
5
- load_in_8bit: true
 
 
6
  datasets:
7
- - path: nomic-ai/gpt4all-j-prompt-generations
8
  type: alpaca
9
- shards: 4
10
- shards_index: 0
11
  dataset_prepared_path: last_run_prepared
12
  val_set_size: 0.05
13
- adapter: lora
14
  lora_model_dir:
15
  sequence_len: 2048
16
  max_packed_sequence_len: 2048
17
- lora_r: 8
18
  lora_alpha: 32
19
- lora_dropout: 0.05
20
  lora_target_modules:
21
- - query_key_value
22
  lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
23
  wandb_project: gpt4all-neox-20b
24
  wandb_watch:
@@ -26,14 +27,19 @@ wandb_run_id:
26
  wandb_log_model:
27
  output_dir: ./gpt4all-neox-20b
28
  gradient_accumulation_steps: 1
29
- micro_batch_size: 4
30
  num_epochs: 5
31
  learning_rate: 0.00003
32
- lr_scheduler: one_cycle
 
33
  train_on_inputs: false
34
  group_by_length: false
35
- bf16: True
36
- tf32: True
 
 
 
37
  early_stopping_patience:
38
  resume_from_checkpoint:
39
  local_rank:
 
 
1
  base_model: EleutherAI/gpt-neox-20b
2
+ base_model_config: EleutherAI/gpt-neox-20b
3
  base_model_ignore_patterns: pytorch* # prefer safetensors
4
  model_type: GPTNeoXForCausalLM
5
  tokenizer_type: AutoTokenizer
6
+ load_in_8bit: false
7
+ load_in_4bit: true
8
+ load_4bit: false
9
  datasets:
10
+ - path: vicgalle/alpaca-gpt4
11
  type: alpaca
 
 
12
  dataset_prepared_path: last_run_prepared
13
  val_set_size: 0.05
14
+ adapter:
15
  lora_model_dir:
16
  sequence_len: 2048
17
  max_packed_sequence_len: 2048
18
+ lora_r: 64
19
  lora_alpha: 32
20
+ lora_dropout: 0.0
21
  lora_target_modules:
22
+ lora_target_linear: true
23
  lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
24
  wandb_project: gpt4all-neox-20b
25
  wandb_watch:
 
27
  wandb_log_model:
28
  output_dir: ./gpt4all-neox-20b
29
  gradient_accumulation_steps: 1
30
+ micro_batch_size: 2
31
  num_epochs: 5
32
  learning_rate: 0.00003
33
+ optimizer: paged_adamw_32bit
34
+ lr_scheduler: cosine
35
  train_on_inputs: false
36
  group_by_length: false
37
+ bf16: false
38
+ fp16: false
39
+ float16: true
40
+ tf32: true
41
+ flash_optimum: true
42
  early_stopping_patience:
43
  resume_from_checkpoint:
44
  local_rank:
45
+ gradient_checkpointing: true
requirements.txt CHANGED
@@ -11,6 +11,7 @@ sentencepiece
11
  wandb
12
  einops
13
  xformers
 
14
  # qlora things
15
  bert-score==0.3.13
16
  evaluate==0.4.0
 
11
  wandb
12
  einops
13
  xformers
14
+ optimum
15
  # qlora things
16
  bert-score==0.3.13
17
  evaluate==0.4.0
scripts/finetune.py CHANGED
@@ -6,6 +6,7 @@ import os
6
  import random
7
  import signal
8
  import sys
 
9
  from pathlib import Path
10
  from typing import Any, Dict, List, Optional, Union
11
 
@@ -19,6 +20,8 @@ from axolotl.utils.dict import DictDefault
19
  from axolotl.utils.models import load_model, load_tokenizer
20
 
21
  # add src to the pythonpath so we don't need to pip install this
 
 
22
  from axolotl.utils.tokenization import check_dataset_labels
23
  from axolotl.utils.trainer import setup_trainer
24
  from axolotl.utils.validation import validate_config
@@ -264,12 +267,14 @@ def train(
264
 
265
  # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
266
  if cfg.local_rank == 0:
 
 
 
 
 
267
  signal.signal(
268
  signal.SIGINT,
269
- lambda signal, frame: (
270
- model.save_pretrained(cfg.output_dir),
271
- sys.exit(0),
272
- ),
273
  )
274
 
275
  logging.info("Starting trainer...")
@@ -299,6 +304,8 @@ def train(
299
  # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
300
  # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
301
  if cfg.local_rank == 0:
 
 
302
  model.save_pretrained(cfg.output_dir)
303
 
304
  # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
 
6
  import random
7
  import signal
8
  import sys
9
+ from functools import partial
10
  from pathlib import Path
11
  from typing import Any, Dict, List, Optional, Union
12
 
 
20
  from axolotl.utils.models import load_model, load_tokenizer
21
 
22
  # add src to the pythonpath so we don't need to pip install this
23
+ from optimum.bettertransformer import BetterTransformer
24
+
25
  from axolotl.utils.tokenization import check_dataset_labels
26
  from axolotl.utils.trainer import setup_trainer
27
  from axolotl.utils.validation import validate_config
 
267
 
268
  # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
269
  if cfg.local_rank == 0:
270
+ def terminate_handler(signum, frame, model):
271
+ if cfg.flash_optimum:
272
+ model = BetterTransformer.reverse(model)
273
+ model.save_pretrained(cfg.output_dir)
274
+ sys.exit(0)
275
  signal.signal(
276
  signal.SIGINT,
277
+ lambda signum, frame: terminate_handler(signum, frame, model)
 
 
 
278
  )
279
 
280
  logging.info("Starting trainer...")
 
304
  # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
305
  # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
306
  if cfg.local_rank == 0:
307
+ if cfg.flash_optimum:
308
+ model = BetterTransformer.reverse(model)
309
  model.save_pretrained(cfg.output_dir)
310
 
311
  # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
src/axolotl/utils/models.py CHANGED
@@ -11,7 +11,8 @@ import bitsandbytes as bnb
11
  import torch
12
  import transformers
13
  from transformers import PreTrainedModel # noqa: F401
14
- from transformers import ( # noqa: F401
 
15
  AutoConfig,
16
  AutoModelForCausalLM,
17
  AutoTokenizer,
@@ -137,7 +138,7 @@ def load_model(
137
 
138
  if cfg.bf16:
139
  torch_dtype = torch.bfloat16
140
- elif cfg.load_in_8bit or cfg.fp16:
141
  torch_dtype = torch.float16
142
  else:
143
  torch_dtype = torch.float32
@@ -342,6 +343,9 @@ def load_model(
342
  logging.warning("there are no parameters that require gradient updates")
343
  model.config.use_cache = False
344
 
 
 
 
345
  # TODO resume_from_checkpoint handling
346
  return model, lora_config
347
 
 
11
  import torch
12
  import transformers
13
  from transformers import PreTrainedModel # noqa: F401
14
+ from optimum.bettertransformer import BetterTransformer
15
+ from transformers import (
16
  AutoConfig,
17
  AutoModelForCausalLM,
18
  AutoTokenizer,
 
138
 
139
  if cfg.bf16:
140
  torch_dtype = torch.bfloat16
141
+ elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
142
  torch_dtype = torch.float16
143
  else:
144
  torch_dtype = torch.float32
 
343
  logging.warning("there are no parameters that require gradient updates")
344
  model.config.use_cache = False
345
 
346
+ if cfg.flash_optimum:
347
+ model = BetterTransformer.transform(model)
348
+
349
  # TODO resume_from_checkpoint handling
350
  return model, lora_config
351
 
src/axolotl/utils/validation.py CHANGED
@@ -57,6 +57,14 @@ def validate_config(cfg):
57
  if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
58
  raise ValueError("FSDP is not supported for falcon models")
59
 
 
 
 
 
 
 
 
 
60
  # TODO
61
  # MPT 7b
62
  # https://github.com/facebookresearch/bitsandbytes/issues/25
 
57
  if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
58
  raise ValueError("FSDP is not supported for falcon models")
59
 
60
+ if cfg.flash_optimum is True:
61
+ if cfg.adapter:
62
+ logging.warning("BetterTransformers probably doesn't work with PEFT adapters")
63
+ if cfg.fp16 or cfg.bf16:
64
+ raise ValueError("AMP is not supported with BetterTransformer")
65
+ if cfg.float16 is not True:
66
+ logging.warning("You should probably set float16 to true")
67
+
68
  # TODO
69
  # MPT 7b
70
  # https://github.com/facebookresearch/bitsandbytes/issues/25