Nanobit's picture
fix(config): passing gradient_checkpoint_kwargs (#1412)
b1e3e1b unverified
raw
history blame
42.7 kB
"""Module for models and model loading"""
# pylint: disable=too-many-lines
import logging
import math
import os
import types
from typing import Any, Dict, List, Optional, Tuple, Type, Union # noqa: F401
import addict
import bitsandbytes as bnb
import safetensors
import torch
import transformers
from accelerate import init_empty_weights
from bitsandbytes.nn import Linear4bit, Params4bit
from fastcore.parallel import parallel
from peft import (
LoftQConfig,
PeftConfig,
PeftModel,
PeftModelForCausalLM,
prepare_model_for_kbit_training,
)
from peft.tuners.lora import QuantLinear
from torch import Tensor, nn
from transformers import ( # noqa: F401
AddedToken,
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
GPTQConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
from axolotl.core.policies.auto_wrap import SUPPORTED_AUTO_WRAP_MODEL_TYPES
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
patch_for_multipack,
)
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.dict import DictDefault
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
LOG = logging.getLogger("axolotl")
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
quant_config_exists = (
hasattr(model_config, "quantization_config")
and model_config.quantization_config
)
quant_config_method_is_gptq = (
quant_config_exists
and "quant_method" in model_config.quantization_config
and model_config.quantization_config["quant_method"] == "gptq"
)
if cfg.gptq and not quant_config_method_is_gptq:
raise ValueError(
"model_config.quantization_config is not set or quant_method is not set to gptq. "
"Please make sure to point to a GPTQ model."
)
if not cfg.gptq and quant_config_exists:
raise ValueError(
"model_config.quantization_config is set but `gptq` flag is not. "
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
)
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
if (
cfg.adapter
and cfg.tokens
and (
not cfg.lora_modules_to_save
or not all(x in cfg.lora_modules_to_save for x in lora_modules_to_save)
)
):
lora_modules_to_save = ", ".join(map(lambda x: f"`{x}`", lora_modules_to_save))
raise ValueError(
f"`lora_modules_to_save` not properly set when adding new tokens. Please include [{lora_modules_to_save}] in `lora_modules_to_save`."
)
def load_model_config(cfg):
model_config_name = cfg.base_model_config or cfg.base_model
if not model_config_name and cfg.tokenizer_config:
model_config_name = cfg.tokenizer_config
trust_remote_code = cfg.trust_remote_code is True
config_kwargs = {}
if cfg.revision_of_model:
config_kwargs["revision"] = cfg.revision_of_model
try:
model_config = AutoConfig.from_pretrained(
model_config_name,
trust_remote_code=trust_remote_code,
**config_kwargs,
)
except ValueError as err:
if "mamba" in model_config_name:
return addict.Dict(
{
"model_type": "mamba",
}
)
raise err
if cfg.overrides_of_model_config:
for key, val in cfg.overrides_of_model_config.items():
setattr(model_config, key, val)
check_model_config(cfg, model_config)
return model_config
def load_tokenizer(cfg):
model_config = load_model_config(cfg)
tokenizer_kwargs = {}
use_fast = True # this is the default
if cfg.tokenizer_use_fast is not None:
use_fast = cfg.tokenizer_use_fast
if cfg.tokenizer_legacy is not None:
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
tokenizer_cls = AutoTokenizer
if cfg.tokenizer_type:
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
tokenizer = tokenizer_cls.from_pretrained(
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)
if (
tokenizer.__class__.__name__
in [
"LlamaTokenizer",
"LlamaTokenizerFast",
"CodeLlamaTokenizer",
"CodeLlamaTokenizerFast",
]
and hasattr(tokenizer, "pad_token")
and not tokenizer.pad_token
):
# set a pad_token, but use eos_token so we don't add a new token
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Mistral's official FA implementation requires left padding
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
tokenizer.padding_side = "left"
# Qwen base only has single token, so we need to set the special tokens
if cfg.is_qwen_derived_model:
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
for attr_name in token_ids:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, tokenizer.eod_id)
token_names = ["bos_token", "eos_token", "pad_token", "unk_token"]
for attr_name in token_names:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, "<|endoftext|>")
additional_special_tokens = None
if cfg.special_tokens:
special_tokens = cfg.special_tokens.to_dict()
additional_special_tokens = special_tokens.pop(
"additional_special_tokens", None
)
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
for k, val in special_tokens.items():
# check if new special token is not already in tokenizer and
# is adapter training to make sure lora_modules_to_save is set
# pylint: disable=too-many-boolean-expressions
if (
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
and (len(tokenizer.encode(val, add_special_tokens=False)) > 2)
and cfg.adapter
and (
not cfg.lora_modules_to_save
or not all(
x in cfg.lora_modules_to_save for x in lora_modules_to_save
)
)
):
lora_modules_to_save = ", ".join(
[f"`{x}`" for x in lora_modules_to_save]
)
raise ValueError(
f"Please set lora_modules_to_save to [{lora_modules_to_save}] when using an adapter and changing the special tokens."
)
tokenizer.add_special_tokens(
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
)
# If we add bos_token and eos_token, we need to update the post processor to
# handle them correctly.
# https://github.com/huggingface/transformers/pull/24132
bos_or_eos_in_special_tokens = (
"bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens
)
if (
tokenizer.__class__.__name__
in (
"LlamaTokenizerFast",
"CodeLlamaTokenizerFast",
)
and bos_or_eos_in_special_tokens
):
tokenizer.update_post_processor()
if cfg.tokens:
tokenizer.add_tokens(
[
AddedToken(token, rstrip=False, lstrip=False, normalized=False)
for token in cfg.tokens
]
)
# Additional special tokens are a List, and need to be treated differently than regular special
# tokens. We add them after we have called `add_tokens` in case these additional special tokens
# are new tokens.
#
# Usage:
#
# ```py
# special_tokens:
# additional_special_tokens: ["<|im_start|>", "<|im_end|>"]
# ```
if additional_special_tokens is not None:
tokenizer.add_special_tokens(
{"additional_special_tokens": additional_special_tokens}
)
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if cfg.chat_template:
chat_template_string = chat_templates(cfg.chat_template)
if cfg.default_system_message and cfg.chat_template == "chatml":
chat_template_string = chat_template_string.replace(
"You are a helpful assistant.", cfg.default_system_message
)
tokenizer.chat_template = chat_template_string
else:
LOG.info(
"No Chat template selected. Consider adding a chat template for easier inference."
)
return tokenizer
def replace_linear(
model: nn.Module,
linear_replacement: Type[nn.Module],
quant_config: Union[dict, None] = None,
skip_modules=None,
**kwargs,
):
"""
Replace linear modules with a new Linear module.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
linear_replacement (`torch.nn.Module`):
The linear module that replaces the old one. Only expects standard arguments.
If other arguments need to be passed, use a lambda.
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
List of modules names not to convert. Defaults to `lm_head`.
"""
if skip_modules is None:
skip_modules = ["lm_head"]
for name, module in model.named_children():
if len(list(module.children())) > 0:
replace_linear(
module, linear_replacement, quant_config, skip_modules, **kwargs
)
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
if issubclass(linear_replacement, Linear4bit):
model._modules[ # pylint: disable=protected-access
name
] = linear_replacement(
module.in_features,
module.out_features,
module.bias is not None,
**kwargs,
)
else:
raise ValueError(
f"Unsupported linear replacement: {type(linear_replacement)}"
)
return model
def load_and_quantize(
module: nn.Module,
name: str,
value: Tensor,
device: torch.device = None,
dtype: torch.dtype = None,
skip_names: Optional[List[str]] = None,
is_meta_rank: bool = False,
low_memory: bool = True,
verbose: bool = False,
quant_method: str = "bnb",
):
"""
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
Quantizes `Params4bit` on `device` then places on "cpu" if low_memory=True or "meta" if is_meta_rank=True.
"""
if skip_names is None:
skip_names = []
def place_on_device(value):
if is_meta_rank:
device = "meta"
elif low_memory:
device = "cpu"
else:
device = "cuda"
return value.to(device=device, dtype=dtype)
if any(skip_name in name for skip_name in skip_names):
if verbose:
print(f"Skipping {name} because it is in skip_names")
return
module_key, _, value_key = name.rpartition(".")
try:
submodule = module.get_submodule(module_key)
except AttributeError as exc:
print(f"Module {module_key} not found:\n{exc}")
return
try:
if quant_method == "bnb":
param = submodule.get_parameter(value_key)
if isinstance(param, Params4bit):
# With `sync_module_states=True`, a meta device Params4bit needs to be the same
# shape as the quantized Params4bit with an initialized quant_state. However,
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
value = type(param)(
value.to(device=device, dtype=dtype).data, **param.__dict__
).cuda(device)
if is_meta_rank:
value = type(param)(value.data.to("meta"), **value.__dict__)
elif low_memory:
value = type(param)(value.data.to("cpu"), **value.__dict__)
else:
value = type(param)(place_on_device(value).data)
except AttributeError:
# it's a buffer
value = place_on_device(value)
setattr(submodule, value_key, value)
def load_model(
cfg: DictDefault,
tokenizer: PreTrainedTokenizerBase,
inference: bool = False,
reference_model: bool = False,
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
"""
Load a model for a given configuration and tokenizer.
"""
base_model = cfg.base_model
model_type = cfg.type_of_model
model_config = load_model_config(cfg)
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
if hasattr(model_config, "model_type") and model_config.model_type == "btlm":
if cfg.flash_attention:
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
replace_btlm_attn_with_flash_attn,
)
replace_btlm_attn_with_flash_attn(cfg.base_model)
if (
hasattr(model_config, "model_type")
and model_config.model_type == "stablelm_epoch"
):
if cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.stablelm_attn_hijack_flash import (
replace_stablelm_attn_with_flash_attn,
)
replace_stablelm_attn_with_flash_attn(cfg.base_model)
if cfg.sample_packing and cfg.s2_attention:
raise ValueError(
"Received `sample_packing=true` and `s2_attention=true`; however, \
shifted-sparse attention does not currently support sample packing."
)
if (
cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
and cfg.flash_attention
and cfg.sample_packing
):
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
elif cfg.is_llama_derived_model:
# Modify all llama derived models in one block
if cfg.flash_attention:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
)
if cfg.sample_packing:
if cfg.device not in ["mps", "cpu"] and not inference:
LOG.info("patching with flash attention for sample packing")
replace_llama_attn_with_flash_attn(
packed=True,
cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm,
)
elif cfg.s2_attention:
LOG.info("patching w/ flash-enabled, shifted-sparse attention")
replace_llama_attn_with_flash_attn(
packed=False,
cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm,
use_shifted_sparse_attn=True,
)
elif cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
)
LOG.info("patching with xformers attention")
hijack_llama_attention()
elif cfg.sample_packing:
from axolotl.monkeypatch.llama_patch_multipack import (
hijack_llama_prepare_4d_mask,
)
LOG.info("patching llama _prepare_4d_causal_attention_mask*")
hijack_llama_prepare_4d_mask()
elif cfg.s2_attention:
raise NotImplementedError(
"Shifted-sparse attention not currently implemented without flash attention."
)
# Modify mistral derived models
if (
cfg.model_config_type == "mistral"
and cfg.flash_attention
and cfg.sample_packing
):
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
replace_mistral_attn_with_flash_attn,
)
LOG.info("patching mistral with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
LOG.info("patching _expand_mask")
hijack_expand_mask()
model_kwargs: Dict[str, Any] = {}
if cfg.model_kwargs:
for key, val in cfg.model_kwargs.items():
model_kwargs[key] = val
max_memory = cfg.max_memory
device_map = cfg.device_map
if cfg.gpu_memory_limit:
gpu_memory_limit = (
str(cfg.gpu_memory_limit) + "GiB"
if isinstance(cfg.gpu_memory_limit, int)
else cfg.gpu_memory_limit
)
max_memory = {}
for i in range(torch.cuda.device_count()):
max_memory[i] = gpu_memory_limit
max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything
if max_memory is not None:
# Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py
from accelerate import infer_auto_device_map
with init_empty_weights():
model_canvas = AutoModelForCausalLM.from_config(model_config)
model_canvas.tie_weights()
device_map = infer_auto_device_map(
model_canvas,
max_memory=max_memory,
dtype=cfg.torch_dtype,
)
# We can discard max_memory now as we have a device map set up for us
max_memory = None
model_kwargs["device_map"] = device_map
model_kwargs["torch_dtype"] = cfg.torch_dtype
if torch.backends.mps.is_available():
model_kwargs["device_map"] = "mps:0"
# TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
# if cfg.rl:
# if torch.cuda.device_count() > 1:
# if reference_model:
# model_kwargs["device_map"] = "cuda:" + str(
# torch.cuda.current_device() + 1
# )
# else:
# model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device())
if is_deepspeed_zero3_enabled():
del model_kwargs["device_map"]
if cfg.revision_of_model:
model_kwargs["revision"] = cfg.revision_of_model
if cfg.gptq:
if not hasattr(model_config, "quantization_config"):
LOG.warning("model config does not contain quantization_config information")
else:
if cfg.gptq_disable_exllama is not None:
model_config.quantization_config[
"disable_exllama"
] = cfg.gptq_disable_exllama
model_kwargs["quantization_config"] = GPTQConfig(
**model_config.quantization_config
)
if cfg.adapter == "qlora" and cfg.load_in_4bit:
bnb_config = {
"load_in_4bit": True,
"llm_int8_threshold": 6.0,
"llm_int8_has_fp16_weight": False,
"bnb_4bit_compute_dtype": cfg.torch_dtype,
"bnb_4bit_use_double_quant": True,
"bnb_4bit_quant_type": "nf4",
}
if cfg.bnb_config_kwargs:
bnb_config.update(cfg.bnb_config_kwargs)
model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
)
if cfg.load_in_8bit and cfg.adapter is not None:
model_kwargs["load_in_8bit"] = True
if cfg.load_in_4bit and cfg.adapter is not None:
model_kwargs["load_in_4bit"] = True
# no longer needed per https://github.com/huggingface/transformers/pull/26610
if "quantization_config" in model_kwargs or cfg.gptq:
if "load_in_8bit" in model_kwargs:
del model_kwargs["load_in_8bit"]
if "load_in_4bit" in model_kwargs:
del model_kwargs["load_in_4bit"]
# sample packing uses custom FA2 patch
if cfg.flash_attention:
if not cfg.sample_packing:
if cfg.s2_attention:
pass
# most other models support flash attention, we can define exceptions as they come up
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
else:
if model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
else:
model_kwargs["attn_implementation"] = "eager"
model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
elif cfg.sdp_attention:
model_kwargs["attn_implementation"] = "sdpa"
model_config._attn_implementation = "sdpa" # pylint: disable=protected-access
elif cfg.eager_attention:
model_kwargs["attn_implementation"] = "eager"
model_config._attn_implementation = "eager" # pylint: disable=protected-access
qlora_fsdp = (
cfg.fsdp
and cfg.adapter == "qlora"
and model_config.model_type in SUPPORTED_AUTO_WRAP_MODEL_TYPES
)
try:
if qlora_fsdp:
if cfg.bf16 or cfg.bfloat16:
torch_dtype, compute_dtype = torch.float32, torch.bfloat16
elif cfg.fp16 or cfg.float16:
torch_dtype, compute_dtype = torch.float32, torch.float16
else:
torch_dtype, compute_dtype = torch.float32, torch.float16
with init_empty_weights():
LOG.info("Loading model with empty weights.")
model = AutoModelForCausalLM.from_config(model_config)
model.model = replace_linear(
model.model,
Linear4bit,
compute_dtype=compute_dtype,
quant_type="nf4",
quant_storage=torch_dtype,
)
model.is_loaded_in_4bit = True
# Grab the safetensors files that hold the weights
try:
idx = hub.cached_file(base_model, SAFE_WEIGHTS_INDEX_NAME)
files, _ = hub.get_checkpoint_shard_files(base_model, idx)
except OSError:
try:
# This means the model doesn't have a model.safetensors.index.json because it is not sharded
files = []
files.append(hub.cached_file(base_model, SAFE_WEIGHTS_NAME))
except OSError as exc:
# This means the model probably doesn't have a safetensors file
raise exc
# Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly
# and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage
def load_and_quantize_parallel(name_param, model, **kwargs):
name, param = name_param
load_and_quantize(model, name, param, **kwargs)
param_count = sum((p.numel() for n, p in model.named_parameters()))
for filename in files:
weights = safetensors.torch.load_file(filename)
quant_method = "bnb"
devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
left = int(os.cpu_count() / torch.cuda.device_count())
right = int(
8 * (devprops.total_memory / 1e9 / 40) * (70 / (param_count / 1e9))
)
n_workers = min(left, right)
parallel(
load_and_quantize_parallel,
weights.items(),
n_workers=n_workers,
threadpool=True,
model=model,
dtype=torch_dtype,
device=cfg.local_rank,
skip_names=[],
is_meta_rank=(cfg.local_rank != 0),
verbose=False,
quant_method=quant_method,
)
elif (
model_config.model_type == "llama"
and not cfg.trust_remote_code
and not cfg.gptq
):
from transformers import LlamaForCausalLM
model = LlamaForCausalLM.from_pretrained(
base_model,
config=model_config,
**model_kwargs,
)
if cfg.flash_attention and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
is_xformers_swiglu_available,
replace_llama_mlp_with_swiglu,
replace_llama_qkv_with_fused,
)
if cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
LOG.info("patching with SwiGLU")
replace_llama_mlp_with_swiglu(model)
if cfg.flash_attn_fuse_qkv:
LOG.info("patching with fused QKV")
replace_llama_qkv_with_fused(model)
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
# This is a WIP, still an issue with the backward pass
# RuntimeError: grad can be implicitly created only for scalar outputs
# TODO: try config.sequence_parallel = False
# # https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/tests/models/test_gpt_neox.py#L12
# # https://github.com/HazyResearch/flash-attention/tree/main/training#model-components
# # add `**kwargs` to https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/flash_attn/models/gpt.py#L442
# from flash_attn.utils.pretrained import state_dict_from_pretrained
# from flash_attn.models.gpt import GPTLMHeadModel
# from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config
# from transformers import GPTNeoXConfig
# config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(base_model))
# config.use_flash_attn = True
# config.fused_bias_fc = True
# config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast"
# config.activation_function = "gelu_fast"
# config.fused_dropout_add_ln = True
# # config.residual_in_fp32 = True
#
# model: GPTLMHeadModel = GPTLMHeadModel.from_pretrained(
# base_model,
# config,
# dtype=torch_dtype,
# device=cfg.device,
# )
# model.train() # sets to train instead of eval mode
elif model_type == "MambaLMHeadModel":
# FIXME this is janky at best and hacked together to make it work
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
model_kwargs["dtype"] = model_kwargs["torch_dtype"]
model_kwargs["device"] = torch.cuda.current_device()
del model_kwargs["torch_dtype"]
del model_kwargs["device_map"]
model = MambaLMHeadModel.from_pretrained(
base_model,
**model_kwargs,
)
elif model_type and not cfg.trust_remote_code:
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
model = getattr(transformers, model_type).from_pretrained(
base_model,
config=model_config,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts
if (
hasattr(model_config, "max_seq_len")
and model_config.max_seq_len
and cfg.sequence_len > model_config.max_seq_len
):
model_config.max_seq_len = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}")
elif (
hasattr(model_config, "max_sequence_length")
and model_config.max_sequence_length
and cfg.sequence_len > model_config.max_sequence_length
):
model_config.max_sequence_length = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}")
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
except Exception as err: # pylint: disable=broad-exception-caught
LOG.exception(err)
raise err
if isinstance(model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp:
model = model.merge_and_unload()
embeddings_len = (
math.ceil(len(tokenizer) / 32) * 32
if cfg.resize_token_embeddings_to_32x
else len(tokenizer)
)
if (
hasattr(model, "get_input_embeddings")
and model.get_input_embeddings().num_embeddings < embeddings_len
):
model.resize_token_embeddings(embeddings_len)
else:
model.tie_weights()
if (
hasattr(model, "config")
and hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings
and cfg.sequence_len > model.config.max_position_embeddings
):
LOG.warning(
f"increasing model.config.max_position_embeddings from {model.config.max_position_embeddings} to {cfg.sequence_len}"
)
model.config.max_position_embeddings = cfg.sequence_len
if (
hasattr(model, "config")
and hasattr(model.config, "bos_token_id")
and model.config.bos_token_id
and model.config.bos_token_id != tokenizer.bos_token_id
):
model.config.bos_token_id = tokenizer.bos_token_id
if (
hasattr(model, "config")
and hasattr(model.config, "eos_token_id")
and model.config.eos_token_id
and model.config.eos_token_id != tokenizer.eos_token_id
):
model.config.eos_token_id = tokenizer.eos_token_id
if hasattr(model, "device") and model.device.type in ("cuda", "mps"):
log_gpu_memory_usage(LOG, "after model load", model.device)
# make sure these are fp32 per Ramesh et al. (2021)
embedding_modules = get_linear_embedding_layers(cfg.model_config_type)
if not cfg.fsdp:
# FSDP doesn't like mixed Float and BFloat16
for name, module in model.named_modules():
if "norm" in name or name.endswith(".gate"):
module.to(torch.float32)
if model_config.model_type == "btlm":
# don't upcast lm_head for btlm
continue
if any(m in name for m in embedding_modules):
if hasattr(module, "weight"):
module.to(torch.float32)
needs_fa2_dtype = cfg.adapter or cfg.fsdp
skip_prepare_model_for_kbit_training = False
if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
from deepspeed.utils import ( # pylint: disable=no-name-in-module
set_z3_leaf_modules,
)
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled
skip_prepare_model_for_kbit_training = True
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
if cfg.adapter == "lora" and loftq_bits:
skip_prepare_model_for_kbit_training = True
if qlora_fsdp:
skip_prepare_model_for_kbit_training = True
if cfg.adapter in ["lora", "qlora"]:
if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs=cfg.gradient_checkpointing_kwargs
)
if (
cfg.load_in_8bit or cfg.load_in_4bit
) and not skip_prepare_model_for_kbit_training:
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing
)
needs_fa2_dtype = True
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
if (needs_fa2_dtype or cfg.flash_attention) and not qlora_fsdp:
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
for name, module in model.named_modules():
if "norm" in name:
module.to(cfg.torch_dtype)
if any(m in name for m in embedding_modules):
if hasattr(module, "weight"):
module.to(cfg.torch_dtype)
lora_config = None
if not reference_model or cfg.lora_model_dir:
# if we're not loading the reference model, then we're loading the model for training
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora:
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
else:
model, lora_config = load_adapter(model, cfg, cfg.adapter)
if (
cfg.ddp
and not load_in_8bit
and not (cfg.rl and cfg.load_in_4bit)
and not qlora_fsdp
):
# TODO revaldate this conditional
model.to(f"cuda:{cfg.local_rank}")
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
setattr(model, "is_parallelizable", True)
setattr(model, "model_parallel", True)
requires_grad = []
for name, param in model.named_parameters(recurse=True):
if param.requires_grad:
requires_grad.append(f"{name}: {param.requires_grad}")
if len(requires_grad) == 0:
LOG.warning("there are no parameters that require gradient updates")
if hasattr(model, "config"):
model.config.use_cache = False
if cfg.flash_optimum:
from optimum.bettertransformer import BetterTransformer
model = BetterTransformer.transform(model)
if cfg.adapter is not None:
log_gpu_memory_usage(LOG, "after adapters", model.device)
# TODO resume_from_checkpoint handling
return model, lora_config
def load_adapter(model, cfg, adapter, inference=False):
# type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
if adapter is None:
return model, None
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if adapter in ["lora", "qlora"]:
return load_lora(model, cfg, inference=inference)
if adapter == "llama-adapter":
return load_llama_adapter(model, cfg)
raise NotImplementedError(f"{adapter} peft adapter not available")
def load_llama_adapter(model, cfg):
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
from peft import AdaptionPromptConfig, get_peft_model
peft_config = AdaptionPromptConfig(
adapter_layers=cfg.peft_adapter.layers, # layers (L)
adapter_len=cfg.peft_adapter.len, # prompt length (K)
task_type="CAUSAL_LM",
)
if cfg.lora_model_dir:
LOG.debug("Loading pretrained PEFT - llama_adapter")
model = PeftModel.from_pretrained(
model,
cfg.lora_model_dir,
torch_dtype=torch.float16,
)
else:
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
return model, peft_config
def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
lora_module_names = set()
for name, module in model.named_modules():
if (
isinstance(module, cls)
or "Linear" in module.__class__.__name__
and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",)
):
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
embedding_modules = get_linear_embedding_layers(model.config.model_type)
output_embedding = embedding_modules[1]
if output_embedding in lora_module_names: # needed for 16-bit
lora_module_names.remove(output_embedding)
return list(lora_module_names)
def setup_quantized_meta_for_peft(model: nn.Module):
"""Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device"""
def temp_to_method(self, *args, **kwargs): # pylint: disable=unused-argument
return self
for param in model.parameters():
if isinstance(param, Params4bit):
param.quant_state._orig_to = ( # pylint: disable=protected-access
param.quant_state.to
)
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)
def setup_quantized_peft_meta_for_training(model: nn.Module):
"""Replaces dummy `quant_state.to` method with the original function to allow training to continue"""
for param in model.parameters():
if isinstance(param, Params4bit) and hasattr(param.quant_state, "_orig_to"):
param.quant_state.to = (
param.quant_state._orig_to # pylint: disable=protected-access
)
param.quant_state._orig_to = None # pylint: disable=protected-access
def load_lora(model, cfg, inference=False, config_only=False):
# type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
from peft import LoraConfig, get_peft_model
lora_target_modules = list(cfg.lora_target_modules or [])
if cfg.lora_target_linear:
linear_names = find_all_linear_names(model)
LOG.info(f"found linear modules: {repr(linear_names)}")
lora_target_modules = list(set(lora_target_modules + linear_names))
lora_config_kwargs = {}
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
if loftq_bits:
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
lora_config_kwargs["init_lora_weights"] = "loftq"
if cfg.peft_use_dora:
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
if cfg.peft_use_rslora:
lora_config_kwargs["use_rslora"] = cfg.use_rslora
lora_config = LoraConfig(
r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
target_modules=lora_target_modules,
layers_to_transform=cfg.peft_layers_to_transform,
lora_dropout=cfg.lora_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
bias="none",
task_type="CAUSAL_LM",
**lora_config_kwargs,
)
if config_only:
return None, lora_config
rank = int(os.environ.get("LOCAL_RANK", 0))
if cfg.fsdp and cfg.adapter == "qlora" and rank != 0:
setup_quantized_meta_for_peft(model)
if cfg.lora_model_dir:
LOG.debug("Loading pretrained PEFT - LoRA")
model_kwargs: Any = {}
if cfg.lora_on_cpu:
model_kwargs["max_memory"] = {"cpu": "256GiB"}
model_kwargs["device_map"] = {"": "cpu"}
model = PeftModel.from_pretrained(
model,
cfg.lora_model_dir,
is_trainable=(not inference),
**model_kwargs,
)
else:
model = get_peft_model(model, lora_config)
if rank == 0:
model.print_trainable_parameters()
elif cfg.fsdp and cfg.adapter == "qlora":
setup_quantized_peft_meta_for_training(model)
return model, lora_config