praeclarumjj3's picture
:zap: add code
9fa3d89
raw
history blame
43.8 kB
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers.generation.utils import GenerateNonBeamOutput
from transformers.utils import logging, is_accelerate_available
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import (
LogitsProcessorList,
)
from transformers.generation.streamers import BaseStreamer
from transformers.generation.stopping_criteria import (
StoppingCriteriaList,
)
from transformers.utils import ModelOutput, logging
import os
logger = logging.get_logger(__name__)
import collections
import gc
import itertools
import os
import re
import shutil
import tempfile
from transformers import PreTrainedModel
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.pytorch_utils import id_tensor_storage
from transformers.modeling_utils import (
is_fsdp_enabled, is_local_dist_rank_0,
load_state_dict, set_initialized_submodules,
_load_state_dict_into_model,
_load_state_dict_into_meta_model,
expand_device_map, get_disk_only_shard_files,
get_disk_only_shard_files,
)
if is_accelerate_available():
from accelerate.utils import (
find_tied_parameters,
load_offloaded_weights,
save_offload_index,
set_module_tensor_to_device,
)
from transformers.utils import logging
from dataclasses import dataclass
PARAM_RENAME_WARNING = "A parameter name that contains `{}` will be renamed internally to `{}`. Please use a different name to suppress this warning."
@dataclass
class GenerateDecoderOnlyOutput(ModelOutput):
"""
Outputs of decoder-only generation models, when using non-beam methods.
Args:
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`.
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`):
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
"""
sequences: torch.LongTensor = None
scores: Optional[Tuple[torch.FloatTensor]] = None
logits: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False):
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if "gamma" in key and ("vision_tower.vision_tower" not in key and "dav2_model" not in key):
logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight"))
new_key = key.replace("gamma", "weight")
if "beta" in key and "vision_tower.vision_tower" not in key:
logger.warning(PARAM_RENAME_WARNING.format("beta", "bias"))
new_key = key.replace("beta", "bias")
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
error_msgs = []
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
local_metadata["assign_to_params_buffers"] = assign_to_params_buffers
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
if is_deepspeed_zero3_enabled():
import deepspeed
# In sharded models, each shard has only part of the full state_dict, so only gather
# parameters that are in the current state_dict.
named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
if len(params_to_gather) > 0:
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(*args)
else:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".", assign_to_params_buffers)
load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers)
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
# it's safe to delete it.
del state_dict
return error_msgs
def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
"""
Checks if `model_to_load` supports param buffer assignment (such
as when loading in empty weights) by first checking
if the model explicitly disables it, then by ensuring that the state dict keys
are a subset of the model's parameters.
Note: We fully disable this if we are using `deepspeed`
"""
if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
return False
if is_deepspeed_zero3_enabled():
return False
# Some models explicitly do not support param buffer assignment
if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
logger.debug(
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
)
return False
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
first_key = list(model_to_load.state_dict().keys())[0]
if start_prefix + first_key in state_dict:
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
# For cases when the `state_dict` doesn't contain real weights to the model (`test_model_weights_reload_no_missing_tied_weights`)
return False
class BaseCausalLM(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
def _sample(
self,
input_ids: torch.LongTensor,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
streamer: Optional["BaseStreamer"],
logits_warper: Optional[LogitsProcessorList] = None,
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
r"""
Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
logits_warper (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in
`generation_config`)
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
[`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
A `torch.LongTensor` containing the generated tokens (default behaviour) or a
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
"""
# init values
pad_token_id = generation_config.pad_token_id
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
do_sample = generation_config.do_sample
if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
raise ValueError(
"`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
f"{logits_warper})."
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# keep track of which sequences are already finished
batch_size = input_ids.shape[0]
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
if do_sample:
next_token_scores = logits_warper(input_ids, next_token_scores)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_logits:
raw_logits += (next_token_logits,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
probs = nn.functional.softmax(next_token_scores, dim=-1)
# token selection
if do_sample:
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)
# finished sentences should have their next token be a padding token
if has_eos_stopping_criteria:
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
if streamer is not None:
streamer.end()
if return_dict_in_generate:
return GenerateDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return input_ids
@classmethod
def _load_pretrained_model(
cls,
model,
state_dict,
loaded_keys,
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=False,
sharded_metadata=None,
_fast_init=True,
low_cpu_mem_usage=False,
device_map=None,
offload_folder=None,
offload_state_dict=None,
dtype=None,
hf_quantizer=None,
keep_in_fp32_modules=None,
gguf_path=None,
):
is_safetensors = False
is_quantized = hf_quantizer is not None
state_dict_folder = None
state_dict_index = None
if device_map is not None and "disk" in device_map.values():
archive_file = (
resolved_archive_file[0] if isinstance(resolved_archive_file, (list, tuple)) else resolved_archive_file
)
is_safetensors = archive_file.endswith(".safetensors")
if offload_folder is None and not is_safetensors:
raise ValueError(
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
" offers the weights in this format."
)
if offload_folder is not None:
os.makedirs(offload_folder, exist_ok=True)
if offload_state_dict is None:
offload_state_dict = True
is_sharded_safetensors = is_safetensors and sharded_metadata is not None
for key, param in model.state_dict().items():
if param.device == torch.device("meta"):
try:
set_module_tensor_to_device(
model, key, "cuda", torch.empty(*param.size(), dtype=dtype)
)
except:
pass
# tie the model weights before retrieving the state_dict
model.tie_weights()
# Retrieve missing & unexpected_keys
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
prefix = model.base_model_prefix
def _fix_key(key):
if "beta" in key and "vision_tower.vision_tower" not in key:
return key.replace("beta", "bias")
if "gamma" in key and ("vision_tower.vision_tower" not in key and "dav2_model" not in key):
return key.replace("gamma", "weight")
return key
original_loaded_keys = loaded_keys
loaded_keys = [_fix_key(key) for key in loaded_keys]
if len(prefix) > 0:
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
else:
has_prefix_module = False
expects_prefix_module = False
# key re-naming operations are never done on the keys
# that are loaded, but always on the keys of the newly initialized model
remove_prefix_from_model = not has_prefix_module and expects_prefix_module
add_prefix_to_model = has_prefix_module and not expects_prefix_module
if remove_prefix_from_model:
_prefix = f"{prefix}."
expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(_prefix)]
expected_keys = [s[len(_prefix) :] if s.startswith(_prefix) else s for s in expected_keys]
elif add_prefix_to_model:
expected_keys = [".".join([prefix, s]) for s in expected_keys]
missing_keys = sorted(set(expected_keys) - set(loaded_keys))
unexpected_keys = set(loaded_keys) - set(expected_keys)
# Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model
# buffers
model_buffers = {n for n, _ in model.named_buffers()}
if remove_prefix_from_model:
model_buffers = {key[len(_prefix) :] if key.startswith(_prefix) else key for key in model_buffers}
elif add_prefix_to_model:
model_buffers = {".".join([prefix, key]) for key in model_buffers}
unexpected_keys = sorted(unexpected_keys - model_buffers)
model.tie_weights()
if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
ptrs = collections.defaultdict(list)
for name, tensor in model.state_dict().items():
id_tensor = id_tensor_storage(tensor)
ptrs[id_tensor].append(name)
# These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
else:
# id function doesn't work for meta tensor so we need this function
tied_params = find_tied_parameters(model)
for group in tied_params:
if remove_prefix_from_model:
group = [key[len(_prefix) :] if key.startswith(_prefix) else key for key in group]
elif add_prefix_to_model:
group = [".".join([prefix, key]) for key in group]
missing_in_group = [k for k in missing_keys if k in group]
if len(missing_in_group) > 0 and len(missing_in_group) < len(group):
missing_keys = [k for k in missing_keys if k not in missing_in_group]
# Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user.
if cls._keys_to_ignore_on_load_missing is not None:
for pat in cls._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if hf_quantizer is not None:
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix)
# retrieve weights on meta device and put them back on CPU.
# This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step
if low_cpu_mem_usage:
for key in missing_keys:
if key in list(model_state_dict.keys()):
key = key
elif f"{prefix}.{key}" in list(model_state_dict.keys()):
key = f"{prefix}.{key}"
elif key.startswith(prefix) and ".".join(key.split(".")[1:]) in list(model_state_dict.keys()):
key = ".".join(key.split(".")[1:])
param = model_state_dict[key]
# upcast in fp32 if any
target_dtype = dtype
if (
keep_in_fp32_modules is not None
and dtype == torch.float16
and any(
module_to_keep_in_fp32 in key.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
)
):
target_dtype = torch.float32
if param.device == torch.device("meta"):
value = torch.empty(*param.size(), dtype=target_dtype)
if (
not is_quantized
or getattr(hf_quantizer, "requires_parameters_quantization", False)
or not hf_quantizer.check_quantized_param(
model, param_value=value, param_name=key, state_dict={}
)
):
set_module_tensor_to_device(model, key, "cpu", value)
else:
hf_quantizer.create_quantized_param(model, value, key, "cpu", state_dict, unexpected_keys)
# retrieve uninitialized modules and initialize before maybe overriding that with the pretrained weights.
if _fast_init:
if not ignore_mismatched_sizes:
if remove_prefix_from_model:
_loaded_keys = [f"{prefix}.{k}" for k in loaded_keys]
elif add_prefix_to_model:
_loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys]
else:
_loaded_keys = loaded_keys
not_initialized_submodules = set_initialized_submodules(model, _loaded_keys)
# If we're about to tie the output embeds to the input embeds we don't need to init them
if hasattr(model.config, "tie_word_embeddings") and model.config.tie_word_embeddings:
output_embeddings = model.get_output_embeddings()
if output_embeddings is not None:
# Still need to initialize if there is a bias term since biases are not tied.
if not hasattr(output_embeddings, "bias") or output_embeddings.bias is None:
output_embeddings._is_hf_initialized = True
else:
not_initialized_submodules = dict(model.named_modules())
# This will only initialize submodules that are not marked as initialized by the line above.
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
not_initialized_parameters = list(
set(
itertools.chain.from_iterable(
submodule.parameters(recurse=False) for submodule in not_initialized_submodules.values()
)
)
)
with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0):
model.apply(model._initialize_weights)
else:
model.apply(model._initialize_weights)
# Set some modules to fp32 if any
if keep_in_fp32_modules is not None:
for name, param in model.named_parameters():
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
# param = param.to(torch.float32) does not work here as only in the local scope.
param.data = param.data.to(torch.float32)
# Make sure we are able to load base models as well as derived models (with heads)
start_prefix = ""
model_to_load = model
if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module:
start_prefix = cls.base_model_prefix + "."
if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module:
model_to_load = getattr(model, cls.base_model_prefix)
base_model_expected_keys = list(model_to_load.state_dict().keys())
if any(key in expected_keys_not_prefixed and key not in base_model_expected_keys for key in loaded_keys):
raise ValueError(
"The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
"properly saved?"
)
if device_map is not None:
device_map = {k.replace(f"{cls.base_model_prefix}.", ""): v for k, v in device_map.items()}
def _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
):
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
continue
model_key = checkpoint_key
if remove_prefix_from_model:
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
model_key = f"{prefix}.{checkpoint_key}"
elif add_prefix_to_model:
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
model_key = ".".join(checkpoint_key.split(".")[1:])
if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
if (
state_dict[checkpoint_key].shape[-1] == 1
and state_dict[checkpoint_key].numel() * 2 == model_state_dict[model_key].numel()
):
# This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences.
# Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights.
pass
else:
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys
if resolved_archive_file is not None:
folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1])
else:
folder = None
if device_map is not None and is_safetensors:
param_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix)
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
if sharded_metadata is None:
archive_file = (
resolved_archive_file[0]
if isinstance(resolved_archive_file, (list, tuple))
else resolved_archive_file
)
weight_map = {p: archive_file for p in original_loaded_keys}
else:
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
offload_index = {
p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
for p, f in weight_map.items()
if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk"
}
else:
offload_index = None
if state_dict is not None:
# Whole checkpoint
mismatched_keys = _find_mismatched_keys(
state_dict,
model_state_dict,
original_loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
)
# For GGUF models `state_dict` is never set to None as the state dict is always small
if gguf_path:
error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
loaded_keys,
start_prefix,
expected_keys,
device_map=device_map,
offload_folder=offload_folder,
offload_index=offload_index,
state_dict_folder=state_dict_folder,
state_dict_index=state_dict_index,
dtype=dtype,
hf_quantizer=hf_quantizer,
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
)
else:
# Sharded checkpoint or whole but low_cpu_mem_usage==True
assign_to_params_buffers = check_support_param_buffer_assignment(
model_to_load, state_dict, start_prefix
)
error_msgs = _load_state_dict_into_model(
model_to_load, state_dict, start_prefix, assign_to_params_buffers
)
else:
# This should always be a list but, just to be sure.
if not isinstance(resolved_archive_file, list):
resolved_archive_file = [resolved_archive_file]
error_msgs = []
mismatched_keys = []
if not is_safetensors:
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
if offload_state_dict:
state_dict_folder = tempfile.mkdtemp()
state_dict_index = {}
else:
state_dict_folder = None
state_dict_index = None
if is_sharded_safetensors:
disk_only_shard_files = get_disk_only_shard_files(
device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix
)
disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files]
else:
disk_only_shard_files = []
if len(resolved_archive_file) > 1:
resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
assign_to_params_buffers = None
for shard_file in resolved_archive_file:
# Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
if shard_file in disk_only_shard_files:
continue
state_dict = load_state_dict(shard_file, is_quantized=is_quantized)
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys += _find_mismatched_keys(
state_dict,
model_state_dict,
original_loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
)
if low_cpu_mem_usage:
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
for key, param in model_to_load.state_dict().items():
if param.device == torch.device("meta"):
set_module_tensor_to_device(
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
loaded_keys,
start_prefix,
expected_keys,
device_map=device_map,
offload_folder=offload_folder,
offload_index=offload_index,
state_dict_folder=state_dict_folder,
state_dict_index=state_dict_index,
dtype=dtype,
hf_quantizer=hf_quantizer,
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
)
error_msgs += new_error_msgs
else:
# Sharded checkpoint or whole but low_cpu_mem_usage==True
if assign_to_params_buffers is None:
assign_to_params_buffers = check_support_param_buffer_assignment(
model_to_load, state_dict, start_prefix
)
error_msgs += _load_state_dict_into_model(
model_to_load, state_dict, start_prefix, assign_to_params_buffers
)
# force memory release
del state_dict
gc.collect()
if offload_index is not None and len(offload_index) > 0:
if model != model_to_load:
# We need to add the prefix of the base model
prefix = cls.base_model_prefix
if not is_safetensors:
for weight_name in offload_index:
shutil.move(
os.path.join(offload_folder, f"{weight_name}.dat"),
os.path.join(offload_folder, f"{prefix}.{weight_name}.dat"),
)
offload_index = {f"{prefix}.{key}": value for key, value in offload_index.items()}
if not is_safetensors:
save_offload_index(offload_index, offload_folder)
offload_index = None
if offload_state_dict:
# Load back temporarily offloaded state dict
load_offloaded_weights(model_to_load, state_dict_index, state_dict_folder)
shutil.rmtree(state_dict_folder)
if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
if "size mismatch" in error_msg:
error_msg += (
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
)
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
if len(unexpected_keys) > 0:
archs = [] if model.config.architectures is None else model.config.architectures
warner = logger.warning if model.__class__.__name__ in archs else logger.info
warner(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
" with another architecture (e.g. initializing a BertForSequenceClassification model from a"
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
" (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
else:
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
elif len(mismatched_keys) == 0:
logger.info(
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
" training."
)
if len(mismatched_keys) > 0:
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
for key, shape1, shape2 in mismatched_keys
]
)
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
" to use it for predictions and inference."
)
return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs