winglian commited on
Commit
02af082
·
unverified ·
1 Parent(s): 4155e99

Jamba (#1451)

Browse files

* fixes for larger models

* add qlora example for deepspeed

* add readme for jamba

examples/jamba/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Jamba
2
+
3
+ qlora w/ deepspeed needs at least 2x GPUs and 35GiB VRAM per GPU
4
+
5
+ qlora single-gpu - training will start, but loss is off by an order of magnitude
examples/jamba/qlora_deepspeed.yaml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: ai21labs/Jamba-v0.1
2
+ trust_remote_code: true
3
+
4
+ load_in_8bit: false
5
+ load_in_4bit: true
6
+ strict: false
7
+
8
+ datasets:
9
+ - path: mhenrichsen/alpaca_2k_test
10
+ type: alpaca
11
+ dataset_prepared_path:
12
+ val_set_size: 0.0
13
+ output_dir: ./out
14
+
15
+ sequence_len: 4096
16
+ sample_packing: false
17
+ pad_to_sequence_len: false
18
+ eval_sample_packing: false
19
+
20
+ wandb_project:
21
+ wandb_entity:
22
+ wandb_watch:
23
+ wandb_name:
24
+ wandb_log_model:
25
+
26
+ adapter: qlora
27
+ lora_r: 8
28
+ lora_alpha: 16
29
+ lora_dropout: 0.05
30
+ lora_target_linear: true
31
+
32
+ low_cpu_mem_usage: true
33
+ gradient_accumulation_steps: 4
34
+ micro_batch_size: 1
35
+ num_epochs: 2
36
+ optimizer: paged_adamw_8bit
37
+ lr_scheduler: cosine
38
+ learning_rate: 0.00001
39
+
40
+ train_on_inputs: false
41
+ group_by_length: false
42
+ bf16: auto
43
+ fp16:
44
+ tf32: false
45
+
46
+ gradient_checkpointing: true
47
+ gradient_checkpointing_kwargs:
48
+ use_reentrant: false
49
+ early_stopping_patience:
50
+ resume_from_checkpoint:
51
+ local_rank:
52
+ logging_steps: 1
53
+ xformers_attention:
54
+ flash_attention: true
55
+
56
+ warmup_steps: 10
57
+ evals_per_epoch:
58
+ saves_per_epoch: 1
59
+ debug:
60
+ deepspeed: deepspeed_configs/zero2.json
61
+ weight_decay: 0.0
62
+ special_tokens:
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -533,6 +533,7 @@ class AxolotlInputConfig(
533
  Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
534
  ] = None
535
  gpu_memory_limit: Optional[Union[int, str]] = None
 
536
 
537
  chat_template: Optional[ChatTemplate] = None
538
  default_system_message: Optional[str] = None
 
533
  Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
534
  ] = None
535
  gpu_memory_limit: Optional[Union[int, str]] = None
536
+ low_cpu_mem_usage: Optional[bool] = None
537
 
538
  chat_template: Optional[ChatTemplate] = None
539
  default_system_message: Optional[str] = None
src/axolotl/utils/models.py CHANGED
@@ -402,7 +402,9 @@ def load_model(
402
  from accelerate import infer_auto_device_map
403
 
404
  with init_empty_weights():
405
- model_canvas = AutoModelForCausalLM.from_config(model_config)
 
 
406
  model_canvas.tie_weights()
407
  device_map = infer_auto_device_map(
408
  model_canvas,
@@ -502,6 +504,9 @@ def load_model(
502
  model_kwargs["attn_implementation"] = "eager"
503
  model_config._attn_implementation = "eager" # pylint: disable=protected-access
504
 
 
 
 
505
  qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora"
506
 
507
  try:
 
402
  from accelerate import infer_auto_device_map
403
 
404
  with init_empty_weights():
405
+ model_canvas = AutoModelForCausalLM.from_config(
406
+ model_config, trust_remote_code=cfg.trust_remote_code or False
407
+ )
408
  model_canvas.tie_weights()
409
  device_map = infer_auto_device_map(
410
  model_canvas,
 
504
  model_kwargs["attn_implementation"] = "eager"
505
  model_config._attn_implementation = "eager" # pylint: disable=protected-access
506
 
507
+ if cfg.low_cpu_mem_usage:
508
+ model_kwargs["low_cpu_mem_usage"] = True
509
+
510
  qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora"
511
 
512
  try:
src/axolotl/utils/trainer.py CHANGED
@@ -312,6 +312,8 @@ def setup_fsdp_envs(cfg):
312
  os.environ["FSDP_USE_ORIG_PARAMS"] = "true"
313
  if cfg.fsdp_config.fsdp_state_dict_type:
314
  os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type
 
 
315
  if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap:
316
  os.environ[
317
  "FSDP_TRANSFORMER_CLS_TO_WRAP"
 
312
  os.environ["FSDP_USE_ORIG_PARAMS"] = "true"
313
  if cfg.fsdp_config.fsdp_state_dict_type:
314
  os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type
315
+ if cfg.fsdp_config.fsdp_auto_wrap_policy:
316
+ os.environ["FSDP_AUTO_WRAP_POLICY"] = cfg.fsdp_config.fsdp_auto_wrap_policy
317
  if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap:
318
  os.environ[
319
  "FSDP_TRANSFORMER_CLS_TO_WRAP"