winglian commited on
Commit
9b6ee83
1 Parent(s): 638c2da

FDSP + QLoRA (#1378)

Browse files

* wip qlora + fsdp fixes

* more fixes

* make sure to load the lora :facepalm:

* only setup quantized meta on non-zero rank:

* only run setup_quantized_peft_meta_for_training for qlora+fsdp

* more fixes for qlora+fsdp

* chore: lint

* add example yml

* support mistral too

* fix for model_type and add mixtral support too

* set cpu_offload: false to reduce vram, constrain new accleerator logic to qlora + fsdp

* refactor for duplicate code

examples/llama-2/qlora-fsdp.yml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: NousResearch/Llama-2-7b-hf
2
+ model_type: LlamaForCausalLM
3
+ tokenizer_type: LlamaTokenizer
4
+
5
+ load_in_8bit: false
6
+ load_in_4bit: true
7
+ strict: false
8
+
9
+ datasets:
10
+ - path: yahma/alpaca-cleaned
11
+ type: alpaca
12
+ dataset_prepared_path: last_run_prepared
13
+ val_set_size: 0.05
14
+ output_dir: ./qlora-out
15
+
16
+ adapter: qlora
17
+ lora_model_dir:
18
+
19
+ sequence_len: 512
20
+ sample_packing: false
21
+ pad_to_sequence_len: true
22
+
23
+ lora_r: 32
24
+ lora_alpha: 16
25
+ lora_dropout: 0.05
26
+ lora_target_modules:
27
+ lora_target_linear: true
28
+ lora_fan_in_fan_out:
29
+
30
+ wandb_project:
31
+ wandb_entity:
32
+ wandb_watch:
33
+ wandb_name:
34
+ wandb_log_model:
35
+
36
+ gradient_accumulation_steps: 4
37
+ micro_batch_size: 4
38
+ num_epochs: 4
39
+ optimizer: paged_adamw_8bit
40
+ lr_scheduler: cosine
41
+ learning_rate: 0.00001
42
+
43
+ train_on_inputs: false
44
+ group_by_length: false
45
+ bf16: auto
46
+ fp16:
47
+ tf32: false
48
+
49
+ gradient_checkpointing: true
50
+ gradient_checkpointing_kwargs:
51
+ use_reentrant: true
52
+ early_stopping_patience:
53
+ resume_from_checkpoint:
54
+ local_rank:
55
+ logging_steps: 1
56
+ xformers_attention:
57
+ flash_attention: true
58
+
59
+ warmup_steps: 10
60
+ evals_per_epoch: 4
61
+ eval_table_size:
62
+ saves_per_epoch: 1
63
+ debug:
64
+ deepspeed:
65
+ weight_decay: 0.0
66
+ fsdp:
67
+ - full_shard
68
+ fsdp_config:
69
+ fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
70
+ special_tokens:
examples/mistral/mixtral-qlora-fsdp.yml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: mistralai/Mixtral-8x7B-v0.1
2
+ model_type: AutoModelForCausalLM
3
+ tokenizer_type: LlamaTokenizer
4
+ trust_remote_code: true
5
+
6
+ load_in_8bit: false
7
+ load_in_4bit: true
8
+ strict: false
9
+
10
+ datasets:
11
+ - path: tatsu-lab/alpaca
12
+ type: alpaca
13
+ dataset_prepared_path: last_run_prepared
14
+ val_set_size: 0.02
15
+ output_dir: ./qlora-out
16
+
17
+ model_config:
18
+ output_router_logits: true
19
+
20
+ adapter: qlora
21
+ lora_model_dir:
22
+
23
+ sequence_len: 1024
24
+ sample_packing: false
25
+ pad_to_sequence_len: false
26
+
27
+ lora_r: 32
28
+ lora_alpha: 16
29
+ lora_dropout: 0.05
30
+ lora_target_linear: true
31
+ lora_fan_in_fan_out:
32
+
33
+ wandb_project:
34
+ wandb_entity:
35
+ wandb_watch:
36
+ wandb_name:
37
+ wandb_log_model:
38
+
39
+ gradient_accumulation_steps: 4
40
+ micro_batch_size: 2
41
+ num_epochs: 1
42
+ optimizer: paged_adamw_8bit
43
+ lr_scheduler: cosine
44
+ learning_rate: 0.0002
45
+
46
+ train_on_inputs: false
47
+ group_by_length: false
48
+ bf16: auto
49
+ fp16:
50
+ tf32: false
51
+
52
+ gradient_checkpointing: true
53
+ early_stopping_patience:
54
+ resume_from_checkpoint:
55
+ local_rank:
56
+ logging_steps: 1
57
+ xformers_attention:
58
+ flash_attention: true
59
+
60
+ loss_watchdog_threshold: 5.0
61
+ loss_watchdog_patience: 3
62
+
63
+ warmup_steps: 10
64
+ evals_per_epoch: 4
65
+ eval_table_size:
66
+ eval_max_new_tokens: 128
67
+ saves_per_epoch: 1
68
+ debug:
69
+ weight_decay: 0.0
70
+ fsdp:
71
+ - full_shard
72
+ fsdp_config:
73
+ fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
74
+ special_tokens:
requirements.txt CHANGED
@@ -3,7 +3,7 @@ packaging==23.2
3
  peft==0.9.0
4
  transformers==4.38.2
5
  tokenizers==0.15.0
6
- bitsandbytes>=0.41.1
7
  accelerate==0.26.1
8
  deepspeed==0.13.1
9
  pydantic==2.6.3
@@ -40,3 +40,4 @@ gcsfs
40
  # adlfs
41
 
42
  trl>=0.7.9
 
 
3
  peft==0.9.0
4
  transformers==4.38.2
5
  tokenizers==0.15.0
6
+ bitsandbytes>=0.43.0
7
  accelerate==0.26.1
8
  deepspeed==0.13.1
9
  pydantic==2.6.3
 
40
  # adlfs
41
 
42
  trl>=0.7.9
43
+ fastcore>=1.5.29
src/axolotl/core/policies/__init__.py ADDED
File without changes
src/axolotl/core/policies/auto_wrap.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """module for building the auto wrap policy for FSDP"""
2
+ import functools
3
+
4
+ from peft import PrefixEncoder, PromptEmbedding, PromptEncoder
5
+ from torch.distributed.fsdp.wrap import (
6
+ _or_policy,
7
+ lambda_auto_wrap_policy,
8
+ transformer_auto_wrap_policy,
9
+ )
10
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
11
+ from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
12
+ from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
13
+
14
+ SUPPORTED_AUTO_WRAP_MODEL_TYPES = [
15
+ "llama",
16
+ "mistral",
17
+ "mixtral",
18
+ ]
19
+
20
+
21
+ def get_wrapping_policy_factory(model_type):
22
+ if model_type == "llama":
23
+ layer_to_wrap = LlamaDecoderLayer
24
+ elif model_type == "mistral":
25
+ layer_to_wrap = MistralDecoderLayer
26
+ elif model_type == "mixtral":
27
+ layer_to_wrap = MixtralDecoderLayer
28
+
29
+ def get_wrapping_policy():
30
+ """This checks for lora layers (has weight and requires_grad)"""
31
+
32
+ def lambda_policy_fn(module):
33
+ return (
34
+ len(list(module.named_children())) == 0
35
+ and getattr(module, "weight", None) is not None
36
+ and module.weight.requires_grad
37
+ )
38
+
39
+ lambda_policy = functools.partial(
40
+ lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn
41
+ )
42
+ transformer_layer_name = layer_to_wrap
43
+ transformer_wrap_policy = functools.partial(
44
+ transformer_auto_wrap_policy,
45
+ transformer_layer_cls=(
46
+ PrefixEncoder,
47
+ PromptEncoder,
48
+ PromptEmbedding,
49
+ transformer_layer_name,
50
+ ),
51
+ )
52
+ policies = [lambda_policy, transformer_wrap_policy]
53
+ return functools.partial(_or_policy, policies=policies)
54
+
55
+ return get_wrapping_policy
src/axolotl/core/trainer_builder.py CHANGED
@@ -8,6 +8,7 @@ import importlib
8
  import importlib.util
9
  import logging
10
  import math
 
11
  import sys
12
  from abc import abstractmethod
13
  from dataclasses import dataclass, field
@@ -17,7 +18,10 @@ from typing import List, Optional, Type, Union
17
 
18
  import torch
19
  import transformers
 
 
20
  from datasets import Dataset
 
21
  from torch.optim.lr_scheduler import OneCycleLR
22
  from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
23
  from transformers import (
@@ -30,6 +34,7 @@ from transformers.trainer_utils import seed_worker
30
  from transformers.utils import is_sagemaker_mp_enabled
31
  from trl import DPOTrainer
32
 
 
33
  from axolotl.loraplus import create_loraplus_optimizer
34
  from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
35
  from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
@@ -191,6 +196,10 @@ class AxolotlTrainingArguments(TrainingArguments):
191
  default=1e-6,
192
  metadata={"help": "loraplus learning rate for lora embedding layers."},
193
  )
 
 
 
 
194
 
195
 
196
  class AxolotlTrainer(Trainer):
@@ -468,6 +477,56 @@ class AxolotlTrainer(Trainer):
468
 
469
  return super().push_to_hub(*args, **kwargs)
470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
 
472
  class AxolotlMambaTrainer(AxolotlTrainer):
473
  """
@@ -787,6 +846,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
787
  if self.cfg.fsdp_config:
788
  training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
789
 
 
 
 
790
  # deepspeed
791
  if self.cfg.deepspeed:
792
  training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
 
8
  import importlib.util
9
  import logging
10
  import math
11
+ import os
12
  import sys
13
  from abc import abstractmethod
14
  from dataclasses import dataclass, field
 
18
 
19
  import torch
20
  import transformers
21
+ from accelerate import FullyShardedDataParallelPlugin
22
+ from accelerate.utils import str_to_bool
23
  from datasets import Dataset
24
+ from torch.distributed.fsdp import MixedPrecision
25
  from torch.optim.lr_scheduler import OneCycleLR
26
  from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
27
  from transformers import (
 
34
  from transformers.utils import is_sagemaker_mp_enabled
35
  from trl import DPOTrainer
36
 
37
+ from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory
38
  from axolotl.loraplus import create_loraplus_optimizer
39
  from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
40
  from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
 
196
  default=1e-6,
197
  metadata={"help": "loraplus learning rate for lora embedding layers."},
198
  )
199
+ qlora: bool = field(
200
+ default=False,
201
+ metadata={"help": "whether this is a qlora training"},
202
+ )
203
 
204
 
205
  class AxolotlTrainer(Trainer):
 
477
 
478
  return super().push_to_hub(*args, **kwargs)
479
 
480
+ @wraps(Trainer.create_accelerator_and_postprocess)
481
+ def create_accelerator_and_postprocess(self):
482
+ rank = int(os.environ.get("LOCAL_RANK", 0))
483
+ res = super().create_accelerator_and_postprocess()
484
+
485
+ if self.args.qlora is False:
486
+ return res
487
+
488
+ # the rest of this method override is specific to fsdp + qlora (for now)
489
+ sync_module_states = (
490
+ str_to_bool(os.environ.get("FSDP_SYNC_MODULE_STATES", "True")) == 1
491
+ )
492
+
493
+ mp_policy = None
494
+ amp = os.environ["ACCELERATE_MIXED_PRECISION"]
495
+ if amp == "fp16":
496
+ mp_policy = MixedPrecision(
497
+ param_dtype=torch.float32,
498
+ reduce_dtype=torch.float32,
499
+ buffer_dtype=torch.float32,
500
+ )
501
+ elif amp == "bf16":
502
+ mp_policy = MixedPrecision(
503
+ param_dtype=torch.float32,
504
+ reduce_dtype=torch.float32,
505
+ buffer_dtype=torch.float32,
506
+ )
507
+
508
+ # If somehow we figure out how we want to parameterize we want to autocast buffers...
509
+ # mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32)
510
+ # load_param_skip_names = ['inv_freq']
511
+
512
+ if self.is_fsdp_enabled:
513
+ wrapping_policy = get_wrapping_policy_factory(self.args.model_type)
514
+ fsdp_plugin = FullyShardedDataParallelPlugin(
515
+ auto_wrap_policy=wrapping_policy(),
516
+ cpu_offload=False,
517
+ use_orig_params=False,
518
+ limit_all_gathers=True,
519
+ param_init_fn=lambda module: module.to_empty(
520
+ device=torch.device("cuda"), recurse=False
521
+ )
522
+ if (rank != 0 and sync_module_states)
523
+ else None,
524
+ mixed_precision_policy=mp_policy,
525
+ )
526
+ self.accelerator.state.fsdp_plugin = fsdp_plugin
527
+
528
+ return res
529
+
530
 
531
  class AxolotlMambaTrainer(AxolotlTrainer):
532
  """
 
846
  if self.cfg.fsdp_config:
847
  training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
848
 
849
+ if self.cfg.adapter == "qlora":
850
+ training_arguments_kwargs["qlora"] = True
851
+
852
  # deepspeed
853
  if self.cfg.deepspeed:
854
  training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
src/axolotl/utils/bench.py CHANGED
@@ -24,9 +24,9 @@ def check_cuda_device(default_value):
24
  or not torch.cuda.is_available()
25
  or device == "auto"
26
  or torch.device(device).type == "cpu"
 
27
  ):
28
  return default_value
29
-
30
  return func(*args, **kwargs)
31
 
32
  return wrapper
 
24
  or not torch.cuda.is_available()
25
  or device == "auto"
26
  or torch.device(device).type == "cpu"
27
+ or torch.device(device).type == "meta"
28
  ):
29
  return default_value
 
30
  return func(*args, **kwargs)
31
 
32
  return wrapper
src/axolotl/utils/models.py CHANGED
@@ -1,13 +1,20 @@
1
  """Module for models and model loading"""
 
 
2
  import logging
3
  import math
4
  import os
5
- from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
 
6
 
7
  import addict
8
  import bitsandbytes as bnb
 
9
  import torch
10
  import transformers
 
 
 
11
  from peft import (
12
  LoftQConfig,
13
  PeftConfig,
@@ -16,6 +23,7 @@ from peft import (
16
  prepare_model_for_kbit_training,
17
  )
18
  from peft.tuners.lora import QuantLinear
 
19
  from transformers import ( # noqa: F401
20
  AddedToken,
21
  AutoConfig,
@@ -27,7 +35,9 @@ from transformers import ( # noqa: F401
27
  PreTrainedTokenizerBase,
28
  )
29
  from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
 
30
 
 
31
  from axolotl.models.mamba import fix_mamba_attn_for_loss
32
  from axolotl.monkeypatch.multipack import (
33
  SUPPORTED_MULTIPACK_MODEL_TYPES,
@@ -262,6 +272,117 @@ def load_tokenizer(cfg):
262
  return tokenizer
263
 
264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  def load_model(
266
  cfg: DictDefault,
267
  tokenizer: PreTrainedTokenizerBase,
@@ -394,7 +515,7 @@ def load_model(
394
 
395
  if max_memory is not None:
396
  # Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py
397
- from accelerate import infer_auto_device_map, init_empty_weights
398
 
399
  with init_empty_weights():
400
  model_canvas = AutoModelForCausalLM.from_config(model_config)
@@ -496,8 +617,78 @@ def load_model(
496
  model_kwargs["attn_implementation"] = "eager"
497
  model_config._attn_implementation = "eager" # pylint: disable=protected-access
498
 
 
 
 
 
 
 
499
  try:
500
- if (
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
  model_config.model_type == "llama"
502
  and not cfg.trust_remote_code
503
  and not cfg.gptq
@@ -613,7 +804,7 @@ def load_model(
613
  LOG.exception(err)
614
  raise err
615
 
616
- if isinstance(model, (PeftModel, PeftModelForCausalLM)):
617
  model = model.merge_and_unload()
618
 
619
  embeddings_len = (
@@ -692,6 +883,9 @@ def load_model(
692
  if cfg.adapter == "lora" and loftq_bits:
693
  skip_prepare_model_for_kbit_training = True
694
 
 
 
 
695
  if cfg.adapter in ["lora", "qlora"]:
696
  if cfg.gradient_checkpointing:
697
  model.gradient_checkpointing_enable()
@@ -706,7 +900,7 @@ def load_model(
706
 
707
  # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
708
  # convert them back to fp16/bf16 for flash-attn compatibility.
709
- if needs_fa2_dtype or cfg.flash_attention:
710
  LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
711
  for name, module in model.named_modules():
712
  if "norm" in name:
@@ -724,7 +918,12 @@ def load_model(
724
  else:
725
  model, lora_config = load_adapter(model, cfg, cfg.adapter)
726
 
727
- if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit):
 
 
 
 
 
728
  # TODO revaldate this conditional
729
  model.to(f"cuda:{cfg.local_rank}")
730
 
@@ -813,6 +1012,30 @@ def find_all_linear_names(model):
813
  return list(lora_module_names)
814
 
815
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
816
  def load_lora(model, cfg, inference=False, config_only=False):
817
  # type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
818
 
@@ -849,6 +1072,11 @@ def load_lora(model, cfg, inference=False, config_only=False):
849
  if config_only:
850
  return None, lora_config
851
 
 
 
 
 
 
852
  if cfg.lora_model_dir:
853
  LOG.debug("Loading pretrained PEFT - LoRA")
854
  model_kwargs: Any = {}
@@ -864,6 +1092,9 @@ def load_lora(model, cfg, inference=False, config_only=False):
864
  else:
865
  model = get_peft_model(model, lora_config)
866
 
867
- model.print_trainable_parameters()
 
 
 
868
 
869
  return model, lora_config
 
1
  """Module for models and model loading"""
2
+ # pylint: disable=too-many-lines
3
+
4
  import logging
5
  import math
6
  import os
7
+ import types
8
+ from typing import Any, Dict, List, Optional, Tuple, Type, Union # noqa: F401
9
 
10
  import addict
11
  import bitsandbytes as bnb
12
+ import safetensors
13
  import torch
14
  import transformers
15
+ from accelerate import init_empty_weights
16
+ from bitsandbytes.nn import Linear4bit, Params4bit
17
+ from fastcore.parallel import parallel
18
  from peft import (
19
  LoftQConfig,
20
  PeftConfig,
 
23
  prepare_model_for_kbit_training,
24
  )
25
  from peft.tuners.lora import QuantLinear
26
+ from torch import Tensor, nn
27
  from transformers import ( # noqa: F401
28
  AddedToken,
29
  AutoConfig,
 
35
  PreTrainedTokenizerBase,
36
  )
37
  from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
38
+ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
39
 
40
+ from axolotl.core.policies.auto_wrap import SUPPORTED_AUTO_WRAP_MODEL_TYPES
41
  from axolotl.models.mamba import fix_mamba_attn_for_loss
42
  from axolotl.monkeypatch.multipack import (
43
  SUPPORTED_MULTIPACK_MODEL_TYPES,
 
272
  return tokenizer
273
 
274
 
275
+ def replace_linear(
276
+ model: nn.Module,
277
+ linear_replacement: Type[nn.Module],
278
+ quant_config: Union[dict, None] = None,
279
+ skip_modules=None,
280
+ **kwargs,
281
+ ):
282
+ """
283
+ Replace linear modules with a new Linear module.
284
+ Parameters:
285
+ model (`torch.nn.Module`):
286
+ Input model or `torch.nn.Module` as the function is run recursively.
287
+ linear_replacement (`torch.nn.Module`):
288
+ The linear module that replaces the old one. Only expects standard arguments.
289
+ If other arguments need to be passed, use a lambda.
290
+ skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
291
+ List of modules names not to convert. Defaults to `lm_head`.
292
+ """
293
+ if skip_modules is None:
294
+ skip_modules = ["lm_head"]
295
+ for name, module in model.named_children():
296
+ if len(list(module.children())) > 0:
297
+ replace_linear(
298
+ module, linear_replacement, quant_config, skip_modules, **kwargs
299
+ )
300
+
301
+ if isinstance(module, torch.nn.Linear) and name not in skip_modules:
302
+ if issubclass(linear_replacement, Linear4bit):
303
+ model._modules[ # pylint: disable=protected-access
304
+ name
305
+ ] = linear_replacement(
306
+ module.in_features,
307
+ module.out_features,
308
+ module.bias is not None,
309
+ **kwargs,
310
+ )
311
+ else:
312
+ raise ValueError(
313
+ f"Unsupported linear replacement: {type(linear_replacement)}"
314
+ )
315
+ return model
316
+
317
+
318
+ def load_and_quantize(
319
+ module: nn.Module,
320
+ name: str,
321
+ value: Tensor,
322
+ device: torch.device = None,
323
+ dtype: torch.dtype = None,
324
+ skip_names: Optional[List[str]] = None,
325
+ is_meta_rank: bool = False,
326
+ low_memory: bool = True,
327
+ verbose: bool = False,
328
+ quant_method: str = "bnb",
329
+ ):
330
+ """
331
+ Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
332
+
333
+ Quantizes `Params4bit` on `device` then places on "cpu" if low_memory=True or "meta" if is_meta_rank=True.
334
+ """
335
+
336
+ if skip_names is None:
337
+ skip_names = []
338
+
339
+ def place_on_device(value):
340
+ if is_meta_rank:
341
+ device = "meta"
342
+ elif low_memory:
343
+ device = "cpu"
344
+ else:
345
+ device = "cuda"
346
+ return value.to(device=device, dtype=dtype)
347
+
348
+ if any(skip_name in name for skip_name in skip_names):
349
+ if verbose:
350
+ print(f"Skipping {name} because it is in skip_names")
351
+ return
352
+
353
+ module_key, _, value_key = name.rpartition(".")
354
+ try:
355
+ submodule = module.get_submodule(module_key)
356
+ except AttributeError as exc:
357
+ print(f"Module {module_key} not found:\n{exc}")
358
+ return
359
+
360
+ try:
361
+ if quant_method == "bnb":
362
+ param = submodule.get_parameter(value_key)
363
+ if isinstance(param, Params4bit):
364
+ # With `sync_module_states=True`, a meta device Params4bit needs to be the same
365
+ # shape as the quantized Params4bit with an initialized quant_state. However,
366
+ # FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
367
+ # workaround quantizes Params4bit to initialize quant_state on all ranks, then
368
+ # replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
369
+ value = type(param)(
370
+ value.to(device=device, dtype=dtype).data, **param.__dict__
371
+ ).cuda(device)
372
+ if is_meta_rank:
373
+ value = type(param)(value.data.to("meta"), **value.__dict__)
374
+ elif low_memory:
375
+ value = type(param)(value.data.to("cpu"), **value.__dict__)
376
+ else:
377
+ value = type(param)(place_on_device(value).data)
378
+
379
+ except AttributeError:
380
+ # it's a buffer
381
+ value = place_on_device(value)
382
+
383
+ setattr(submodule, value_key, value)
384
+
385
+
386
  def load_model(
387
  cfg: DictDefault,
388
  tokenizer: PreTrainedTokenizerBase,
 
515
 
516
  if max_memory is not None:
517
  # Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py
518
+ from accelerate import infer_auto_device_map
519
 
520
  with init_empty_weights():
521
  model_canvas = AutoModelForCausalLM.from_config(model_config)
 
617
  model_kwargs["attn_implementation"] = "eager"
618
  model_config._attn_implementation = "eager" # pylint: disable=protected-access
619
 
620
+ qlora_fsdp = (
621
+ cfg.fsdp
622
+ and cfg.adapter == "qlora"
623
+ and model_config.model_type in SUPPORTED_AUTO_WRAP_MODEL_TYPES
624
+ )
625
+
626
  try:
627
+ if qlora_fsdp:
628
+ if cfg.bf16 or cfg.bfloat16:
629
+ torch_dtype, compute_dtype = torch.float32, torch.bfloat16
630
+ elif cfg.fp16 or cfg.float16:
631
+ torch_dtype, compute_dtype = torch.float32, torch.float16
632
+ else:
633
+ torch_dtype, compute_dtype = torch.float32, torch.float16
634
+
635
+ with init_empty_weights():
636
+ LOG.info("Loading model with empty weights.")
637
+ model = AutoModelForCausalLM.from_config(model_config)
638
+ model.model = replace_linear(
639
+ model.model,
640
+ Linear4bit,
641
+ compute_dtype=compute_dtype,
642
+ quant_type="nf4",
643
+ quant_storage=torch_dtype,
644
+ )
645
+
646
+ model.is_loaded_in_4bit = True
647
+
648
+ # Grab the safetensors files that hold the weights
649
+ try:
650
+ idx = hub.cached_file(base_model, SAFE_WEIGHTS_INDEX_NAME)
651
+ files, _ = hub.get_checkpoint_shard_files(base_model, idx)
652
+ except OSError:
653
+ try:
654
+ # This means the model doesn't have a model.safetensors.index.json because it is not sharded
655
+ files = []
656
+ files.append(hub.cached_file(base_model, SAFE_WEIGHTS_NAME))
657
+ except OSError as exc:
658
+ # This means the model probably doesn't have a safetensors file
659
+ raise exc
660
+
661
+ # Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly
662
+ # and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage
663
+ def load_and_quantize_parallel(name_param, model, **kwargs):
664
+ name, param = name_param
665
+ load_and_quantize(model, name, param, **kwargs)
666
+
667
+ param_count = sum((p.numel() for n, p in model.named_parameters()))
668
+ for filename in files:
669
+ weights = safetensors.torch.load_file(filename)
670
+ quant_method = "bnb"
671
+ devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
672
+ left = int(os.cpu_count() / torch.cuda.device_count())
673
+ right = int(
674
+ 8 * (devprops.total_memory / 1e9 / 40) * (70 / (param_count / 1e9))
675
+ )
676
+ n_workers = min(left, right)
677
+ parallel(
678
+ load_and_quantize_parallel,
679
+ weights.items(),
680
+ n_workers=n_workers,
681
+ threadpool=True,
682
+ model=model,
683
+ dtype=torch_dtype,
684
+ device=cfg.local_rank,
685
+ skip_names=[],
686
+ is_meta_rank=(cfg.local_rank != 0),
687
+ verbose=False,
688
+ quant_method=quant_method,
689
+ )
690
+
691
+ elif (
692
  model_config.model_type == "llama"
693
  and not cfg.trust_remote_code
694
  and not cfg.gptq
 
804
  LOG.exception(err)
805
  raise err
806
 
807
+ if isinstance(model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp:
808
  model = model.merge_and_unload()
809
 
810
  embeddings_len = (
 
883
  if cfg.adapter == "lora" and loftq_bits:
884
  skip_prepare_model_for_kbit_training = True
885
 
886
+ if qlora_fsdp:
887
+ skip_prepare_model_for_kbit_training = True
888
+
889
  if cfg.adapter in ["lora", "qlora"]:
890
  if cfg.gradient_checkpointing:
891
  model.gradient_checkpointing_enable()
 
900
 
901
  # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
902
  # convert them back to fp16/bf16 for flash-attn compatibility.
903
+ if (needs_fa2_dtype or cfg.flash_attention) and not qlora_fsdp:
904
  LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
905
  for name, module in model.named_modules():
906
  if "norm" in name:
 
918
  else:
919
  model, lora_config = load_adapter(model, cfg, cfg.adapter)
920
 
921
+ if (
922
+ cfg.ddp
923
+ and not load_in_8bit
924
+ and not (cfg.rl and cfg.load_in_4bit)
925
+ and not qlora_fsdp
926
+ ):
927
  # TODO revaldate this conditional
928
  model.to(f"cuda:{cfg.local_rank}")
929
 
 
1012
  return list(lora_module_names)
1013
 
1014
 
1015
+ def setup_quantized_meta_for_peft(model: nn.Module):
1016
+ """Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device"""
1017
+
1018
+ def temp_to_method(self, *args, **kwargs): # pylint: disable=unused-argument
1019
+ return self
1020
+
1021
+ for param in model.parameters():
1022
+ if isinstance(param, Params4bit):
1023
+ param.quant_state._orig_to = ( # pylint: disable=protected-access
1024
+ param.quant_state.to
1025
+ )
1026
+ param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)
1027
+
1028
+
1029
+ def setup_quantized_peft_meta_for_training(model: nn.Module):
1030
+ """Replaces dummy `quant_state.to` method with the original function to allow training to continue"""
1031
+ for param in model.parameters():
1032
+ if isinstance(param, Params4bit) and hasattr(param.quant_state, "_orig_to"):
1033
+ param.quant_state.to = (
1034
+ param.quant_state._orig_to # pylint: disable=protected-access
1035
+ )
1036
+ param.quant_state._orig_to = None # pylint: disable=protected-access
1037
+
1038
+
1039
  def load_lora(model, cfg, inference=False, config_only=False):
1040
  # type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
1041
 
 
1072
  if config_only:
1073
  return None, lora_config
1074
 
1075
+ rank = int(os.environ.get("LOCAL_RANK", 0))
1076
+
1077
+ if cfg.fsdp and cfg.adapter == "qlora" and rank != 0:
1078
+ setup_quantized_meta_for_peft(model)
1079
+
1080
  if cfg.lora_model_dir:
1081
  LOG.debug("Loading pretrained PEFT - LoRA")
1082
  model_kwargs: Any = {}
 
1092
  else:
1093
  model = get_peft_model(model, lora_config)
1094
 
1095
+ if rank == 0:
1096
+ model.print_trainable_parameters()
1097
+ elif cfg.fsdp and cfg.adapter == "qlora":
1098
+ setup_quantized_peft_meta_for_training(model)
1099
 
1100
  return model, lora_config