winglian commited on
Commit
2255bb7
1 Parent(s): 55baef0

support llama-adapter zero init attention

Browse files
Files changed (2) hide show
  1. scripts/finetune.py +4 -4
  2. src/axolotl/utils/models.py +50 -21
scripts/finetune.py CHANGED
@@ -146,8 +146,8 @@ def train(
146
  cfg.bf16 = False
147
 
148
  # Load the model and tokenizer
149
- logging.info("loading model, tokenizer, and lora_config...")
150
- model, tokenizer, lora_config = load_model(
151
  cfg.base_model,
152
  cfg.base_model_config,
153
  cfg.model_type,
@@ -186,9 +186,9 @@ def train(
186
  model = torch.compile(model)
187
 
188
  # go ahead and presave, so we have the adapter config available to inspect
189
- if lora_config:
190
  logging.info(f"Pre-saving adapter config to {cfg.output_dir}")
191
- lora_config.save_pretrained(cfg.output_dir)
192
 
193
  # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
194
  if cfg.local_rank == 0:
 
146
  cfg.bf16 = False
147
 
148
  # Load the model and tokenizer
149
+ logging.info("loading model, tokenizer, and peft_config...")
150
+ model, tokenizer, peft_config = load_model(
151
  cfg.base_model,
152
  cfg.base_model_config,
153
  cfg.model_type,
 
186
  model = torch.compile(model)
187
 
188
  # go ahead and presave, so we have the adapter config available to inspect
189
+ if peft_config:
190
  logging.info(f"Pre-saving adapter config to {cfg.output_dir}")
191
+ peft_config.save_pretrained(cfg.output_dir)
192
 
193
  # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
194
  if cfg.local_rank == 0:
src/axolotl/utils/models.py CHANGED
@@ -195,11 +195,41 @@ def load_adapter(model, cfg, adapter):
195
  return model, None
196
  if adapter == "lora":
197
  return load_lora(model, cfg)
198
- # TODO support Llama-Adapter once merged into peft https://github.com/huggingface/peft/pulls
 
199
 
200
  raise NotImplementedError(f"{adapter} peft adapter not available")
201
 
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  def load_lora(model, cfg):
204
  # type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
205
 
@@ -211,27 +241,26 @@ def load_lora(model, cfg):
211
 
212
  lora_config = None
213
 
214
- if cfg.adapter == "lora":
215
- lora_config = LoraConfig(
216
- r=cfg.lora_r,
217
- lora_alpha=cfg.lora_alpha,
218
- target_modules=cfg.lora_target_modules,
219
- lora_dropout=cfg.lora_dropout,
220
- fan_in_fan_out=cfg.lora_fan_in_fan_out,
221
- bias="none",
222
- task_type="CAUSAL_LM",
223
- )
224
 
225
- if cfg.lora_model_dir:
226
- model = PeftModel.from_pretrained(
227
- model,
228
- cfg.lora_model_dir,
229
- device_map=cfg.device_map,
230
- torch_dtype=torch.float16,
231
- )
232
- else:
233
- model = get_peft_model(model, lora_config)
234
 
235
- model.print_trainable_parameters()
236
 
237
  return model, lora_config
 
195
  return model, None
196
  if adapter == "lora":
197
  return load_lora(model, cfg)
198
+ if adapter == "llama-adapter":
199
+ return load_llama_adapter(model, cfg)
200
 
201
  raise NotImplementedError(f"{adapter} peft adapter not available")
202
 
203
 
204
+ def load_llama_adapter(model, cfg):
205
+ # type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
206
+ from peft import (
207
+ AdaptionPromptConfig,
208
+ get_peft_model,
209
+ PeftModel,
210
+ )
211
+
212
+ peft_config = AdaptionPromptConfig(
213
+ adapter_layers=cfg.peft_adapter.layers, # layers (L)
214
+ adapter_len=cfg.peft_adapter.len, # prompt length (K)
215
+ task_type="CAUSAL_LM",
216
+ )
217
+
218
+ if cfg.peft_model_dir:
219
+ model = PeftModel.from_pretrained(
220
+ model,
221
+ cfg.lora_model_dir,
222
+ device_map=cfg.device_map,
223
+ torch_dtype=torch.float16,
224
+ )
225
+ else:
226
+ model = get_peft_model(model, peft_config)
227
+
228
+ model.print_trainable_parameters()
229
+
230
+ return model, peft_config
231
+
232
+
233
  def load_lora(model, cfg):
234
  # type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
235
 
 
241
 
242
  lora_config = None
243
 
244
+ lora_config = LoraConfig(
245
+ r=cfg.lora_r,
246
+ lora_alpha=cfg.lora_alpha,
247
+ target_modules=cfg.lora_target_modules,
248
+ lora_dropout=cfg.lora_dropout,
249
+ fan_in_fan_out=cfg.lora_fan_in_fan_out,
250
+ bias="none",
251
+ task_type="CAUSAL_LM",
252
+ )
 
253
 
254
+ if cfg.lora_model_dir:
255
+ model = PeftModel.from_pretrained(
256
+ model,
257
+ cfg.lora_model_dir,
258
+ device_map=cfg.device_map,
259
+ torch_dtype=torch.float16,
260
+ )
261
+ else:
262
+ model = get_peft_model(model, lora_config)
263
 
264
+ model.print_trainable_parameters()
265
 
266
  return model, lora_config