Spaces:
Runtime error
Runtime error
# Usage: deepspeed train_lora.py --deepspeed <$PATH_TO_DEEPSPEED_CONFIG> | |
# Adapted from tatsu-lab@stanford_alpaca. Below is the original copyright: | |
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from dataclasses import dataclass, field | |
import logging | |
import pathlib | |
import typing | |
import os | |
from deepspeed import zero | |
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus | |
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training | |
import transformers | |
from transformers import Trainer, BitsAndBytesConfig, deepspeed | |
import torch | |
from fastchat.train.train import ( | |
DataArguments, | |
ModelArguments, | |
make_supervised_data_module, | |
) | |
from fastchat.train.llama_flash_attn_monkey_patch import ( | |
replace_llama_attn_with_flash_attn, | |
) | |
class TrainingArguments(transformers.TrainingArguments): | |
cache_dir: typing.Optional[str] = field(default=None) | |
optim: str = field(default="adamw_torch") | |
model_max_length: int = field( | |
default=512, | |
metadata={ | |
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." | |
}, | |
) | |
flash_attn: bool = False | |
class LoraArguments: | |
lora_r: int = 8 | |
lora_alpha: int = 16 | |
lora_dropout: float = 0.05 | |
lora_target_modules: typing.List[str] = field( | |
default_factory=lambda: ["q_proj", "v_proj"] | |
) | |
lora_weight_path: str = "" | |
lora_bias: str = "none" | |
q_lora: bool = False | |
def maybe_zero_3(param): | |
if hasattr(param, "ds_id"): | |
assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE | |
with zero.GatheredParameters([param]): | |
param = param.data.detach().cpu().clone() | |
else: | |
param = param.detach().cpu().clone() | |
return param | |
# Borrowed from peft.utils.get_peft_model_state_dict | |
def get_peft_state_maybe_zero_3(named_params, bias): | |
if bias == "none": | |
to_return = {k: t for k, t in named_params if "lora_" in k} | |
elif bias == "all": | |
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} | |
elif bias == "lora_only": | |
to_return = {} | |
maybe_lora_bias = {} | |
lora_bias_names = set() | |
for k, t in named_params: | |
if "lora_" in k: | |
to_return[k] = t | |
bias_name = k.split("lora_")[0] + "bias" | |
lora_bias_names.add(bias_name) | |
elif "bias" in k: | |
maybe_lora_bias[k] = t | |
for k, t in maybe_lora_bias: | |
if bias_name in lora_bias_names: | |
to_return[bias_name] = t | |
else: | |
raise NotImplementedError | |
to_return = {k: maybe_zero_3(v) for k, v in to_return.items()} | |
return to_return | |
def train(): | |
parser = transformers.HfArgumentParser( | |
(ModelArguments, DataArguments, TrainingArguments, LoraArguments) | |
) | |
( | |
model_args, | |
data_args, | |
training_args, | |
lora_args, | |
) = parser.parse_args_into_dataclasses() | |
if training_args.flash_attn: | |
replace_llama_attn_with_flash_attn() | |
device_map = None | |
world_size = int(os.environ.get("WORLD_SIZE", 1)) | |
ddp = world_size != 1 | |
if lora_args.q_lora: | |
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None | |
if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled(): | |
logging.warning( | |
"FSDP and ZeRO3 are both currently incompatible with QLoRA." | |
) | |
compute_dtype = ( | |
torch.float16 | |
if training_args.fp16 | |
else (torch.bfloat16 if training_args.bf16 else torch.float32) | |
) | |
model = transformers.AutoModelForCausalLM.from_pretrained( | |
model_args.model_name_or_path, | |
cache_dir=training_args.cache_dir, | |
device_map=device_map, | |
quantization_config=BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=compute_dtype, | |
) | |
if lora_args.q_lora | |
else None, | |
) | |
lora_config = LoraConfig( | |
r=lora_args.lora_r, | |
lora_alpha=lora_args.lora_alpha, | |
target_modules=lora_args.lora_target_modules, | |
lora_dropout=lora_args.lora_dropout, | |
bias=lora_args.lora_bias, | |
task_type="CAUSAL_LM", | |
) | |
if lora_args.q_lora: | |
model = prepare_model_for_kbit_training( | |
model, use_gradient_checkpointing=training_args.gradient_checkpointing | |
) | |
if not ddp and torch.cuda.device_count() > 1: | |
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available | |
model.is_parallelizable = True | |
model.model_parallel = True | |
model = get_peft_model(model, lora_config) | |
if training_args.flash_attn: | |
for name, module in model.named_modules(): | |
if "norm" in name: | |
module = module.to(compute_dtype) | |
if "lm_head" in name or "embed_tokens" in name: | |
if hasattr(module, "weight"): | |
module = module.to(compute_dtype) | |
if training_args.deepspeed is not None and training_args.local_rank == 0: | |
model.print_trainable_parameters() | |
if training_args.gradient_checkpointing: | |
model.enable_input_require_grads() | |
tokenizer = transformers.AutoTokenizer.from_pretrained( | |
model_args.model_name_or_path, | |
cache_dir=training_args.cache_dir, | |
model_max_length=training_args.model_max_length, | |
padding_side="right", | |
use_fast=False, | |
) | |
tokenizer.pad_token = tokenizer.unk_token | |
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) | |
trainer = Trainer( | |
model=model, tokenizer=tokenizer, args=training_args, **data_module | |
) | |
model.config.use_cache = False | |
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): | |
trainer.train(resume_from_checkpoint=True) | |
else: | |
trainer.train() | |
trainer.save_state() | |
# check if zero3 mode enabled | |
if deepspeed.is_deepspeed_zero3_enabled(): | |
# use deepspeed engine internal function to gather state dict | |
# state_dict_zero3 contains whole parameters of base and lora adapters | |
# we will not extract lora parameters since peft save_pretrained will do that | |
# https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/peft_model.py#L125 | |
# https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/utils/save_and_load.py#L19 | |
state_dict_zero3 = trainer.model_wrapped._zero3_consolidated_16bit_state_dict() | |
if training_args.local_rank == 0: | |
state_dict = state_dict_zero3 | |
else: | |
# in other mode we use original code from fastchat team, to make sure our change is minimum | |
state_dict = get_peft_state_maybe_zero_3( | |
model.named_parameters(), lora_args.lora_bias | |
) | |
if training_args.local_rank == 0: | |
model.save_pretrained(training_args.output_dir, state_dict=state_dict) | |
if __name__ == "__main__": | |
train() | |