dg-kalle Karl-Johan Alm winglian commited on
Commit
bdfefaf
1 Parent(s): 63fb3eb

feature: better device mapping for large models (#918)

Browse files

* fix: improved memory handling when model is bigger than existing VRAM

* feature: add lora_on_cpu flag to do LoRA loading on CPU (RAM)

For big models where the models are taking up the entire GPU VRAM, the LoRA part will fail unless it is loaded on CPU only.

* doc: add README

* fix: enable progress bars in do_merge_lora()

* doc: mention gpu_memory_limit and lora_on_cpu in merge part of README

* Update src/axolotl/utils/models.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* fix: remove deletion of removed model_kwargs key

* fix: validate that gpu_memory_limit and max_memory are not both set

---------

Co-authored-by: Karl-Johan Alm <kalle@gmail.com>
Co-authored-by: Wing Lian <wing.lian@gmail.com>

README.md CHANGED
@@ -550,6 +550,11 @@ tf32: true # require >=ampere
550
  bfloat16: true # require >=ampere
551
  float16: true
552
 
 
 
 
 
 
553
  # A list of one or more datasets to finetune the model with
554
  datasets:
555
  # HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
@@ -1042,12 +1047,14 @@ The following command will merge your LORA adapater with your base model. You c
1042
  python3 -m axolotl.cli.merge_lora your_config.yml --lora_model_dir="./completed-model"
1043
  ```
1044
 
1045
- If you run out of CUDA memory, you can try to merge in system RAM with
1046
 
1047
  ```bash
1048
  CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ...
1049
  ```
1050
 
 
 
1051
  ## Common Errors 🧰
1052
 
1053
  See also the [FAQ's](./docs/faq.md).
 
550
  bfloat16: true # require >=ampere
551
  float16: true
552
 
553
+ # Limit the memory for all available GPUs to this amount (if an integer, expressed in gigabytes); default: unset
554
+ gpu_memory_limit: 20GiB
555
+ # Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge
556
+ lora_on_cpu: true
557
+
558
  # A list of one or more datasets to finetune the model with
559
  datasets:
560
  # HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
 
1047
  python3 -m axolotl.cli.merge_lora your_config.yml --lora_model_dir="./completed-model"
1048
  ```
1049
 
1050
+ You may need to use the `gpu_memory_limit` and/or `lora_on_cpu` config options to avoid running out of memory. If you still run out of CUDA memory, you can try to merge in system RAM with
1051
 
1052
  ```bash
1053
  CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ...
1054
  ```
1055
 
1056
+ although this will be very slow, and using the config options above are recommended instead.
1057
+
1058
  ## Common Errors 🧰
1059
 
1060
  See also the [FAQ's](./docs/faq.md).
src/axolotl/cli/__init__.py CHANGED
@@ -73,7 +73,7 @@ def do_merge_lora(
73
  safe_serialization = cfg.save_safetensors is True
74
 
75
  LOG.info("running merge of LoRA with base model")
76
- model = model.merge_and_unload()
77
  model.to(dtype=cfg.torch_dtype)
78
 
79
  if cfg.local_rank == 0:
@@ -81,6 +81,7 @@ def do_merge_lora(
81
  model.save_pretrained(
82
  str(Path(cfg.output_dir) / "merged"),
83
  safe_serialization=safe_serialization,
 
84
  )
85
  tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
86
 
 
73
  safe_serialization = cfg.save_safetensors is True
74
 
75
  LOG.info("running merge of LoRA with base model")
76
+ model = model.merge_and_unload(progressbar=True)
77
  model.to(dtype=cfg.torch_dtype)
78
 
79
  if cfg.local_rank == 0:
 
81
  model.save_pretrained(
82
  str(Path(cfg.output_dir) / "merged"),
83
  safe_serialization=safe_serialization,
84
+ progressbar=True,
85
  )
86
  tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
87
 
src/axolotl/utils/config.py CHANGED
@@ -457,6 +457,11 @@ def validate_config(cfg):
457
  "lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`."
458
  )
459
 
 
 
 
 
 
460
  # TODO
461
  # MPT 7b
462
  # https://github.com/facebookresearch/bitsandbytes/issues/25
 
457
  "lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`."
458
  )
459
 
460
+ if cfg.max_memory is not None and cfg.gpu_memory_limit is not None:
461
+ raise ValueError(
462
+ "max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
463
+ )
464
+
465
  # TODO
466
  # MPT 7b
467
  # https://github.com/facebookresearch/bitsandbytes/issues/25
src/axolotl/utils/models.py CHANGED
@@ -2,7 +2,7 @@
2
  import logging
3
  import math
4
  import os
5
- from typing import Optional, Tuple # noqa: F401
6
 
7
  import addict
8
  import bitsandbytes as bnb
@@ -288,8 +288,37 @@ def load_model(
288
 
289
  model_kwargs = {}
290
 
291
- model_kwargs["device_map"] = cfg.device_map
292
- model_kwargs["max_memory"] = cfg.max_memory
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  model_kwargs["torch_dtype"] = cfg.torch_dtype
294
  # TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
295
  # if cfg.rl:
@@ -426,7 +455,6 @@ def load_model(
426
  model_kwargs["device"] = torch.cuda.current_device()
427
  del model_kwargs["torch_dtype"]
428
  del model_kwargs["device_map"]
429
- del model_kwargs["max_memory"]
430
 
431
  model = MambaLMHeadModel.from_pretrained(
432
  base_model,
@@ -683,10 +711,15 @@ def load_lora(model, cfg, inference=False):
683
 
684
  if cfg.lora_model_dir:
685
  LOG.debug("Loading pretained PEFT - LoRA")
 
 
 
 
686
  model = PeftModel.from_pretrained(
687
  model,
688
  cfg.lora_model_dir,
689
  is_trainable=(not inference),
 
690
  )
691
  else:
692
  model = get_peft_model(model, lora_config)
 
2
  import logging
3
  import math
4
  import os
5
+ from typing import Any, Optional, Tuple # noqa: F401
6
 
7
  import addict
8
  import bitsandbytes as bnb
 
288
 
289
  model_kwargs = {}
290
 
291
+ max_memory = cfg.max_memory
292
+ device_map = cfg.device_map
293
+
294
+ if cfg.gpu_memory_limit:
295
+ gpu_memory_limit = (
296
+ str(cfg.gpu_memory_limit) + "GiB"
297
+ if isinstance(cfg.gpu_memory_limit, int)
298
+ else cfg.gpu_memory_limit
299
+ )
300
+
301
+ max_memory = {}
302
+ for i in range(torch.cuda.device_count()):
303
+ max_memory[i] = gpu_memory_limit
304
+ max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything
305
+
306
+ if max_memory is not None:
307
+ # Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py
308
+ from accelerate import infer_auto_device_map, init_empty_weights
309
+
310
+ with init_empty_weights():
311
+ model_canvas = AutoModelForCausalLM.from_config(model_config)
312
+ model_canvas.tie_weights()
313
+ device_map = infer_auto_device_map(
314
+ model_canvas,
315
+ max_memory=max_memory,
316
+ dtype=cfg.torch_dtype,
317
+ )
318
+ # We can discard max_memory now as we have a device map set up for us
319
+ max_memory = None
320
+
321
+ model_kwargs["device_map"] = device_map
322
  model_kwargs["torch_dtype"] = cfg.torch_dtype
323
  # TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
324
  # if cfg.rl:
 
455
  model_kwargs["device"] = torch.cuda.current_device()
456
  del model_kwargs["torch_dtype"]
457
  del model_kwargs["device_map"]
 
458
 
459
  model = MambaLMHeadModel.from_pretrained(
460
  base_model,
 
711
 
712
  if cfg.lora_model_dir:
713
  LOG.debug("Loading pretained PEFT - LoRA")
714
+ model_kwargs: Any = {}
715
+ if cfg.lora_on_cpu:
716
+ model_kwargs["max_memory"] = {"cpu": "256GiB"}
717
+ model_kwargs["device_map"] = {"": "cpu"}
718
  model = PeftModel.from_pretrained(
719
  model,
720
  cfg.lora_model_dir,
721
  is_trainable=(not inference),
722
+ **model_kwargs,
723
  )
724
  else:
725
  model = get_peft_model(model, lora_config)