winglian commited on
Commit
a5bf838
1 Parent(s): a4f1241

add logging and make sure model unloads to float16

Browse files
scripts/finetune.py CHANGED
@@ -176,6 +176,7 @@ def train(
176
  if "merge_lora" in kwargs and cfg.adapter is not None:
177
  logging.info("running merge of LoRA with base model")
178
  model = model.merge_and_unload()
 
179
 
180
  if cfg.local_rank == 0:
181
  logging.info("saving merged model")
 
176
  if "merge_lora" in kwargs and cfg.adapter is not None:
177
  logging.info("running merge of LoRA with base model")
178
  model = model.merge_and_unload()
179
+ model.to(dtype=torch.float16)
180
 
181
  if cfg.local_rank == 0:
182
  logging.info("saving merged model")
src/axolotl/utils/validation.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  def validate_config(cfg):
2
  if cfg.adapter == "qlora":
3
  if cfg.merge_lora:
@@ -9,6 +12,9 @@ def validate_config(cfg):
9
  assert cfg.load_in_8bit is False
10
  assert cfg.load_4bit is False
11
  assert cfg.load_in_4bit is True
 
 
 
12
  # TODO
13
  # MPT 7b
14
  # https://github.com/facebookresearch/bitsandbytes/issues/25
 
1
+ import logging
2
+
3
+
4
  def validate_config(cfg):
5
  if cfg.adapter == "qlora":
6
  if cfg.merge_lora:
 
12
  assert cfg.load_in_8bit is False
13
  assert cfg.load_4bit is False
14
  assert cfg.load_in_4bit is True
15
+ if cfg.load_in_8bit and cfg.adapter == "lora":
16
+ logging.warning("we recommend setting `load_in_8bit: true`")
17
+
18
  # TODO
19
  # MPT 7b
20
  # https://github.com/facebookresearch/bitsandbytes/issues/25